diff --git a/aggregator/aggregator.go b/aggregator/aggregator.go index 4dadeb8c6..76caea899 100644 --- a/aggregator/aggregator.go +++ b/aggregator/aggregator.go @@ -53,6 +53,20 @@ var ( CheckpointNotFoundError = errors.New("CheckpointMessages not found") ) +type RpcAggregatorer interface { + ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse *messages.SignedCheckpointTaskResponse) error + ProcessSignedStateRootUpdateMessage(signedStateRootUpdateMessage *messages.SignedStateRootUpdateMessage) error + ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdateMessage *messages.SignedOperatorSetUpdateMessage) error + GetAggregatedCheckpointMessages(fromTimestamp, toTimestamp uint64) (*messages.CheckpointMessages, error) + GetRegistryCoordinatorAddress(reply *string) error +} + +type RestAggregatorer interface { + GetStateRootUpdateAggregation(rollupId uint32, blockHeight uint64) (*types.GetStateRootUpdateAggregationResponse, error) + GetOperatorSetUpdateAggregation(id uint64) (*types.GetOperatorSetUpdateAggregationResponse, error) + GetCheckpointMessages(fromTimestamp, toTimestamp uint64) (*types.GetCheckpointMessagesResponse, error) +} + // Aggregator sends checkpoint tasks onchain, then listens for operator signed TaskResponses. // It aggregates responses signatures, and if any of the TaskResponses reaches the QuorumThreshold for each quorum // (currently we only use a single quorum of the ERC20Mock token), it sends the aggregated TaskResponse and signature onchain. @@ -118,6 +132,8 @@ type Aggregator struct { } var _ core.Metricable = (*Aggregator)(nil) +var _ RpcAggregatorer = (*Aggregator)(nil) +var _ RestAggregatorer = (*Aggregator)(nil) // NewAggregator creates a new Aggregator with the provided config. // TODO: Remove this context once OperatorPubkeysServiceInMemory's API is @@ -516,7 +532,7 @@ func (agg *Aggregator) handleOperatorSetUpdateReachedQuorum(ctx context.Context, } } -func (agg *Aggregator) ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse *messages.SignedCheckpointTaskResponse, reply *bool) error { +func (agg *Aggregator) ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse *messages.SignedCheckpointTaskResponse) error { taskIndex := signedCheckpointTaskResponse.TaskResponse.ReferenceTaskIndex taskResponseDigest, err := signedCheckpointTaskResponse.TaskResponse.Digest() if err != nil { @@ -545,7 +561,7 @@ func (agg *Aggregator) ProcessSignedCheckpointTaskResponse(signedCheckpointTaskR } // Rpc request handlers -func (agg *Aggregator) ProcessSignedStateRootUpdateMessage(signedStateRootUpdateMessage *messages.SignedStateRootUpdateMessage, reply *bool) error { +func (agg *Aggregator) ProcessSignedStateRootUpdateMessage(signedStateRootUpdateMessage *messages.SignedStateRootUpdateMessage) error { messageDigest, err := signedStateRootUpdateMessage.Message.Digest() if err != nil { agg.logger.Error("Failed to get message digest", "err", err) @@ -575,7 +591,7 @@ func (agg *Aggregator) ProcessSignedStateRootUpdateMessage(signedStateRootUpdate return err } -func (agg *Aggregator) ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdateMessage *messages.SignedOperatorSetUpdateMessage, reply *bool) error { +func (agg *Aggregator) ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdateMessage *messages.SignedOperatorSetUpdateMessage) error { messageDigest, err := signedOperatorSetUpdateMessage.Message.Digest() if err != nil { agg.logger.Error("Failed to get message digest", "err", err) @@ -617,22 +633,16 @@ func (agg *Aggregator) GetRegistry() *prometheus.Registry { return agg.registry } -type GetAggregatedCheckpointMessagesArgs struct { - FromTimestamp, ToTimestamp uint64 -} - -func (agg *Aggregator) GetAggregatedCheckpointMessages(args *GetAggregatedCheckpointMessagesArgs, reply *messages.CheckpointMessages) error { - checkpointMessages, err := agg.msgDb.FetchCheckpointMessages(args.FromTimestamp, args.ToTimestamp) +func (agg *Aggregator) GetAggregatedCheckpointMessages(fromTimestamp, toTimestamp uint64) (*messages.CheckpointMessages, error) { + checkpointMessages, err := agg.msgDb.FetchCheckpointMessages(fromTimestamp, toTimestamp) if err != nil { - return err + return nil, err } - *reply = *checkpointMessages - - return nil + return checkpointMessages, nil } -func (agg *Aggregator) GetRegistryCoordinatorAddress(_ *struct{}, reply *string) error { +func (agg *Aggregator) GetRegistryCoordinatorAddress(reply *string) error { *reply = agg.config.SFFLRegistryCoordinatorAddr.String() return nil } diff --git a/aggregator/aggregator_test.go b/aggregator/aggregator_test.go index 0d9d9cf2b..739ed4c59 100644 --- a/aggregator/aggregator_test.go +++ b/aggregator/aggregator_test.go @@ -177,8 +177,6 @@ func createMockAggregator( rollupBroadcaster: mockRollupBroadcaster, httpClient: mockClient, wsClient: mockClient, - rpcListener: &SelectiveRpcListener{}, - restListener: &SelectiveRestListener{}, aggregatorListener: &SelectiveAggregatorListener{}, } return aggregator, mockAvsReader, mockAvsWriter, mockTaskBlsAggregationService, mockStateRootUpdateBlsAggregationService, mockOperatorSetUpdateBlsAggregationService, mockMsgDb, mockRollupBroadcaster, mockClient, nil diff --git a/aggregator/cmd/main.go b/aggregator/cmd/main.go index 60ff802f4..1e3cb3c41 100644 --- a/aggregator/cmd/main.go +++ b/aggregator/cmd/main.go @@ -74,14 +74,18 @@ func aggregatorMain(ctx *cli.Context) error { registry := agg.GetRegistry() rpcServer := rpcserver.NewRpcServer(config.AggregatorServerIpPortAddr, agg, logger) if registry != nil { - rpcServer.EnableMetrics(registry) + if err = rpcServer.EnableMetrics(registry); err != nil { + return err + } } go rpcServer.Start() restServer := restserver.NewRestServer(config.AggregatorRestServerIpPortAddr, agg, logger) if registry != nil { - restServer.EnableMetrics(registry) + if err = restServer.EnableMetrics(registry); err != nil { + return err + } } go restServer.Start() diff --git a/aggregator/gen.go b/aggregator/gen.go index c6efd5542..b6a8cd633 100644 --- a/aggregator/gen.go +++ b/aggregator/gen.go @@ -1,5 +1,6 @@ package aggregator +//go:generate mockgen -destination=./mocks/rest_aggregator.go -package=mocks github.com/NethermindEth/near-sffl/aggregator RestAggregatorer //go:generate mockgen -destination=./mocks/message_blsagg.go -package=mocks github.com/NethermindEth/near-sffl/aggregator MessageBlsAggregationService //go:generate mockgen -destination=./mocks/rollup_broadcaster.go -package=mocks github.com/NethermindEth/near-sffl/aggregator RollupBroadcasterer //go:generate mockgen -destination=./mocks/eth_client.go -package=mocks github.com/Layr-Labs/eigensdk-go/chainio/clients/eth Client diff --git a/aggregator/mocks/rest_aggregator.go b/aggregator/mocks/rest_aggregator.go new file mode 100644 index 000000000..95c9b1327 --- /dev/null +++ b/aggregator/mocks/rest_aggregator.go @@ -0,0 +1,84 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/NethermindEth/near-sffl/aggregator (interfaces: RestAggregatorer) +// +// Generated by this command: +// +// mockgen -destination=./mocks/rest_aggregator.go -package=mocks github.com/NethermindEth/near-sffl/aggregator RestAggregatorer +// +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + types "github.com/NethermindEth/near-sffl/aggregator/types" + gomock "go.uber.org/mock/gomock" +) + +// MockRestAggregatorer is a mock of RestAggregatorer interface. +type MockRestAggregatorer struct { + ctrl *gomock.Controller + recorder *MockRestAggregatorerMockRecorder +} + +// MockRestAggregatorerMockRecorder is the mock recorder for MockRestAggregatorer. +type MockRestAggregatorerMockRecorder struct { + mock *MockRestAggregatorer +} + +// NewMockRestAggregatorer creates a new mock instance. +func NewMockRestAggregatorer(ctrl *gomock.Controller) *MockRestAggregatorer { + mock := &MockRestAggregatorer{ctrl: ctrl} + mock.recorder = &MockRestAggregatorerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRestAggregatorer) EXPECT() *MockRestAggregatorerMockRecorder { + return m.recorder +} + +// GetCheckpointMessages mocks base method. +func (m *MockRestAggregatorer) GetCheckpointMessages(arg0, arg1 uint64) (*types.GetCheckpointMessagesResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCheckpointMessages", arg0, arg1) + ret0, _ := ret[0].(*types.GetCheckpointMessagesResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCheckpointMessages indicates an expected call of GetCheckpointMessages. +func (mr *MockRestAggregatorerMockRecorder) GetCheckpointMessages(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCheckpointMessages", reflect.TypeOf((*MockRestAggregatorer)(nil).GetCheckpointMessages), arg0, arg1) +} + +// GetOperatorSetUpdateAggregation mocks base method. +func (m *MockRestAggregatorer) GetOperatorSetUpdateAggregation(arg0 uint64) (*types.GetOperatorSetUpdateAggregationResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOperatorSetUpdateAggregation", arg0) + ret0, _ := ret[0].(*types.GetOperatorSetUpdateAggregationResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOperatorSetUpdateAggregation indicates an expected call of GetOperatorSetUpdateAggregation. +func (mr *MockRestAggregatorerMockRecorder) GetOperatorSetUpdateAggregation(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOperatorSetUpdateAggregation", reflect.TypeOf((*MockRestAggregatorer)(nil).GetOperatorSetUpdateAggregation), arg0) +} + +// GetStateRootUpdateAggregation mocks base method. +func (m *MockRestAggregatorer) GetStateRootUpdateAggregation(arg0 uint32, arg1 uint64) (*types.GetStateRootUpdateAggregationResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStateRootUpdateAggregation", arg0, arg1) + ret0, _ := ret[0].(*types.GetStateRootUpdateAggregationResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetStateRootUpdateAggregation indicates an expected call of GetStateRootUpdateAggregation. +func (mr *MockRestAggregatorerMockRecorder) GetStateRootUpdateAggregation(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStateRootUpdateAggregation", reflect.TypeOf((*MockRestAggregatorer)(nil).GetStateRootUpdateAggregation), arg0, arg1) +} diff --git a/aggregator/rest_server/server.go b/aggregator/rest_server/server.go index 7a10a5aa9..164cb7169 100644 --- a/aggregator/rest_server/server.go +++ b/aggregator/rest_server/server.go @@ -24,7 +24,7 @@ var ( type RestServer struct { serverIpPortAddr string - app *aggregator.Aggregator + app aggregator.RestAggregatorer logger logging.Logger listener EventListener @@ -32,7 +32,7 @@ type RestServer struct { var _ core.Metricable = (*RestServer)(nil) -func NewRestServer(serverIpPortAddr string, app *aggregator.Aggregator, logger logging.Logger) *RestServer { +func NewRestServer(serverIpPortAddr string, app aggregator.RestAggregatorer, logger logging.Logger) *RestServer { return &RestServer{ serverIpPortAddr: serverIpPortAddr, app: app, diff --git a/aggregator/rest_server/server_test.go b/aggregator/rest_server/server_test.go new file mode 100644 index 000000000..a512b84a9 --- /dev/null +++ b/aggregator/rest_server/server_test.go @@ -0,0 +1,161 @@ +package rest_server + +import ( + "encoding/json" + "fmt" + sdklogging "github.com/Layr-Labs/eigensdk-go/logging" + "github.com/NethermindEth/near-sffl/aggregator/mocks" + "github.com/NethermindEth/near-sffl/core/types/messages" + "github.com/NethermindEth/near-sffl/tests" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + aggtypes "github.com/NethermindEth/near-sffl/aggregator/types" +) + +func TestGetStateRootUpdateAggregation(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + logger := sdklogging.NewNoopLogger() + aggregator := mocks.NewMockRestAggregatorer(mockCtrl) + restServer := NewRestServer("", aggregator, logger) + + msg := messages.StateRootUpdateMessage{ + RollupId: 1, + BlockHeight: 2, + Timestamp: 3, + NearDaCommitment: tests.Keccak256(4), + NearDaTransactionId: tests.Keccak256(5), + StateRoot: tests.Keccak256(6), + } + msgDigest, err := msg.Digest() + assert.Nil(t, err) + + response := aggtypes.GetStateRootUpdateAggregationResponse{ + Message: msg, + Aggregation: messages.MessageBlsAggregation{ + MessageDigest: msgDigest, + }, + } + aggregator.EXPECT().GetStateRootUpdateAggregation(msg.RollupId, msg.BlockHeight).Return(&response, nil) + + req, err := http.NewRequest( + "GET", + fmt.Sprintf("/aggregation/state-root-update?rollupId=%d&blockHeight=%d", msg.RollupId, msg.BlockHeight), + nil, + ) + assert.Nil(t, err) + + recorder := httptest.NewRecorder() + err = restServer.handleGetStateRootUpdateAggregation(recorder, req) + assert.Nil(t, err) + assert.Equal(t, recorder.Code, http.StatusOK) + + var body aggtypes.GetStateRootUpdateAggregationResponse + err = json.Unmarshal(recorder.Body.Bytes(), &body) + assert.Nil(t, err) + assert.Equal(t, body, response) +} + +func TestGetOperatorSetUpdateAggregation(t *testing.T) { + mockCtrl := gomock.NewController(t) + + logger := sdklogging.NewNoopLogger() + aggregator := mocks.NewMockRestAggregatorer(mockCtrl) + restServer := NewRestServer("", aggregator, logger) + + msg := messages.OperatorSetUpdateMessage{ + Id: 1, + Timestamp: 2, + } + digest, err := msg.Digest() + assert.Nil(t, err) + + response := aggtypes.GetOperatorSetUpdateAggregationResponse{ + Message: msg, + Aggregation: messages.MessageBlsAggregation{ + MessageDigest: digest, + }, + } + + aggregator.EXPECT().GetOperatorSetUpdateAggregation(msg.Id).Return(&response, nil) + + req, err := http.NewRequest( + "GET", + fmt.Sprintf("/aggregation/operator-set-update?id=%d", msg.Id), + nil, + ) + assert.Nil(t, err) + + recorder := httptest.NewRecorder() + err = restServer.handleGetOperatorSetUpdateAggregation(recorder, req) + assert.Nil(t, err) + assert.Equal(t, recorder.Code, http.StatusOK) + + var actual aggtypes.GetOperatorSetUpdateAggregationResponse + err = json.Unmarshal(recorder.Body.Bytes(), &actual) + assert.Nil(t, err) + assert.Equal(t, response, actual) +} + +func TestGetCheckpointMessages(t *testing.T) { + mockCtrl := gomock.NewController(t) + + logger := sdklogging.NewNoopLogger() + aggregator := mocks.NewMockRestAggregatorer(mockCtrl) + restServer := NewRestServer("", aggregator, logger) + + stateRootMessage := messages.StateRootUpdateMessage{ + RollupId: 1, + BlockHeight: 2, + Timestamp: 3, + } + stateRootDigest, err := stateRootMessage.Digest() + assert.Nil(t, err) + stateRootAggregation := messages.MessageBlsAggregation{ + MessageDigest: stateRootDigest, + } + + operatorSetMesssage := messages.OperatorSetUpdateMessage{ + Id: 1, + Timestamp: 2, + } + operatorSetDigest, err := operatorSetMesssage.Digest() + assert.Nil(t, err) + operatorSetAggregation := messages.MessageBlsAggregation{ + MessageDigest: operatorSetDigest, + } + + var fromTimestamp, toTimestamp uint64 = 0, 3 + response := aggtypes.GetCheckpointMessagesResponse{ + CheckpointMessages: messages.CheckpointMessages{ + StateRootUpdateMessages: []messages.StateRootUpdateMessage{stateRootMessage}, + StateRootUpdateMessageAggregations: []messages.MessageBlsAggregation{stateRootAggregation}, + OperatorSetUpdateMessages: []messages.OperatorSetUpdateMessage{operatorSetMesssage}, + OperatorSetUpdateMessageAggregations: []messages.MessageBlsAggregation{operatorSetAggregation}, + }, + } + aggregator.EXPECT().GetCheckpointMessages(fromTimestamp, toTimestamp).Return(&response, nil) + + req, err := http.NewRequest( + "GET", + fmt.Sprintf("/checkpoint/messages?fromTimestamp=%d&toTimestamp=%d", fromTimestamp, toTimestamp), + nil, + ) + assert.Nil(t, err) + + recorder := httptest.NewRecorder() + err = restServer.handleGetCheckpointMessages(recorder, req) + assert.Nil(t, err) + assert.Equal(t, recorder.Code, http.StatusOK) + + var actual aggtypes.GetCheckpointMessagesResponse + err = json.Unmarshal(recorder.Body.Bytes(), &actual) + assert.Nil(t, err) + assert.Equal(t, response, actual) +} diff --git a/aggregator/rest_server_test.go b/aggregator/rest_server_test.go deleted file mode 100644 index 606ce7ccd..000000000 --- a/aggregator/rest_server_test.go +++ /dev/null @@ -1,199 +0,0 @@ -package aggregator - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" - - aggtypes "github.com/NethermindEth/near-sffl/aggregator/types" - "github.com/NethermindEth/near-sffl/core/types/messages" -) - -func TestGetStateRootUpdateAggregation(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - aggregator, _, _, _, _, _, mockDb, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) - assert.Nil(t, err) - - go aggregator.startRestServer() - - msg := messages.StateRootUpdateMessage{ - RollupId: 1, - BlockHeight: 2, - Timestamp: 3, - NearDaCommitment: keccak256(4), - NearDaTransactionId: keccak256(5), - StateRoot: keccak256(6), - } - msgDigest, err := msg.Digest() - assert.Nil(t, err) - - aggregation := aggtypes.MessageBlsAggregationServiceResponse{ - MessageBlsAggregation: messages.MessageBlsAggregation{ - MessageDigest: msgDigest, - }, - } - - mockDb.EXPECT().FetchStateRootUpdate(msg.RollupId, msg.BlockHeight).Return(&msg, nil) - - mockDb.EXPECT().FetchStateRootUpdateAggregation(msg.RollupId, msg.BlockHeight).Return(&aggregation.MessageBlsAggregation, nil) - - req, err := http.NewRequest( - "GET", - fmt.Sprintf("/aggregation/state-root-update?rollupId=%d&blockHeight=%d", msg.RollupId, msg.BlockHeight), - nil, - ) - assert.Nil(t, err) - - recorder := httptest.NewRecorder() - - aggregator.handleGetStateRootUpdateAggregation(recorder, req) - - expectedBody := aggtypes.GetStateRootUpdateAggregationResponse{ - Message: msg, - Aggregation: aggregation.MessageBlsAggregation, - } - var body aggtypes.GetStateRootUpdateAggregationResponse - - assert.Equal(t, recorder.Code, http.StatusOK) - - if recorder.Code != http.StatusOK { - fmt.Printf("HTTP Error: %s", recorder.Body.Bytes()) - } - - err = json.Unmarshal(recorder.Body.Bytes(), &body) - assert.Nil(t, err) - - assert.Equal(t, body, expectedBody) -} - -func TestGetOperatorSetUpdateAggregation(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - aggregator, _, _, _, _, _, mockDb, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) - assert.Nil(t, err) - - go aggregator.startRestServer() - - msg := messages.OperatorSetUpdateMessage{ - Id: 1, - Timestamp: 2, - } - msgDigest, err := msg.Digest() - assert.Nil(t, err) - - aggregation := messages.MessageBlsAggregation{ - MessageDigest: msgDigest, - } - - mockDb.EXPECT().FetchOperatorSetUpdate(msg.Id).Return(&msg, nil) - - mockDb.EXPECT().FetchOperatorSetUpdateAggregation(msg.Id).Return(&aggregation, nil) - - req, err := http.NewRequest( - "GET", - fmt.Sprintf("/aggregation/operator-set-update?id=%d", msg.Id), - nil, - ) - assert.Nil(t, err) - - recorder := httptest.NewRecorder() - - aggregator.handleGetOperatorSetUpdateAggregation(recorder, req) - - expectedBody := aggtypes.GetOperatorSetUpdateAggregationResponse{ - Message: msg, - Aggregation: aggregation, - } - var body aggtypes.GetOperatorSetUpdateAggregationResponse - - assert.Equal(t, recorder.Code, http.StatusOK) - - if recorder.Code != http.StatusOK { - fmt.Printf("HTTP Error: %s", recorder.Body.Bytes()) - } - - err = json.Unmarshal(recorder.Body.Bytes(), &body) - assert.Nil(t, err) - - assert.Equal(t, body, expectedBody) -} - -func TestGetCheckpointMessages(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - aggregator, _, _, _, _, _, mockDb, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) - assert.Nil(t, err) - - go aggregator.startRestServer() - - msg := messages.StateRootUpdateMessage{ - RollupId: 1, - BlockHeight: 2, - Timestamp: 3, - } - msgDigest, err := msg.Digest() - assert.Nil(t, err) - - aggregation := messages.MessageBlsAggregation{ - MessageDigest: msgDigest, - } - - msg2 := messages.OperatorSetUpdateMessage{ - Id: 1, - Timestamp: 2, - } - msgDigest2, err := msg2.Digest() - assert.Nil(t, err) - - aggregation2 := messages.MessageBlsAggregation{ - MessageDigest: msgDigest2, - } - - mockDb.EXPECT().FetchCheckpointMessages(uint64(0), uint64(3)).Return(&messages.CheckpointMessages{ - StateRootUpdateMessages: []messages.StateRootUpdateMessage{msg}, - StateRootUpdateMessageAggregations: []messages.MessageBlsAggregation{aggregation}, - OperatorSetUpdateMessages: []messages.OperatorSetUpdateMessage{msg2}, - OperatorSetUpdateMessageAggregations: []messages.MessageBlsAggregation{aggregation2}, - }, nil) - - req, err := http.NewRequest( - "GET", - fmt.Sprintf("/checkpoint/messages?fromTimestamp=%d&toTimestamp=%d", 0, 3), - nil, - ) - assert.Nil(t, err) - - recorder := httptest.NewRecorder() - - aggregator.handleGetCheckpointMessages(recorder, req) - - expectedBody := aggtypes.GetCheckpointMessagesResponse{ - CheckpointMessages: messages.CheckpointMessages{ - StateRootUpdateMessages: []messages.StateRootUpdateMessage{msg}, - StateRootUpdateMessageAggregations: []messages.MessageBlsAggregation{aggregation}, - OperatorSetUpdateMessages: []messages.OperatorSetUpdateMessage{msg2}, - OperatorSetUpdateMessageAggregations: []messages.MessageBlsAggregation{aggregation2}, - }, - } - var body aggtypes.GetCheckpointMessagesResponse - - assert.Equal(t, recorder.Code, http.StatusOK) - - if recorder.Code != http.StatusOK { - fmt.Printf("HTTP Error: %s", recorder.Body.Bytes()) - } - - err = json.Unmarshal(recorder.Body.Bytes(), &body) - assert.Nil(t, err) - - assert.Equal(t, body, expectedBody) -} diff --git a/aggregator/rpc_server/server.go b/aggregator/rpc_server/server.go index 69cd4ae1f..137004403 100644 --- a/aggregator/rpc_server/server.go +++ b/aggregator/rpc_server/server.go @@ -36,7 +36,7 @@ var ( type RpcServer struct { serverIpPortAddr string - app *aggregator.Aggregator + app aggregator.RpcAggregatorer logger logging.Logger listener EventListener @@ -44,7 +44,7 @@ type RpcServer struct { var _ core.Metricable = (*RpcServer)(nil) -func NewRpcServer(serverIpPortAddr string, app *aggregator.Aggregator, logger logging.Logger) *RpcServer { +func NewRpcServer(serverIpPortAddr string, app aggregator.RpcAggregatorer, logger logging.Logger) *RpcServer { return &RpcServer{ serverIpPortAddr: serverIpPortAddr, app: app, @@ -96,7 +96,7 @@ func (s *RpcServer) ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResp s.listener.IncTotalSignedCheckpointTaskResponse() s.listener.ObserveLastMessageReceivedTime(signedCheckpointTaskResponse.OperatorId, CheckpointTaskResponseLabel) - err := s.app.ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse, reply) + err := s.app.ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse) if err != nil { s.listener.IncSignedCheckpointTaskResponse( signedCheckpointTaskResponse.OperatorId, @@ -120,7 +120,7 @@ func (s *RpcServer) ProcessSignedStateRootUpdateMessage(signedStateRootUpdateMes operatorId := signedStateRootUpdateMessage.OperatorId rollupId := signedStateRootUpdateMessage.Message.RollupId - err := s.app.ProcessSignedStateRootUpdateMessage(signedStateRootUpdateMessage, reply) + err := s.app.ProcessSignedStateRootUpdateMessage(signedStateRootUpdateMessage) s.listener.IncSignedStateRootUpdateMessage(operatorId, rollupId, err != nil, hasNearDaCommitment) if err != nil { return mapErrors(err) @@ -136,7 +136,7 @@ func (s *RpcServer) ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdat s.listener.ObserveLastMessageReceivedTime(operatorId, OperatorSetUpdateMessageLabel) s.listener.IncTotalSignedOperatorSetUpdateMessage() - err := s.app.ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdateMessage, reply) + err := s.app.ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdateMessage) s.listener.IncSignedOperatorSetUpdateMessage(operatorId, err != nil) if err != nil { return mapErrors(err) @@ -145,12 +145,23 @@ func (s *RpcServer) ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdat return nil } -func (s *RpcServer) GetAggregatedCheckpointMessages(args *aggregator.GetAggregatedCheckpointMessagesArgs, reply *messages.CheckpointMessages) error { - return s.app.GetAggregatedCheckpointMessages(args, reply) +type GetAggregatedCheckpointMessagesArgs struct { + FromTimestamp, ToTimestamp uint64 } -func (s *RpcServer) GetRegistryCoordinatorAddress(data *struct{}, reply *string) error { - return s.app.GetRegistryCoordinatorAddress(data, reply) +func (s *RpcServer) GetAggregatedCheckpointMessages(args *GetAggregatedCheckpointMessagesArgs, reply *messages.CheckpointMessages) error { + result, err := s.app.GetAggregatedCheckpointMessages(args.FromTimestamp, args.ToTimestamp) + if err != nil { + return mapErrors(err) + } + + *reply = *result + + return nil +} + +func (s *RpcServer) GetRegistryCoordinatorAddress(_ *struct{}, reply *string) error { + return s.app.GetRegistryCoordinatorAddress(reply) } func (s *RpcServer) NotifyOperatorInitialization(operatorId eigentypes.OperatorId, reply *bool) error { diff --git a/aggregator/rpc_server_test.go b/aggregator/rpc_server_test.go index 914abf8c9..17d9a0bc7 100644 --- a/aggregator/rpc_server_test.go +++ b/aggregator/rpc_server_test.go @@ -45,7 +45,7 @@ func TestProcessSignedCheckpointTaskResponse(t *testing.T) { // see https://hynek.me/articles/what-to-mock-in-5-mins/ mockBlsAggServ.EXPECT().ProcessNewSignature(context.Background(), TASK_INDEX, signedCheckpointTaskResponseDigest, &signedCheckpointTaskResponse.BlsSignature, signedCheckpointTaskResponse.OperatorId) - err = aggregator.ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse, nil) + err = aggregator.ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse) assert.Nil(t, err) } @@ -73,7 +73,7 @@ func TestProcessSignedStateRootUpdateMessage(t *testing.T) { mockMessageBlsAggServ.EXPECT().ProcessNewSignature(context.Background(), messageDigest, &signedMessage.BlsSignature, signedMessage.OperatorId) mockMessageBlsAggServ.EXPECT().InitializeMessageIfNotExists(messageDigest, coretypes.QUORUM_NUMBERS, []eigentypes.QuorumThresholdPercentage{types.MESSAGE_AGGREGATION_QUORUM_THRESHOLD}, types.MESSAGE_TTL, types.MESSAGE_BLS_AGGREGATION_TIMEOUT, uint64(0)) - err = aggregator.ProcessSignedStateRootUpdateMessage(signedMessage, nil) + err = aggregator.ProcessSignedStateRootUpdateMessage(signedMessage) assert.Nil(t, err) } @@ -102,7 +102,7 @@ func TestProcessOperatorSetUpdateMessage(t *testing.T) { mockMessageBlsAggServ.EXPECT().ProcessNewSignature(context.Background(), messageDigest, &signedMessage.BlsSignature, signedMessage.OperatorId) mockMessageBlsAggServ.EXPECT().InitializeMessageIfNotExists(messageDigest, coretypes.QUORUM_NUMBERS, []eigentypes.QuorumThresholdPercentage{types.MESSAGE_AGGREGATION_QUORUM_THRESHOLD}, types.MESSAGE_TTL, types.MESSAGE_BLS_AGGREGATION_TIMEOUT, uint64(9)) - err = aggregator.ProcessSignedOperatorSetUpdateMessage(signedMessage, nil) + err = aggregator.ProcessSignedOperatorSetUpdateMessage(signedMessage) assert.Nil(t, err) } @@ -114,9 +114,8 @@ func TestGetAggregatedCheckpointMessages(t *testing.T) { assert.Nil(t, err) var checkpointMessages messages.CheckpointMessages - mockDb.EXPECT().FetchCheckpointMessages(uint64(1), uint64(2)).Return(&checkpointMessages, nil) - err = aggregator.GetAggregatedCheckpointMessages(&GetAggregatedCheckpointMessagesArgs{uint64(1), uint64(2)}, &checkpointMessages) + _, err = aggregator.GetAggregatedCheckpointMessages(uint64(1), uint64(2)) assert.Nil(t, err) }