diff --git a/config/config-docker.yml b/config/config-docker.yml index 948c7ac..ed01692 100644 --- a/config/config-docker.yml +++ b/config/config-docker.yml @@ -33,4 +33,6 @@ assets: ordinals: host: "http://ord-poc.devnet.babylonchain.io" port: 8888 - timeout: 1000 \ No newline at end of file + timeout: 1000 +terms_acceptance_logging: + enabled: true diff --git a/config/config-local.yml b/config/config-local.yml index a9d03c9..97f8f53 100644 --- a/config/config-local.yml +++ b/config/config-local.yml @@ -33,4 +33,6 @@ assets: ordinals: host: "http://ord-poc.devnet.babylonchain.io" port: 8888 - timeout: 5000 \ No newline at end of file + timeout: 5000 +terms_acceptance_logging: + enabled: true diff --git a/internal/api/handlers/handler.go b/internal/api/handlers/handler.go index b0eb5f9..9e1941c 100644 --- a/internal/api/handlers/handler.go +++ b/internal/api/handlers/handler.go @@ -2,6 +2,7 @@ package handlers import ( "context" + "encoding/json" "fmt" "net/http" @@ -171,3 +172,27 @@ func parseStateFilterQuery( } return stateEnum, nil } + +// parseTermsAcceptanceLoggingRequest parses the terms acceptance request bdoy and returns the address and public key +func parseTermsAcceptanceLoggingRequest(request *http.Request, btcNetParam *chaincfg.Params) (string, string, *types.Error) { + var req TermsAcceptanceLoggingRequest + if err := json.NewDecoder(request.Body).Decode(&req); err != nil { + return "", "", types.NewErrorWithMsg(http.StatusBadRequest, types.BadRequest, "Invalid request payload") + } + + // Validate the Bitcoin address + if _, err := utils.CheckBtcAddressType(req.Address, btcNetParam); err != nil { + return "", "", types.NewErrorWithMsg(http.StatusBadRequest, types.BadRequest, "Invalid Bitcoin address") + } + + // Validate the public key + if _, err := utils.GetSchnorrPkFromHex(req.PublicKey); err != nil { + return "", "", types.NewErrorWithMsg(http.StatusBadRequest, types.BadRequest, "Invalid public key") + } + + if req.Address == "" || req.PublicKey == "" { + return "", "", types.NewErrorWithMsg(http.StatusBadRequest, types.BadRequest, "Address and public key are required") + } + + return req.Address, req.PublicKey, nil +} diff --git a/internal/api/handlers/terms.go b/internal/api/handlers/terms.go new file mode 100644 index 0000000..676c686 --- /dev/null +++ b/internal/api/handlers/terms.go @@ -0,0 +1,29 @@ +package handlers + +import ( + "net/http" + + "github.com/babylonlabs-io/staking-api-service/internal/types" +) + +type TermsAcceptanceLoggingRequest struct { + Address string `json:"address"` + PublicKey string `json:"public_key"` +} + +type TermsAcceptancePublic struct { + Status bool `json:"status"` +} + +func (h *Handler) LogTermsAcceptance(request *http.Request) (*Result, *types.Error) { + address, publicKey, err := parseTermsAcceptanceLoggingRequest(request, h.config.Server.BTCNetParam) + if err != nil { + return nil, err + } + + if err := h.services.AcceptTerms(request.Context(), address, publicKey); err != nil { + return nil, err + } + + return NewResult(TermsAcceptancePublic{Status: true}), nil +} diff --git a/internal/api/routes.go b/internal/api/routes.go index 2380deb..5d3b74a 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -10,6 +10,10 @@ func (a *Server) SetupRoutes(r *chi.Mux) { handlers := a.handlers r.Get("/healthcheck", registerHandler(handlers.HealthCheck)) + if a.cfg.TermsAcceptanceLogging.Enabled { + r.Post("/log_terms_acceptance", registerHandler(handlers.LogTermsAcceptance)) + } + r.Get("/v1/staker/delegations", registerHandler(handlers.GetStakerDelegations)) r.Post("/v1/unbonding", registerHandler(handlers.UnbondDelegation)) r.Get("/v1/unbonding/eligibility", registerHandler(handlers.GetUnbondingEligibility)) diff --git a/internal/config/config.go b/internal/config/config.go index b2e0f81..3ebad61 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,11 +10,12 @@ import ( ) type Config struct { - Server *ServerConfig `mapstructure:"server"` - Db *DbConfig `mapstructure:"db"` - Queue *queue.QueueConfig `mapstructure:"queue"` - Metrics *MetricsConfig `mapstructure:"metrics"` - Assets *AssetsConfig `mapstructure:"assets"` + Server *ServerConfig `mapstructure:"server"` + Db *DbConfig `mapstructure:"db"` + Queue *queue.QueueConfig `mapstructure:"queue"` + Metrics *MetricsConfig `mapstructure:"metrics"` + Assets *AssetsConfig `mapstructure:"assets"` + TermsAcceptanceLogging *TermsAcceptanceConfig `mapstructure:"terms_acceptance_logging"` } func (cfg *Config) Validate() error { diff --git a/internal/config/terms.go b/internal/config/terms.go new file mode 100644 index 0000000..04d94c2 --- /dev/null +++ b/internal/config/terms.go @@ -0,0 +1,5 @@ +package config + +type TermsAcceptanceConfig struct { + Enabled bool `mapstructure:"enabled"` +} diff --git a/internal/db/interface.go b/internal/db/interface.go index 258654b..d64280b 100644 --- a/internal/db/interface.go +++ b/internal/db/interface.go @@ -104,6 +104,8 @@ type DBClient interface { ctx context.Context, paginationToken string, ) (*DbResultMap[model.DelegationDocument], error) + // SaveTermsAcceptance saves the acceptance of the terms of service of the public key + SaveTermsAcceptance(ctx context.Context, termsAcceptance *model.TermsAcceptance) error } type DelegationFilter struct { diff --git a/internal/db/model/setup.go b/internal/db/model/setup.go index f5ea00d..d56c347 100644 --- a/internal/db/model/setup.go +++ b/internal/db/model/setup.go @@ -24,6 +24,7 @@ const ( BtcInfoCollection = "btc_info" UnprocessableMsgCollection = "unprocessable_messages" PkAddressMappingsCollection = "pk_address_mappings" + TermsAcceptanceCollection = "terms_acceptance" ) type index struct { diff --git a/internal/db/model/terms.go b/internal/db/model/terms.go new file mode 100644 index 0000000..3b05563 --- /dev/null +++ b/internal/db/model/terms.go @@ -0,0 +1,11 @@ +package model + +import ( + "go.mongodb.org/mongo-driver/bson/primitive" +) + +type TermsAcceptance struct { + Id primitive.ObjectID `bson:"_id,omitempty"` + Address string `bson:"address"` + PublicKey string `bson:"public_key"` +} diff --git a/internal/db/terms.go b/internal/db/terms.go new file mode 100644 index 0000000..da55e95 --- /dev/null +++ b/internal/db/terms.go @@ -0,0 +1,14 @@ +package db + +import ( + "context" + + "github.com/babylonlabs-io/staking-api-service/internal/db/model" +) + +func (db *Database) SaveTermsAcceptance(ctx context.Context, termsAcceptance *model.TermsAcceptance) error { + collection := db.Client.Database(db.DbName).Collection(model.TermsAcceptanceCollection) + + _, err := collection.InsertOne(ctx, termsAcceptance) + return err +} diff --git a/internal/services/terms.go b/internal/services/terms.go new file mode 100644 index 0000000..0ed1bf4 --- /dev/null +++ b/internal/services/terms.go @@ -0,0 +1,21 @@ +package services + +import ( + "context" + + "github.com/babylonlabs-io/staking-api-service/internal/db/model" + "github.com/babylonlabs-io/staking-api-service/internal/types" +) + +func (s *Services) AcceptTerms(ctx context.Context, address, publicKey string) *types.Error { + termsAcceptance := &model.TermsAcceptance{ + Address: address, + PublicKey: publicKey, + } + + if err := s.DbClient.SaveTermsAcceptance(ctx, termsAcceptance); err != nil { + return types.NewInternalServiceError(err) + } + + return nil +} diff --git a/tests/config/config-test.yml b/tests/config/config-test.yml index 1aa5c41..db0e31c 100644 --- a/tests/config/config-test.yml +++ b/tests/config/config-test.yml @@ -33,4 +33,6 @@ assets: ordinals: host: "http://ord-poc.devnet.babylonchain.io" port: 8888 - timeout: 100 \ No newline at end of file + timeout: 100 +terms_acceptance_logging: + enabled: true diff --git a/tests/integration_test/terms_acceptance_test.go b/tests/integration_test/terms_acceptance_test.go new file mode 100644 index 0000000..31910d3 --- /dev/null +++ b/tests/integration_test/terms_acceptance_test.go @@ -0,0 +1,66 @@ +package tests + +import ( + "bytes" + "encoding/json" + "math/rand" + "net/http" + "testing" + "time" + + "github.com/babylonlabs-io/staking-api-service/internal/api/handlers" + "github.com/babylonlabs-io/staking-api-service/tests/testutils" + "github.com/stretchr/testify/assert" +) + +const ( + termsAcceptancePath = "/log_terms_acceptance" +) + +func TestTermsAcceptance(t *testing.T) { + testServer := setupTestServer(t, nil) + defer testServer.Close() + + r := rand.New(rand.NewSource(time.Now().UnixNano())) + address, _ := testutils.RandomBtcAddress(r, testServer.Config.Server.BTCNetParam) + publicKey, _ := testutils.RandomPk() + + // Prepare request body + requestBody := handlers.TermsAcceptanceLoggingRequest{ + Address: address, + PublicKey: publicKey, + } + bodyBytes, _ := json.Marshal(requestBody) + + url := testServer.Server.URL + termsAcceptancePath + resp, err := http.Post(url, "application/json", bytes.NewReader(bodyBytes)) + assert.NoError(t, err, "making POST request to terms acceptance endpoint should not fail") + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode, "expected HTTP 200 OK status") + + var response handlers.PublicResponse[handlers.TermsAcceptancePublic] + err = json.NewDecoder(resp.Body).Decode(&response) + assert.NoError(t, err, "decoding response body should not fail") + assert.Equal(t, true, response.Data.Status) +} + +func TestTermsAcceptanceInvalidAddress(t *testing.T) { + testServer := setupTestServer(t, nil) + defer testServer.Close() + + // Use invalid address + invalidAddress := "invalidaddress" + publicKey, _ := testutils.RandomPk() + + requestBody := handlers.TermsAcceptanceLoggingRequest{} + bodyBytes, _ := json.Marshal(requestBody) + + url := testServer.Server.URL + termsAcceptancePath + "?address=" + invalidAddress + "&public_key=" + publicKey + resp, err := http.Post(url, "application/json", bytes.NewReader(bodyBytes)) + assert.NoError(t, err, "making POST request to terms acceptance endpoint should not fail") + defer resp.Body.Close() + + // Check response + assert.Equal(t, http.StatusBadRequest, resp.StatusCode, "expected HTTP 400 Bad Request status") +} diff --git a/tests/mocks/mock_db_client.go b/tests/mocks/mock_db_client.go index 1b07152..57fcbd6 100644 --- a/tests/mocks/mock_db_client.go +++ b/tests/mocks/mock_db_client.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.41.0. DO NOT EDIT. +// Code generated by mockery v2.44.1. DO NOT EDIT. package mocks @@ -532,6 +532,24 @@ func (_m *DBClient) SaveActiveStakingDelegation(ctx context.Context, stakingTxHa return r0 } +// SaveTermsAcceptance provides a mock function with given fields: ctx, termsAcceptance +func (_m *DBClient) SaveTermsAcceptance(ctx context.Context, termsAcceptance *model.TermsAcceptance) error { + ret := _m.Called(ctx, termsAcceptance) + + if len(ret) == 0 { + panic("no return value specified for SaveTermsAcceptance") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *model.TermsAcceptance) error); ok { + r0 = rf(ctx, termsAcceptance) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // SaveTimeLockExpireCheck provides a mock function with given fields: ctx, stakingTxHashHex, expireHeight, txType func (_m *DBClient) SaveTimeLockExpireCheck(ctx context.Context, stakingTxHashHex string, expireHeight uint64, txType string) error { ret := _m.Called(ctx, stakingTxHashHex, expireHeight, txType) diff --git a/tests/mocks/mock_ordinal_client.go b/tests/mocks/mock_ordinal_client.go index 79346c3..9b30238 100644 --- a/tests/mocks/mock_ordinal_client.go +++ b/tests/mocks/mock_ordinal_client.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.41.0. DO NOT EDIT. +// Code generated by mockery v2.44.1. DO NOT EDIT. package mocks diff --git a/tests/testutils/datagen.go b/tests/testutils/datagen.go index 0290085..90c2ea9 100644 --- a/tests/testutils/datagen.go +++ b/tests/testutils/datagen.go @@ -14,6 +14,8 @@ import ( "github.com/babylonlabs-io/staking-queue-client/client" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" ) @@ -140,6 +142,19 @@ func RandomBytes(r *rand.Rand, n uint64) ([]byte, string) { return randomBytes, hex.EncodeToString(randomBytes) } +func RandomBtcAddress(r *rand.Rand, params *chaincfg.Params) (string, error) { + privKey, err := btcec.NewPrivateKey() + if err != nil { + return "", err + } + pubKey := privKey.PubKey() + addr, err := btcutil.NewAddressTaproot(schnorr.SerializePubKey(pubKey), params) + if err != nil { + return "", err + } + return addr.EncodeAddress(), nil +} + // GenerateRandomTimestamp generates a random timestamp before the specified timestamp. // If beforeTimestamp is 0, then the current time is used. func GenerateRandomTimestamp(afterTimestamp, beforeTimestamp int64) int64 {