diff --git a/docs/docs.go b/docs/docs.go index f2868e5a..31d25277 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -267,12 +267,18 @@ const docTemplate = `{ }, "/v1/stats/staker": { "get": { - "description": "Fetches details of top stakers by their active total value locked (ActiveTvl) in descending order.", + "description": "Fetches staker stats for babylon staking including tvl, total delegations, active tvl and active delegations.\nIf staker_btc_pk query parameter is provided, it will return stats for the specific staker.\nOtherwise, it will return the top stakers ranked by active tvl.", "produces": [ "application/json" ], - "summary": "Get Top Staker Stats by Active TVL", + "summary": "Get Staker Stats", "parameters": [ + { + "type": "string", + "description": "Public key of the staker to fetch", + "name": "staker_btc_pk", + "in": "query" + }, { "type": "string", "description": "Pagination key to fetch the next page of top stakers", diff --git a/docs/swagger.json b/docs/swagger.json index d30841e8..2bb84e1a 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -259,12 +259,18 @@ }, "/v1/stats/staker": { "get": { - "description": "Fetches details of top stakers by their active total value locked (ActiveTvl) in descending order.", + "description": "Fetches staker stats for babylon staking including tvl, total delegations, active tvl and active delegations.\nIf staker_btc_pk query parameter is provided, it will return stats for the specific staker.\nOtherwise, it will return the top stakers ranked by active tvl.", "produces": [ "application/json" ], - "summary": "Get Top Staker Stats by Active TVL", + "summary": "Get Staker Stats", "parameters": [ + { + "type": "string", + "description": "Public key of the staker to fetch", + "name": "staker_btc_pk", + "in": "query" + }, { "type": "string", "description": "Pagination key to fetch the next page of top stakers", diff --git a/docs/swagger.yaml b/docs/swagger.yaml index edd27346..2428a47d 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -410,9 +410,15 @@ paths: summary: Get Overall Stats /v1/stats/staker: get: - description: Fetches details of top stakers by their active total value locked - (ActiveTvl) in descending order. + description: |- + Fetches staker stats for babylon staking including tvl, total delegations, active tvl and active delegations. + If staker_btc_pk query parameter is provided, it will return stats for the specific staker. + Otherwise, it will return the top stakers ranked by active tvl. parameters: + - description: Public key of the staker to fetch + in: query + name: staker_btc_pk + type: string - description: Pagination key to fetch the next page of top stakers in: query name: pagination_key @@ -428,7 +434,7 @@ paths: description: 'Error: Bad Request' schema: $ref: '#/definitions/github_com_babylonlabs-io_staking-api-service_internal_types.Error' - summary: Get Top Staker Stats by Active TVL + summary: Get Staker Stats /v1/unbonding: post: consumes: diff --git a/internal/api/handlers/stats.go b/internal/api/handlers/stats.go index fa931160..d346f466 100644 --- a/internal/api/handlers/stats.go +++ b/internal/api/handlers/stats.go @@ -3,6 +3,7 @@ package handlers import ( "net/http" + "github.com/babylonlabs-io/staking-api-service/internal/services" "github.com/babylonlabs-io/staking-api-service/internal/types" ) @@ -21,15 +22,37 @@ func (h *Handler) GetOverallStats(request *http.Request) (*Result, *types.Error) return NewResult(stats), nil } -// GetTopStakerStats gets top stakers by active tvl -// @Summary Get Top Staker Stats by Active TVL -// @Description Fetches details of top stakers by their active total value locked (ActiveTvl) in descending order. +// GetStakersStats gets staker stats for babylon staking +// @Summary Get Staker Stats +// @Description Fetches staker stats for babylon staking including tvl, total delegations, active tvl and active delegations. +// @Description If staker_btc_pk query parameter is provided, it will return stats for the specific staker. +// @Description Otherwise, it will return the top stakers ranked by active tvl. // @Produce json +// @Param staker_btc_pk query string false "Public key of the staker to fetch" // @Param pagination_key query string false "Pagination key to fetch the next page of top stakers" // @Success 200 {object} PublicResponse[[]services.StakerStatsPublic]{array} "List of top stakers by active tvl" // @Failure 400 {object} types.Error "Error: Bad Request" // @Router /v1/stats/staker [get] -func (h *Handler) GetTopStakerStats(request *http.Request) (*Result, *types.Error) { +func (h *Handler) GetStakersStats(request *http.Request) (*Result, *types.Error) { + // Check if the request is for a specific staker + stakerPk, err := parsePublicKeyQuery(request, "staker_btc_pk", true) + if err != nil { + return nil, err + } + if stakerPk != "" { + var result []services.StakerStatsPublic + stakerStats, err := h.services.GetStakerStats(request.Context(), stakerPk) + if err != nil { + return nil, err + } + if stakerStats != nil { + result = append(result, *stakerStats) + } + + return NewResult(result), nil + } + + // Otherwise, fetch the top stakers ranked by active tvl paginationKey, err := parsePaginationQuery(request) if err != nil { return nil, err diff --git a/internal/api/routes.go b/internal/api/routes.go index 3a9aa834..2380debf 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -16,7 +16,7 @@ func (a *Server) SetupRoutes(r *chi.Mux) { r.Get("/v1/global-params", registerHandler(handlers.GetBabylonGlobalParams)) r.Get("/v1/finality-providers", registerHandler(handlers.GetFinalityProviders)) r.Get("/v1/stats", registerHandler(handlers.GetOverallStats)) - r.Get("/v1/stats/staker", registerHandler(handlers.GetTopStakerStats)) + r.Get("/v1/stats/staker", registerHandler(handlers.GetStakersStats)) r.Get("/v1/staker/delegation/check", registerHandler(handlers.CheckStakerDelegationExist)) r.Get("/v1/delegation", registerHandler(handlers.GetDelegationByTxHash)) diff --git a/internal/db/interface.go b/internal/db/interface.go index 290be918..bf176d7e 100644 --- a/internal/db/interface.go +++ b/internal/db/interface.go @@ -66,6 +66,10 @@ type DBClient interface { ctx context.Context, stakingTxHashHex, stakerPkHex string, amount uint64, ) error FindTopStakersByTvl(ctx context.Context, paginationToken string) (*DbResultMap[*model.StakerStatsDocument], error) + // GetStakerStats fetches the staker stats by the staker's public key. + GetStakerStats( + ctx context.Context, stakerPkHex string, + ) (*model.StakerStatsDocument, error) UpsertLatestBtcInfo( ctx context.Context, height uint64, confirmedTvl uint64, unconfirmedTvl uint64, ) error diff --git a/internal/db/stats.go b/internal/db/stats.go index 629bed4d..da6187ff 100644 --- a/internal/db/stats.go +++ b/internal/db/stats.go @@ -2,6 +2,7 @@ package db import ( "context" + "errors" "fmt" "math/rand" @@ -412,3 +413,23 @@ func (db *Database) FindTopStakersByTvl(ctx context.Context, paginationToken str model.BuildStakerStatsByStakerPaginationToken, ) } + +func (db *Database) GetStakerStats( + ctx context.Context, stakerPkHex string, +) (*model.StakerStatsDocument, error) { + client := db.Client.Database(db.DbName).Collection(model.StakerStatsCollection) + filter := bson.M{"_id": stakerPkHex} + var result model.StakerStatsDocument + err := client.FindOne(ctx, filter).Decode(&result) + if err != nil { + // If the document is not found, return nil + if errors.Is(err, mongo.ErrNoDocuments) { + return nil, &NotFoundError{ + Key: stakerPkHex, + Message: "Staker stats not found", + } + } + return nil, err + } + return &result, nil +} diff --git a/internal/services/stats.go b/internal/services/stats.go index e69608ae..7d4a9a34 100644 --- a/internal/services/stats.go +++ b/internal/services/stats.go @@ -195,6 +195,28 @@ func (s *Services) GetOverallStats( }, nil } +func (s *Services) GetStakerStats( + ctx context.Context, stakerPkHex string, +) (*StakerStatsPublic, *types.Error) { + stats, err := s.DbClient.GetStakerStats(ctx, stakerPkHex) + if err != nil { + // Not found error is not an error, return nil + if db.IsNotFoundError(err) { + return nil, nil + } + log.Ctx(ctx).Error().Err(err).Msg("error while fetching staker stats") + return nil, types.NewInternalServiceError(err) + } + + return &StakerStatsPublic{ + StakerPkHex: stakerPkHex, + ActiveTvl: stats.ActiveTvl, + TotalTvl: stats.TotalTvl, + ActiveDelegations: stats.ActiveDelegations, + TotalDelegations: stats.TotalDelegations, + }, nil +} + func (s *Services) GetTopStakersByActiveTvl( ctx context.Context, pageToken string, ) ([]StakerStatsPublic, string, *types.Error) { diff --git a/tests/integration_test/stats_test.go b/tests/integration_test/stats_test.go index 3de5c614..1bc312a6 100644 --- a/tests/integration_test/stats_test.go +++ b/tests/integration_test/stats_test.go @@ -222,7 +222,7 @@ func TestStatsEndpoints(t *testing.T) { assert.Equal(t, uint64(0), overallStats.PendingTvl) // Test the top staker stats endpoint - stakerStats, _ := fetchStakerStatsEndpoint(t, testServer) + stakerStats := fetchStakerStatsEndpoint(t, testServer, "") assert.Equal(t, 1, len(stakerStats)) assert.Equal(t, activeStakingEvent.StakerPkHex, stakerStats[0].StakerPkHex) assert.Equal(t, int64(activeStakingEvent.StakingValue), stakerStats[0].ActiveTvl) @@ -267,7 +267,7 @@ func TestStatsEndpoints(t *testing.T) { assert.Equal(t, int64(1), overallStats.TotalDelegations) assert.Equal(t, uint64(1), overallStats.TotalStakers) - stakerStats, _ = fetchStakerStatsEndpoint(t, testServer) + stakerStats = fetchStakerStatsEndpoint(t, testServer, "") assert.Equal(t, 1, len(stakerStats)) assert.Equal(t, activeStakingEvent.StakerPkHex, stakerStats[0].StakerPkHex) assert.Equal(t, int64(0), stakerStats[0].ActiveTvl) @@ -297,7 +297,7 @@ func TestStatsEndpoints(t *testing.T) { assert.Equal(t, int64(3), overallStats.TotalDelegations) assert.Equal(t, uint64(2), overallStats.TotalStakers, "expected 2 stakers as the last 2 belong to same staker") - stakerStats, _ = fetchStakerStatsEndpoint(t, testServer) + stakerStats = fetchStakerStatsEndpoint(t, testServer, "") assert.Equal(t, 2, len(stakerStats)) // Also make sure the returned data is sorted by active TVL @@ -321,6 +321,46 @@ func TestStatsEndpoints(t *testing.T) { assert.Equal(t, int64(90), overallStats.ActiveTvl) } +func TestReturnEmptyArrayWhenNoStakerStatsFound(t *testing.T) { + testServer := setupTestServer(t, nil) + defer testServer.Close() + stakerPk, err := testutils.RandomPk() + require.NoError(t, err) + stakerStats := fetchStakerStatsEndpoint(t, testServer, stakerPk) + assert.Equal(t, 0, len(stakerStats)) +} + +func FuzzReturnStakerStatsByStakerPk(f *testing.F) { + attachRandomSeedsToFuzzer(f, 3) + f.Fuzz(func(t *testing.T, seed int64) { + testServer := setupTestServer(t, nil) + defer testServer.Close() + r := rand.New(rand.NewSource(seed)) + events := testutils.GenerateRandomActiveStakingEvents(r, &testutils.TestActiveEventGeneratorOpts{ + NumOfEvents: testutils.RandomPositiveInt(r, 10), + Stakers: testutils.GeneratePks(10), + EnforceNotOverflow: true, + }) + sendTestMessage(testServer.Queues.ActiveStakingQueueClient, events) + time.Sleep(10 * time.Second) + + // Find the unique staker pks + var stakerPks []string + for _, e := range events { + // append into stakerPks if it's not already there + if !testutils.Contains(stakerPks, e.StakerPkHex) { + stakerPks = append(stakerPks, e.StakerPkHex) + } + } + + // Fetch the staker stats for each staker + for _, stakerPk := range stakerPks { + stakerStats := fetchStakerStatsEndpoint(t, testServer, stakerPk) + assert.Equal(t, 1, len(stakerStats)) + } + }) +} + func FuzzStatsEndpointReturnHighestUnconfirmedTvlFromEvents(f *testing.F) { attachRandomSeedsToFuzzer(f, 5) f.Fuzz(func(t *testing.T, seed int64) { @@ -468,8 +508,11 @@ func fetchOverallStatsEndpoint(t *testing.T, testServer *TestServer) services.Ov return responseBody.Data } -func fetchStakerStatsEndpoint(t *testing.T, testServer *TestServer) ([]services.StakerStatsPublic, string) { +func fetchStakerStatsEndpoint(t *testing.T, testServer *TestServer, stakerPk string) []services.StakerStatsPublic { url := testServer.Server.URL + topStakerStatsPath + if stakerPk != "" { + url += "?staker_btc_pk=" + stakerPk + } resp, err := http.Get(url) assert.NoError(t, err, "making GET request to staker stats endpoint should not fail") defer resp.Body.Close() @@ -484,5 +527,5 @@ func fetchStakerStatsEndpoint(t *testing.T, testServer *TestServer) ([]services. err = json.Unmarshal(bodyBytes, &responseBody) assert.NoError(t, err, "unmarshalling response body should not fail") - return responseBody.Data, responseBody.Pagination.NextKey + return responseBody.Data } diff --git a/tests/mocks/mock_db_client.go b/tests/mocks/mock_db_client.go index f326a609..c3c05c5f 100644 --- a/tests/mocks/mock_db_client.go +++ b/tests/mocks/mock_db_client.go @@ -394,6 +394,36 @@ func (_m *DBClient) GetOverallStats(ctx context.Context) (*model.OverallStatsDoc return r0, r1 } +// GetStakerStats provides a mock function with given fields: ctx, stakerPkHex +func (_m *DBClient) GetStakerStats(ctx context.Context, stakerPkHex string) (*model.StakerStatsDocument, error) { + ret := _m.Called(ctx, stakerPkHex) + + if len(ret) == 0 { + panic("no return value specified for GetStakerStats") + } + + var r0 *model.StakerStatsDocument + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*model.StakerStatsDocument, error)); ok { + return rf(ctx, stakerPkHex) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *model.StakerStatsDocument); ok { + r0 = rf(ctx, stakerPkHex) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.StakerStatsDocument) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, stakerPkHex) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // IncrementFinalityProviderStats provides a mock function with given fields: ctx, stakingTxHashHex, fpPkHex, amount func (_m *DBClient) IncrementFinalityProviderStats(ctx context.Context, stakingTxHashHex string, fpPkHex string, amount uint64) error { ret := _m.Called(ctx, stakingTxHashHex, fpPkHex, amount) diff --git a/tests/testutils/utils.go b/tests/testutils/utils.go new file mode 100644 index 00000000..a7661eff --- /dev/null +++ b/tests/testutils/utils.go @@ -0,0 +1,11 @@ +package testutils + +// Contains checks if a slice of strings contains a specific string. +func Contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +}