diff --git a/aggregator/aggregator.go b/aggregator/aggregator.go index 6ca7ddd7..b8e75291 100644 --- a/aggregator/aggregator.go +++ b/aggregator/aggregator.go @@ -8,17 +8,15 @@ import ( "sync" "time" - chainioavsregistry "github.com/Layr-Labs/eigensdk-go/chainio/clients/avsregistry" "github.com/Layr-Labs/eigensdk-go/chainio/clients/wallet" "github.com/Layr-Labs/eigensdk-go/chainio/txmgr" + "github.com/Layr-Labs/eigensdk-go/crypto/bls" "github.com/Layr-Labs/eigensdk-go/logging" "github.com/Layr-Labs/eigensdk-go/metrics" "github.com/Layr-Labs/eigensdk-go/services/avsregistry" blsagg "github.com/Layr-Labs/eigensdk-go/services/bls_aggregation" - opinfoserv "github.com/Layr-Labs/eigensdk-go/services/operatorsinfo" "github.com/Layr-Labs/eigensdk-go/signerv2" eigentypes "github.com/Layr-Labs/eigensdk-go/types" - "github.com/ethereum/go-ethereum/common" "github.com/prometheus/client_golang/prometheus" "github.com/NethermindEth/near-sffl/aggregator/database" @@ -46,6 +44,9 @@ var ( DigestError = errors.New("Failed to get message digest") TaskResponseDigestError = errors.New("Failed to get task response digest") GetOperatorSetUpdateBlockError = errors.New("Failed to get operator set update block") + OperatorNotFoundError = errors.New("Operator not found") + InvalidSignatureError = errors.New("Invalid signature") + UnsupportedMessageTypeError = errors.New("Unsupported message type") // REST errors StateRootUpdateNotFoundError = errors.New("StateRootUpdate not found") @@ -61,6 +62,7 @@ type RpcAggregatorer interface { ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdateMessage *messages.SignedOperatorSetUpdateMessage) error GetAggregatedCheckpointMessages(fromTimestamp, toTimestamp uint64) (*messages.CheckpointMessages, error) GetRegistryCoordinatorAddress(reply *string) error + GetOperatorInfoById(ctx context.Context, operatorId eigentypes.OperatorId) (eigentypes.OperatorInfo, bool) } type RestAggregatorer interface { @@ -120,6 +122,7 @@ type Aggregator struct { metrics metrics.Metrics aggregatorListener AggregatorEventListener + operatorRegistrationsService OperatorRegistrationsService taskBlsAggregationService blsagg.BlsAggregationService stateRootUpdateBlsAggregationService MessageBlsAggregationService operatorSetUpdateBlsAggregationService MessageBlsAggregationService @@ -184,13 +187,6 @@ func NewAggregator( txMgr := txmgr.NewSimpleTxManager(txSender, ethHttpClient, logger, config.AggregatorAddress).WithGasLimitMultiplier(1.5) - // note that the subscriber needs a ws connection instead of http - avsRegistryChainSubscriber, err := chainioavsregistry.BuildAvsRegistryChainSubscriber(common.HexToAddress(config.SFFLRegistryCoordinatorAddr.String()), ethWsClient, logger) - if err != nil { - logger.Error("Cannot create AvsRegistryChainSubscriber", "err", err) - return nil, err - } - avsWriter, err := chainio.BuildAvsWriterFromConfig(txMgr, config, ethHttpClient, logger) if err != nil { logger.Error("Cannot create avsWriter", "err", err) @@ -215,8 +211,12 @@ func NewAggregator( return nil, err } - operatorPubkeysService := opinfoserv.NewOperatorsInfoServiceInMemory(ctx, avsRegistryChainSubscriber, avsReader, logger) - avsRegistryService := avsregistry.NewAvsRegistryServiceChainCaller(avsReader, operatorPubkeysService, logger) + operatorRegistrationsService, err := NewOperatorRegistrationsServiceInMemory(ctx, avsSubscriber, avsReader, logger) + if err != nil { + return nil, err + } + + avsRegistryService := avsregistry.NewAvsRegistryServiceChainCaller(avsReader, operatorRegistrationsService, logger) taskBlsAggregationService := blsagg.NewBlsAggregatorService(avsRegistryService, logger) stateRootUpdateBlsAggregationService := NewMessageBlsAggregatorService(avsRegistryService, ethHttpClient, logger) operatorSetUpdateBlsAggregationService := NewMessageBlsAggregatorService(avsRegistryService, ethHttpClient, logger) @@ -232,6 +232,7 @@ func NewAggregator( rollupBroadcaster: rollupBroadcaster, httpClient: ethHttpClient, wsClient: ethWsClient, + operatorRegistrationsService: operatorRegistrationsService, clock: core.SystemClock, taskBlsAggregationService: taskBlsAggregationService, stateRootUpdateBlsAggregationService: stateRootUpdateBlsAggregationService, @@ -522,6 +523,11 @@ func (agg *Aggregator) handleOperatorSetUpdateReachedQuorum(ctx context.Context, } func (agg *Aggregator) ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse *messages.SignedCheckpointTaskResponse) error { + err := agg.verifySignature(signedCheckpointTaskResponse) + if err != nil { + return err + } + taskIndex := signedCheckpointTaskResponse.TaskResponse.ReferenceTaskIndex taskResponseDigest, err := signedCheckpointTaskResponse.TaskResponse.Digest() if err != nil { @@ -551,19 +557,24 @@ func (agg *Aggregator) ProcessSignedCheckpointTaskResponse(signedCheckpointTaskR // Rpc request handlers func (agg *Aggregator) ProcessSignedStateRootUpdateMessage(signedStateRootUpdateMessage *messages.SignedStateRootUpdateMessage) error { - messageDigest, err := signedStateRootUpdateMessage.Message.Digest() + timestamp := signedStateRootUpdateMessage.Message.Timestamp + err := agg.validateMessageTimestamp(timestamp) if err != nil { - agg.logger.Error("Failed to get message digest", "err", err) - return DigestError + agg.logger.Error("Failed to validate message timestamp", "err", err, "timestamp", timestamp) + return err } - timestamp := signedStateRootUpdateMessage.Message.Timestamp - err = agg.validateMessageTimestamp(timestamp) + err = agg.verifySignature(signedStateRootUpdateMessage) if err != nil { - agg.logger.Error("Failed to validate message timestamp", "err", err, "timestamp", timestamp) return err } + messageDigest, err := signedStateRootUpdateMessage.Message.Digest() + if err != nil { + agg.logger.Error("Failed to get message digest", "err", err) + return DigestError + } + err = agg.stateRootUpdateBlsAggregationService.InitializeMessageIfNotExists( messageDigest, coretypes.QUORUM_NUMBERS, @@ -588,19 +599,24 @@ func (agg *Aggregator) ProcessSignedStateRootUpdateMessage(signedStateRootUpdate } func (agg *Aggregator) ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdateMessage *messages.SignedOperatorSetUpdateMessage) error { - messageDigest, err := signedOperatorSetUpdateMessage.Message.Digest() + timestamp := signedOperatorSetUpdateMessage.Message.Timestamp + err := agg.validateMessageTimestamp(timestamp) if err != nil { - agg.logger.Error("Failed to get message digest", "err", err) - return DigestError + agg.logger.Error("Failed to validate message timestamp", "err", err, "timestamp", timestamp) + return err } - timestamp := signedOperatorSetUpdateMessage.Message.Timestamp - err = agg.validateMessageTimestamp(timestamp) + err = agg.verifySignature(signedOperatorSetUpdateMessage) if err != nil { - agg.logger.Error("Failed to validate message timestamp", "err", err, "timestamp", timestamp) return err } + messageDigest, err := signedOperatorSetUpdateMessage.Message.Digest() + if err != nil { + agg.logger.Error("Failed to get message digest", "err", err) + return DigestError + } + blockNumber, err := agg.avsReader.GetOperatorSetUpdateBlock(context.Background(), signedOperatorSetUpdateMessage.Message.Id) if err != nil { agg.logger.Error("Failed to get operator set update block", "err", err) @@ -691,6 +707,59 @@ func (agg *Aggregator) GetCheckpointMessages(fromTimestamp, toTimestamp uint64) }, nil } +func (agg *Aggregator) GetOperatorInfoById(ctx context.Context, operatorId eigentypes.OperatorId) (eigentypes.OperatorInfo, bool) { + operatorInfo, ok := agg.operatorRegistrationsService.GetOperatorInfoById(ctx, operatorId) + return operatorInfo, ok +} + +func (agg *Aggregator) verifySignature(signedMessage interface{}) error { + var operatorId eigentypes.OperatorId + var signature bls.Signature + var digest [32]byte + var err error + + switch signedMessage := signedMessage.(type) { + case *messages.SignedCheckpointTaskResponse: + operatorId = signedMessage.OperatorId + signature = signedMessage.BlsSignature + digest, err = signedMessage.TaskResponse.Digest() + if err != nil { + return TaskResponseDigestError + } + case *messages.SignedStateRootUpdateMessage: + operatorId = signedMessage.OperatorId + signature = signedMessage.BlsSignature + digest, err = signedMessage.Message.Digest() + if err != nil { + return DigestError + } + case *messages.SignedOperatorSetUpdateMessage: + operatorId = signedMessage.OperatorId + signature = signedMessage.BlsSignature + digest, err = signedMessage.Message.Digest() + if err != nil { + return DigestError + } + default: + return UnsupportedMessageTypeError + } + + operatorInfo, ok := agg.GetOperatorInfoById(context.Background(), operatorId) + if !ok { + return OperatorNotFoundError + } + + ok, err = signature.Verify(operatorInfo.Pubkeys.G2Pubkey, digest) + if err != nil { + return InvalidSignatureError + } + if !ok { + return InvalidSignatureError + } + + return nil +} + func (agg *Aggregator) validateMessageTimestamp(messageTimestamp uint64) error { now := agg.clock.Now().Unix() timeoutInSeconds := types.MESSAGE_SUBMISSION_TIMEOUT.Seconds() diff --git a/aggregator/aggregator_test.go b/aggregator/aggregator_test.go index 78a53814..0ffe4914 100644 --- a/aggregator/aggregator_test.go +++ b/aggregator/aggregator_test.go @@ -34,13 +34,14 @@ var MOCK_OPERATOR_BLS_PRIVATE_KEY, _ = bls.NewPrivateKey(MOCK_OPERATOR_BLS_PRIVA var MOCK_OPERATOR_KEYPAIR = bls.NewKeyPair(MOCK_OPERATOR_BLS_PRIVATE_KEY) var MOCK_OPERATOR_G1PUBKEY = MOCK_OPERATOR_KEYPAIR.GetPubKeyG1() var MOCK_OPERATOR_G2PUBKEY = MOCK_OPERATOR_KEYPAIR.GetPubKeyG2() +var MOCK_OPERATOR_PUBKEYS = eigentypes.OperatorPubkeys{ + G1Pubkey: MOCK_OPERATOR_G1PUBKEY, + G2Pubkey: MOCK_OPERATOR_G2PUBKEY, +} var MOCK_OPERATOR_PUBKEY_DICT = map[eigentypes.OperatorId]types.OperatorInfo{ MOCK_OPERATOR_ID: { - OperatorPubkeys: eigentypes.OperatorPubkeys{ - G1Pubkey: MOCK_OPERATOR_G1PUBKEY, - G2Pubkey: MOCK_OPERATOR_G2PUBKEY, - }, - OperatorAddr: common.Address{}, + OperatorPubkeys: MOCK_OPERATOR_PUBKEYS, + OperatorAddr: common.Address{}, }, } @@ -55,7 +56,7 @@ func TestSendNewTask(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - aggregator, mockAvsReaderer, mockAvsWriterer, mockTaskBlsAggService, _, _, _, _, mockClient, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) + aggregator, mockAvsReaderer, mockAvsWriterer, mockTaskBlsAggService, _, _, _, _, _, mockClient, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) assert.Nil(t, err) var TASK_INDEX = uint32(0) @@ -92,7 +93,7 @@ func TestHandleStateRootUpdateAggregationReachedQuorum(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - aggregator, _, _, _, _, _, mockMsgDb, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) + aggregator, _, _, _, _, _, _, mockMsgDb, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) assert.Nil(t, err) msg := messages.StateRootUpdateMessage{} @@ -122,7 +123,7 @@ func TestHandleOperatorSetUpdateAggregationReachedQuorum(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - aggregator, _, _, _, _, _, mockMsgDb, mockRollupBroadcaster, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) + aggregator, _, _, _, _, _, _, mockMsgDb, mockRollupBroadcaster, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) assert.Nil(t, err) msg := messages.OperatorSetUpdateMessage{} @@ -158,7 +159,7 @@ func TestExpiredStateRootUpdateMessage(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - aggregator, _, _, _, _, _, _, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) + aggregator, _, _, _, _, _, _, _, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) assert.NoError(t, err) nowTimestamp := uint64(6000) @@ -178,7 +179,7 @@ func TestExpiredOperatorSetUpdate(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - aggregator, _, _, _, _, _, _, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) + aggregator, _, _, _, _, _, _, _, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) assert.NoError(t, err) nowTimestamp := uint64(8000) @@ -196,7 +197,7 @@ func TestExpiredOperatorSetUpdate(t *testing.T) { func createMockAggregator( mockCtrl *gomock.Controller, operatorPubkeyDict map[eigentypes.OperatorId]types.OperatorInfo, -) (*Aggregator, *chainiomocks.MockAvsReaderer, *chainiomocks.MockAvsWriterer, *blsaggservmock.MockBlsAggregationService, *aggmocks.MockMessageBlsAggregationService, *aggmocks.MockMessageBlsAggregationService, *dbmocks.MockDatabaser, *aggmocks.MockRollupBroadcasterer, *safeclientmocks.MockSafeClient, error) { +) (*Aggregator, *chainiomocks.MockAvsReaderer, *chainiomocks.MockAvsWriterer, *blsaggservmock.MockBlsAggregationService, *aggmocks.MockMessageBlsAggregationService, *aggmocks.MockMessageBlsAggregationService, *aggmocks.MockOperatorRegistrationsService, *dbmocks.MockDatabaser, *aggmocks.MockRollupBroadcasterer, *safeclientmocks.MockSafeClient, error) { logger := sdklogging.NewNoopLogger() mockAvsWriter := chainiomocks.NewMockAvsWriterer(mockCtrl) mockAvsReader := chainiomocks.NewMockAvsReaderer(mockCtrl) @@ -206,6 +207,7 @@ func createMockAggregator( mockMsgDb := dbmocks.NewMockDatabaser(mockCtrl) mockRollupBroadcaster := aggmocks.NewMockRollupBroadcasterer(mockCtrl) mockClient := safeclientmocks.NewMockSafeClient(mockCtrl) + mockOperatorRegistrationsService := aggmocks.NewMockOperatorRegistrationsService(mockCtrl) aggregator := &Aggregator{ logger: logger, @@ -214,6 +216,7 @@ func createMockAggregator( taskBlsAggregationService: mockTaskBlsAggregationService, stateRootUpdateBlsAggregationService: mockStateRootUpdateBlsAggregationService, operatorSetUpdateBlsAggregationService: mockOperatorSetUpdateBlsAggregationService, + operatorRegistrationsService: mockOperatorRegistrationsService, msgDb: mockMsgDb, tasks: make(map[coretypes.TaskIndex]taskmanager.CheckpointTask), taskResponses: make(map[coretypes.TaskIndex]map[eigentypes.TaskResponseDigest]messages.CheckpointTaskResponse), @@ -225,5 +228,5 @@ func createMockAggregator( aggregatorListener: &SelectiveAggregatorListener{}, clock: core.SystemClock, } - return aggregator, mockAvsReader, mockAvsWriter, mockTaskBlsAggregationService, mockStateRootUpdateBlsAggregationService, mockOperatorSetUpdateBlsAggregationService, mockMsgDb, mockRollupBroadcaster, mockClient, nil + return aggregator, mockAvsReader, mockAvsWriter, mockTaskBlsAggregationService, mockStateRootUpdateBlsAggregationService, mockOperatorSetUpdateBlsAggregationService, mockOperatorRegistrationsService, mockMsgDb, mockRollupBroadcaster, mockClient, nil } diff --git a/aggregator/gen.go b/aggregator/gen.go index 31996765..8e5a6dc9 100644 --- a/aggregator/gen.go +++ b/aggregator/gen.go @@ -4,4 +4,5 @@ package aggregator //go:generate mockgen -destination=./mocks/rpc_aggregator.go -package=mocks github.com/NethermindEth/near-sffl/aggregator RpcAggregatorer //go:generate mockgen -destination=./mocks/message_blsagg.go -package=mocks github.com/NethermindEth/near-sffl/aggregator MessageBlsAggregationService //go:generate mockgen -destination=./mocks/rollup_broadcaster.go -package=mocks github.com/NethermindEth/near-sffl/aggregator RollupBroadcasterer +//go:generate mockgen -destination=./mocks/operator_registrations_inmemory.go -package=mocks github.com/NethermindEth/near-sffl/aggregator OperatorRegistrationsService //go:generate mockgen -destination=./mocks/eth_client.go -package=mocks github.com/Layr-Labs/eigensdk-go/chainio/clients/eth Client diff --git a/aggregator/mocks/eth_client.go b/aggregator/mocks/eth_client.go index 081e963c..e7a5f5d7 100644 --- a/aggregator/mocks/eth_client.go +++ b/aggregator/mocks/eth_client.go @@ -5,7 +5,6 @@ // // mockgen -destination=./mocks/eth_client.go -package=mocks github.com/Layr-Labs/eigensdk-go/chainio/clients/eth Client // - // Package mocks is a generated GoMock package. package mocks diff --git a/aggregator/mocks/message_blsagg.go b/aggregator/mocks/message_blsagg.go index 14223ecf..8abfd2fb 100644 --- a/aggregator/mocks/message_blsagg.go +++ b/aggregator/mocks/message_blsagg.go @@ -5,7 +5,6 @@ // // mockgen -destination=./mocks/message_blsagg.go -package=mocks github.com/NethermindEth/near-sffl/aggregator MessageBlsAggregationService // - // Package mocks is a generated GoMock package. package mocks diff --git a/aggregator/mocks/operator_registrations_inmemory.go b/aggregator/mocks/operator_registrations_inmemory.go new file mode 100644 index 00000000..16fddc21 --- /dev/null +++ b/aggregator/mocks/operator_registrations_inmemory.go @@ -0,0 +1,71 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/NethermindEth/near-sffl/aggregator (interfaces: OperatorRegistrationsService) +// +// Generated by this command: +// +// mockgen -destination=./mocks/operator_registrations_inmemory.go -package=mocks github.com/NethermindEth/near-sffl/aggregator OperatorRegistrationsService +// +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + types "github.com/Layr-Labs/eigensdk-go/types" + common "github.com/ethereum/go-ethereum/common" + gomock "go.uber.org/mock/gomock" +) + +// MockOperatorRegistrationsService is a mock of OperatorRegistrationsService interface. +type MockOperatorRegistrationsService struct { + ctrl *gomock.Controller + recorder *MockOperatorRegistrationsServiceMockRecorder +} + +// MockOperatorRegistrationsServiceMockRecorder is the mock recorder for MockOperatorRegistrationsService. +type MockOperatorRegistrationsServiceMockRecorder struct { + mock *MockOperatorRegistrationsService +} + +// NewMockOperatorRegistrationsService creates a new mock instance. +func NewMockOperatorRegistrationsService(ctrl *gomock.Controller) *MockOperatorRegistrationsService { + mock := &MockOperatorRegistrationsService{ctrl: ctrl} + mock.recorder = &MockOperatorRegistrationsServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOperatorRegistrationsService) EXPECT() *MockOperatorRegistrationsServiceMockRecorder { + return m.recorder +} + +// GetOperatorInfo mocks base method. +func (m *MockOperatorRegistrationsService) GetOperatorInfo(arg0 context.Context, arg1 common.Address) (types.OperatorInfo, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOperatorInfo", arg0, arg1) + ret0, _ := ret[0].(types.OperatorInfo) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetOperatorInfo indicates an expected call of GetOperatorInfo. +func (mr *MockOperatorRegistrationsServiceMockRecorder) GetOperatorInfo(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOperatorInfo", reflect.TypeOf((*MockOperatorRegistrationsService)(nil).GetOperatorInfo), arg0, arg1) +} + +// GetOperatorInfoById mocks base method. +func (m *MockOperatorRegistrationsService) GetOperatorInfoById(arg0 context.Context, arg1 types.Bytes32) (types.OperatorInfo, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOperatorInfoById", arg0, arg1) + ret0, _ := ret[0].(types.OperatorInfo) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetOperatorInfoById indicates an expected call of GetOperatorInfoById. +func (mr *MockOperatorRegistrationsServiceMockRecorder) GetOperatorInfoById(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOperatorInfoById", reflect.TypeOf((*MockOperatorRegistrationsService)(nil).GetOperatorInfoById), arg0, arg1) +} diff --git a/aggregator/mocks/rest_aggregator.go b/aggregator/mocks/rest_aggregator.go index 524ec199..95c9b132 100644 --- a/aggregator/mocks/rest_aggregator.go +++ b/aggregator/mocks/rest_aggregator.go @@ -5,7 +5,6 @@ // // mockgen -destination=./mocks/rest_aggregator.go -package=mocks github.com/NethermindEth/near-sffl/aggregator RestAggregatorer // - // Package mocks is a generated GoMock package. package mocks diff --git a/aggregator/mocks/rollup_broadcaster.go b/aggregator/mocks/rollup_broadcaster.go index 639c6cb0..b048ec4b 100644 --- a/aggregator/mocks/rollup_broadcaster.go +++ b/aggregator/mocks/rollup_broadcaster.go @@ -5,7 +5,6 @@ // // mockgen -destination=./mocks/rollup_broadcaster.go -package=mocks github.com/NethermindEth/near-sffl/aggregator RollupBroadcasterer // - // Package mocks is a generated GoMock package. package mocks diff --git a/aggregator/mocks/rpc_aggregator.go b/aggregator/mocks/rpc_aggregator.go index b0016a50..2768a7ed 100644 --- a/aggregator/mocks/rpc_aggregator.go +++ b/aggregator/mocks/rpc_aggregator.go @@ -5,13 +5,14 @@ // // mockgen -destination=./mocks/rpc_aggregator.go -package=mocks github.com/NethermindEth/near-sffl/aggregator RpcAggregatorer // - // Package mocks is a generated GoMock package. package mocks import ( + context "context" reflect "reflect" + types "github.com/Layr-Labs/eigensdk-go/types" messages "github.com/NethermindEth/near-sffl/core/types/messages" gomock "go.uber.org/mock/gomock" ) @@ -54,6 +55,21 @@ func (mr *MockRpcAggregatorerMockRecorder) GetAggregatedCheckpointMessages(arg0, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAggregatedCheckpointMessages", reflect.TypeOf((*MockRpcAggregatorer)(nil).GetAggregatedCheckpointMessages), arg0, arg1) } +// GetOperatorInfoById mocks base method. +func (m *MockRpcAggregatorer) GetOperatorInfoById(arg0 context.Context, arg1 types.Bytes32) (types.OperatorInfo, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOperatorInfoById", arg0, arg1) + ret0, _ := ret[0].(types.OperatorInfo) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetOperatorInfoById indicates an expected call of GetOperatorInfoById. +func (mr *MockRpcAggregatorerMockRecorder) GetOperatorInfoById(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOperatorInfoById", reflect.TypeOf((*MockRpcAggregatorer)(nil).GetOperatorInfoById), arg0, arg1) +} + // GetRegistryCoordinatorAddress mocks base method. func (m *MockRpcAggregatorer) GetRegistryCoordinatorAddress(arg0 *string) error { m.ctrl.T.Helper() diff --git a/aggregator/operator_registrations_inmemory.go b/aggregator/operator_registrations_inmemory.go new file mode 100644 index 00000000..58c220be --- /dev/null +++ b/aggregator/operator_registrations_inmemory.go @@ -0,0 +1,234 @@ +package aggregator + +import ( + "context" + "sync" + + "github.com/Layr-Labs/eigensdk-go/chainio/clients/avsregistry" + "github.com/Layr-Labs/eigensdk-go/crypto/bls" + "github.com/Layr-Labs/eigensdk-go/logging" + "github.com/Layr-Labs/eigensdk-go/services/operatorsinfo" + "github.com/Layr-Labs/eigensdk-go/types" + "github.com/ethereum/go-ethereum/common" +) + +type OperatorRegistrationsService interface { + operatorsinfo.OperatorsInfoService + + GetOperatorInfoById(ctx context.Context, operatorId types.OperatorId) (operatorInfo types.OperatorInfo, operatorFound bool) +} + +type OperatorRegistrationsServiceInMemory struct { + avsRegistrySubscriber avsregistry.AvsRegistrySubscriber + avsRegistryReader avsregistry.AvsRegistryReader + logger logging.Logger + queryByAddrC chan<- queryByAddr + queryByIdC chan<- queryById + + idToAddr map[types.OperatorId]common.Address + addrToId map[common.Address]types.OperatorId + pubkeysById map[types.OperatorId]types.OperatorPubkeys + socketById map[types.OperatorId]types.Socket +} + +type queryByAddr struct { + operatorAddr common.Address + respC chan<- resp +} + +type queryById struct { + operatorId types.OperatorId + respC chan<- resp +} + +type resp struct { + operatorInfo types.OperatorInfo + operatorExists bool +} + +var _ OperatorRegistrationsService = (*OperatorRegistrationsServiceInMemory)(nil) + +func NewOperatorRegistrationsServiceInMemory( + ctx context.Context, + avsRegistrySubscriber avsregistry.AvsRegistrySubscriber, + avsRegistryReader avsregistry.AvsRegistryReader, + logger logging.Logger, +) (*OperatorRegistrationsServiceInMemory, error) { + queryByAddrC := make(chan queryByAddr) + queryByIdC := make(chan queryById) + + ors := &OperatorRegistrationsServiceInMemory{ + avsRegistrySubscriber: avsRegistrySubscriber, + avsRegistryReader: avsRegistryReader, + logger: logger, + queryByAddrC: queryByAddrC, + queryByIdC: queryByIdC, + idToAddr: make(map[types.OperatorId]common.Address), + addrToId: make(map[common.Address]types.OperatorId), + pubkeysById: make(map[types.OperatorId]types.OperatorPubkeys), + socketById: make(map[types.OperatorId]types.Socket), + } + err := ors.asyncInit(ctx, queryByAddrC, queryByIdC) + if err != nil { + return nil, err + } + + return ors, nil +} + +func (ors *OperatorRegistrationsServiceInMemory) asyncInit(ctx context.Context, queryByAddrC chan queryByAddr, queryByIdC chan queryById) error { + wg := sync.WaitGroup{} + defer wg.Wait() + + initErrC := make(chan error) + ors.startServiceInGoroutine(ctx, queryByAddrC, queryByIdC, &wg, initErrC) + + return <-initErrC +} + +func (ors *OperatorRegistrationsServiceInMemory) startServiceInGoroutine(ctx context.Context, queryByAddrC <-chan queryByAddr, queryByIdC <-chan queryById, wg *sync.WaitGroup, initErrC chan<- error) { + wg.Add(1) + + go func() { + ors.logger.Debug("Subscribing to new pubkey and socket registration events", "service", "OperatorRegistrationsServiceInMemory") + newPubkeyRegistrationC, newPubkeyRegistrationSub, err := ors.avsRegistrySubscriber.SubscribeToNewPubkeyRegistrations() + if err != nil { + ors.logger.Error("Error opening websocket subscription for new pubkey registrations", "err", err, "service", "OperatorRegistrationsServiceInMemory") + wg.Done() + initErrC <- err + return + } + + newSocketRegistrationC, newSocketRegistrationSub, err := ors.avsRegistrySubscriber.SubscribeToOperatorSocketUpdates() + if err != nil { + ors.logger.Error("Error opening websocket subscription for new socket registrations", "err", err, "service", "OperatorRegistrationsServiceInMemory") + wg.Done() + initErrC <- err + return + } + + err = ors.queryPastRegisteredOperators(ctx) + if err != nil { + wg.Done() + initErrC <- err + return + } + wg.Done() + close(initErrC) + + for { + select { + case <-ctx.Done(): + ors.logger.Info("OperatorRegistrationsServiceInMemory: Context cancelled, exiting") + return + + case err := <-newPubkeyRegistrationSub.Err(): + newPubkeyRegistrationSub.Unsubscribe() + ors.logger.Error("Error in websocket subscription for new pubkey registration events", "err", err, "service", "OperatorRegistrationsServiceInMemory") + + case err := <-newSocketRegistrationSub.Err(): + newSocketRegistrationSub.Unsubscribe() + ors.logger.Error("Error in websocket subscription for new socket registration events", "err", err, "service", "OperatorRegistrationsServiceInMemory") + + case newPubkeyRegistrationEvent := <-newPubkeyRegistrationC: + pubkeys := types.OperatorPubkeys{ + G1Pubkey: bls.NewG1Point(newPubkeyRegistrationEvent.PubkeyG1.X, newPubkeyRegistrationEvent.PubkeyG1.Y), + G2Pubkey: bls.NewG2Point(newPubkeyRegistrationEvent.PubkeyG2.X, newPubkeyRegistrationEvent.PubkeyG2.Y), + } + operatorId := types.OperatorIdFromG1Pubkey(pubkeys.G1Pubkey) + operatorAddr := newPubkeyRegistrationEvent.Operator + + ors.idToAddr[operatorId] = operatorAddr + ors.addrToId[operatorAddr] = operatorId + ors.pubkeysById[operatorId] = pubkeys + + ors.logger.Debug("Added operator info to dict", + "service", "OperatorRegistrationsServiceInMemory", + "block", newPubkeyRegistrationEvent.Raw.BlockNumber, + "operatorAddr", operatorAddr, + "operatorId", operatorId, + "G1pubkey", pubkeys.G1Pubkey, + "G2pubkey", pubkeys.G2Pubkey, + ) + + case newSocketRegistrationEvent := <-newSocketRegistrationC: + operatorId := types.OperatorId(newSocketRegistrationEvent.OperatorId) + socket := types.Socket(newSocketRegistrationEvent.Socket) + ors.logger.Debug("Received new socket registration event", "service", "OperatorRegistrationsServiceInMemory", "operatorId", operatorId, "socket", socket) + + ors.socketById[operatorId] = socket + + case q := <-queryByAddrC: + operatorId, idExists := ors.addrToId[q.operatorAddr] + pubkeys, pubkeysExist := ors.pubkeysById[operatorId] + socket, socketExists := ors.socketById[operatorId] + + operatorInfo := types.OperatorInfo{ + Pubkeys: pubkeys, + Socket: socket, + } + q.respC <- resp{operatorInfo, idExists && pubkeysExist && socketExists} + + case q := <-queryByIdC: + pubkeys, pubkeysExist := ors.pubkeysById[q.operatorId] + socket, socketExists := ors.socketById[q.operatorId] + + operatorInfo := types.OperatorInfo{ + Pubkeys: pubkeys, + Socket: socket, + } + q.respC <- resp{operatorInfo, pubkeysExist && socketExists} + } + } + }() +} + +func (ors *OperatorRegistrationsServiceInMemory) queryPastRegisteredOperators(ctx context.Context) error { + alreadyRegisteredOperatorAddrs, alreadyRegisteredOperatorPubkeys, err := ors.avsRegistryReader.QueryExistingRegisteredOperatorPubKeys(ctx, nil, nil) + if err != nil { + ors.logger.Error("Error querying existing registered operators", "err", err, "service", "OperatorRegistrationsServiceInMemory") + return err + } + + socketById, err := ors.avsRegistryReader.QueryExistingRegisteredOperatorSockets(ctx, nil, nil) + if err != nil { + ors.logger.Error("Error querying existing registered operator sockets", "err", err, "service", "OperatorRegistrationsServiceInMemory") + return err + } + + for i, operatorAddr := range alreadyRegisteredOperatorAddrs { + operatorPubkeys := alreadyRegisteredOperatorPubkeys[i] + operatorId := types.OperatorIdFromG1Pubkey(operatorPubkeys.G1Pubkey) + + ors.idToAddr[operatorId] = operatorAddr + ors.addrToId[operatorAddr] = operatorId + ors.pubkeysById[operatorId] = operatorPubkeys + ors.socketById[operatorId] = socketById[operatorId] + } + + return nil +} + +func (ors *OperatorRegistrationsServiceInMemory) GetOperatorInfo(ctx context.Context, operator common.Address) (types.OperatorInfo, bool) { + respC := make(chan resp) + ors.queryByAddrC <- queryByAddr{operator, respC} + + select { + case <-ctx.Done(): + return types.OperatorInfo{}, false + case resp := <-respC: + return resp.operatorInfo, resp.operatorExists + } +} + +func (ors *OperatorRegistrationsServiceInMemory) GetOperatorInfoById(ctx context.Context, operatorId types.OperatorId) (types.OperatorInfo, bool) { + respC := make(chan resp) + ors.queryByIdC <- queryById{operatorId, respC} + + select { + case <-ctx.Done(): + return types.OperatorInfo{}, false + case resp := <-respC: + return resp.operatorInfo, resp.operatorExists + } +} diff --git a/aggregator/rpc_server/server.go b/aggregator/rpc_server/server.go index 389f9cc9..e2e5a0cf 100644 --- a/aggregator/rpc_server/server.go +++ b/aggregator/rpc_server/server.go @@ -19,11 +19,12 @@ import ( var ( TaskNotFoundError400 = errors.New("400. Task not found") OperatorNotPartOfTaskQuorum400 = errors.New("400. Operator not part of quorum") + OperatorNotFoundError400 = errors.New("400. Operator not found") TaskResponseDigestNotFoundError500 = errors.New("500. Failed to get task response digest") MessageDigestNotFoundError500 = errors.New("500. Failed to get message digest") OperatorSetUpdateBlockNotFoundError500 = errors.New("500. Failed to get operator set update block") UnknownErrorWhileVerifyingSignature400 = errors.New("400. Failed to verify signature") - SignatureVerificationFailed400 = errors.New("400. Signature verification failed") + InvalidSignatureError400 = errors.New("400. Invalid signature") CallToGetCheckSignaturesIndicesFailed500 = errors.New("500. Failed to get check signatures indices") MessageExpiredError500 = errors.New("500. Message expired") UnknownError400 = errors.New("400. Unknown error") @@ -32,6 +33,8 @@ var ( aggregator.DigestError: MessageDigestNotFoundError500, aggregator.TaskResponseDigestError: TaskResponseDigestNotFoundError500, aggregator.GetOperatorSetUpdateBlockError: OperatorSetUpdateBlockNotFoundError500, + aggregator.InvalidSignatureError: InvalidSignatureError400, + aggregator.OperatorNotFoundError: OperatorNotFoundError400, aggregator.MessageExpiredError: MessageExpiredError500, } ) diff --git a/aggregator/rpc_server_test.go b/aggregator/rpc_server_test.go index a5e2e4b4..2297ceff 100644 --- a/aggregator/rpc_server_test.go +++ b/aggregator/rpc_server_test.go @@ -29,7 +29,7 @@ func TestProcessSignedCheckpointTaskResponse(t *testing.T) { var FROM_NEAR_BLOCK = uint64(3) var TO_NEAR_BLOCK = uint64(4) - aggregator, _, _, mockBlsAggServ, _, _, _, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) + aggregator, _, _, mockBlsAggServ, _, _, mockOperatorRegistrationsServ, _, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) assert.Nil(t, err) signedCheckpointTaskResponse, err := createMockSignedCheckpointTaskResponse(MockTask{ @@ -45,8 +45,11 @@ func TestProcessSignedCheckpointTaskResponse(t *testing.T) { // TODO(samlaf): is this the right way to test writing to external service? // or is there some wisdom to "don't mock 3rd party code"? // see https://hynek.me/articles/what-to-mock-in-5-mins/ - mockBlsAggServ.EXPECT().ProcessNewSignature(context.Background(), TASK_INDEX, signedCheckpointTaskResponseDigest, + ctx := context.Background() + mockBlsAggServ.EXPECT().ProcessNewSignature(ctx, TASK_INDEX, signedCheckpointTaskResponseDigest, &signedCheckpointTaskResponse.BlsSignature, signedCheckpointTaskResponse.OperatorId) + mockOperatorRegistrationsServ.EXPECT().GetOperatorInfoById(ctx, signedCheckpointTaskResponse.OperatorId).Return(eigentypes.OperatorInfo{Pubkeys: MOCK_OPERATOR_PUBKEYS}, true) + err = aggregator.ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse) assert.Nil(t, err) } @@ -55,7 +58,7 @@ func TestProcessSignedStateRootUpdateMessage(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - aggregator, _, _, _, mockMessageBlsAggServ, _, _, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) + aggregator, _, _, _, mockMessageBlsAggServ, _, mockOperatorRegistrationsServ, _, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) assert.Nil(t, err) aggregator.clock = core.Clock{Now: func() time.Time { return time.Unix(10_000, 0) }} @@ -75,15 +78,43 @@ func TestProcessSignedStateRootUpdateMessage(t *testing.T) { mockMessageBlsAggServ.EXPECT().ProcessNewSignature(context.Background(), messageDigest, &signedMessage.BlsSignature, signedMessage.OperatorId) mockMessageBlsAggServ.EXPECT().InitializeMessageIfNotExists(messageDigest, coretypes.QUORUM_NUMBERS, []eigentypes.QuorumThresholdPercentage{types.MESSAGE_AGGREGATION_QUORUM_THRESHOLD}, types.MESSAGE_TTL, types.MESSAGE_BLS_AGGREGATION_TIMEOUT, uint64(0)) + mockOperatorRegistrationsServ.EXPECT().GetOperatorInfoById(context.Background(), signedMessage.OperatorId).Return(eigentypes.OperatorInfo{Pubkeys: MOCK_OPERATOR_PUBKEYS}, true) + err = aggregator.ProcessSignedStateRootUpdateMessage(signedMessage) assert.Nil(t, err) } +func TestProcessInvalidSignedStateRootUpdateMessage(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + aggregator, _, _, _, _, _, mockOperatorRegistrationsServ, _, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) + assert.Nil(t, err) + + aggregator.clock = core.Clock{Now: func() time.Time { return time.Unix(10_000, 0) }} + message := messages.StateRootUpdateMessage{ + RollupId: 1, + BlockHeight: 2, + Timestamp: 9_995, + NearDaCommitment: keccak256(4), + NearDaTransactionId: keccak256(5), + StateRoot: keccak256(6), + } + + signedMessage, err := createMockSignedStateRootUpdateMessage(message, *MOCK_OPERATOR_KEYPAIR) + assert.Nil(t, err) + signedMessage.BlsSignature = *newInvalidSignature() + + mockOperatorRegistrationsServ.EXPECT().GetOperatorInfoById(context.Background(), signedMessage.OperatorId).Return(eigentypes.OperatorInfo{Pubkeys: MOCK_OPERATOR_PUBKEYS}, true) + err = aggregator.ProcessSignedStateRootUpdateMessage(signedMessage) + assert.Equal(t, err.Error(), "Invalid signature") +} + func TestProcessOperatorSetUpdateMessage(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - aggregator, mockAvsReader, _, _, _, mockMessageBlsAggServ, _, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) + aggregator, mockAvsReader, _, _, _, mockMessageBlsAggServ, mockOperatorRegistrationsServ, _, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) assert.Nil(t, err) aggregator.clock = core.Clock{Now: func() time.Time { return time.Unix(10_000, 0) }} @@ -100,11 +131,14 @@ func TestProcessOperatorSetUpdateMessage(t *testing.T) { messageDigest, err := signedMessage.Message.Digest() assert.Nil(t, err) - mockAvsReader.EXPECT().GetOperatorSetUpdateBlock(context.Background(), uint64(1)).Return(uint32(10), nil) + ctx := context.Background() + mockAvsReader.EXPECT().GetOperatorSetUpdateBlock(ctx, uint64(1)).Return(uint32(10), nil) - mockMessageBlsAggServ.EXPECT().ProcessNewSignature(context.Background(), messageDigest, + mockMessageBlsAggServ.EXPECT().ProcessNewSignature(ctx, messageDigest, &signedMessage.BlsSignature, signedMessage.OperatorId) mockMessageBlsAggServ.EXPECT().InitializeMessageIfNotExists(messageDigest, coretypes.QUORUM_NUMBERS, []eigentypes.QuorumThresholdPercentage{types.MESSAGE_AGGREGATION_QUORUM_THRESHOLD}, types.MESSAGE_TTL, types.MESSAGE_BLS_AGGREGATION_TIMEOUT, uint64(9)) + mockOperatorRegistrationsServ.EXPECT().GetOperatorInfoById(context.Background(), signedMessage.OperatorId).Return(eigentypes.OperatorInfo{Pubkeys: MOCK_OPERATOR_PUBKEYS}, true) + err = aggregator.ProcessSignedOperatorSetUpdateMessage(signedMessage) assert.Nil(t, err) } @@ -113,7 +147,7 @@ func TestGetAggregatedCheckpointMessages(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - aggregator, _, _, _, _, _, mockDb, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) + aggregator, _, _, _, _, _, _, mockDb, _, _, err := createMockAggregator(mockCtrl, MOCK_OPERATOR_PUBKEY_DICT) assert.Nil(t, err) var checkpointMessages messages.CheckpointMessages @@ -177,3 +211,7 @@ func createMockSignedOperatorSetUpdateMessage(mockMessage messages.OperatorSetUp } return signedOperatorSetUpdateMessage, nil } + +func newInvalidSignature() *bls.Signature { + return bls.NewZeroSignature() +} diff --git a/config-files/aggregator.yaml b/config-files/aggregator.yaml index 5d68b015..f4535fb0 100644 --- a/config-files/aggregator.yaml +++ b/config-files/aggregator.yaml @@ -1,5 +1,5 @@ # 'production' only prints info and above. 'development' also prints debug -environment: production +environment: development eth_rpc_url: http://localhost:8545 eth_ws_url: ws://localhost:8545 # address which the aggregator listens on for operator signed messages diff --git a/config-files/operator.anvil.yaml b/config-files/operator.anvil.yaml index d45fb609..85a1645b 100644 --- a/config-files/operator.anvil.yaml +++ b/config-files/operator.anvil.yaml @@ -1,5 +1,5 @@ # this sets the logger level (true = info, false = debug) -production: true +production: false operator_address: 0xD5A0359da7B310917d7760385516B2426E86ab7f diff --git a/core/chainio/avs_subscriber.go b/core/chainio/avs_subscriber.go index 7b3237e1..b11f9557 100644 --- a/core/chainio/avs_subscriber.go +++ b/core/chainio/avs_subscriber.go @@ -1,19 +1,20 @@ package chainio import ( + "github.com/Layr-Labs/eigensdk-go/chainio/clients/avsregistry" + "github.com/Layr-Labs/eigensdk-go/chainio/clients/eth" + sdklogging "github.com/Layr-Labs/eigensdk-go/logging" "github.com/ethereum/go-ethereum/accounts/abi/bind" gethcommon "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/event" - "github.com/Layr-Labs/eigensdk-go/chainio/clients/eth" - sdklogging "github.com/Layr-Labs/eigensdk-go/logging" - opsetupdatereg "github.com/NethermindEth/near-sffl/contracts/bindings/SFFLOperatorSetUpdateRegistry" taskmanager "github.com/NethermindEth/near-sffl/contracts/bindings/SFFLTaskManager" ) type AvsSubscriberer interface { + avsregistry.AvsRegistrySubscriber SubscribeToNewTasks(checkpointTaskCreatedChan chan *taskmanager.ContractSFFLTaskManagerCheckpointTaskCreated) (event.Subscription, error) SubscribeToTaskResponses(taskResponseLogs chan *taskmanager.ContractSFFLTaskManagerCheckpointTaskResponded) (event.Subscription, error) SubscribeToOperatorSetUpdates(operatorSetUpdateChan chan *opsetupdatereg.ContractSFFLOperatorSetUpdateRegistryOperatorSetUpdatedAtBlock) (event.Subscription, error) @@ -25,23 +26,34 @@ type AvsSubscriberer interface { // it takes a single url, so the bindings, even though they have watcher functions, those can't be used // with the http connection... seems very very stupid. Am I missing something? type AvsSubscriber struct { + avsregistry.AvsRegistrySubscriber AvsContractBindings *AvsManagersBindings logger sdklogging.Logger } +var _ (AvsSubscriberer) = (*AvsSubscriber)(nil) + func BuildAvsSubscriber(registryCoordinatorAddr, blsOperatorStateRetrieverAddr gethcommon.Address, ethclient eth.Client, logger sdklogging.Logger) (*AvsSubscriber, error) { avsContractBindings, err := NewAvsManagersBindings(registryCoordinatorAddr, blsOperatorStateRetrieverAddr, ethclient, logger) if err != nil { logger.Error("Failed to create contract bindings", "err", err) return nil, err } - return NewAvsSubscriber(avsContractBindings, logger), nil + + avsRegistrySubscriber, err := avsregistry.NewAvsRegistryChainSubscriber(logger, avsContractBindings.RegistryCoordinator, avsContractBindings.BlsApkRegistry) + if err != nil { + logger.Error("Failed to create chain registry subscriber", "err", err) + return nil, err + } + + return NewAvsSubscriber(avsContractBindings, avsRegistrySubscriber, logger), nil } -func NewAvsSubscriber(avsContractBindings *AvsManagersBindings, logger sdklogging.Logger) *AvsSubscriber { +func NewAvsSubscriber(avsContractBindings *AvsManagersBindings, avsRegistrySubscriber avsregistry.AvsRegistrySubscriber, logger sdklogging.Logger) *AvsSubscriber { return &AvsSubscriber{ - AvsContractBindings: avsContractBindings, - logger: logger, + AvsRegistrySubscriber: avsRegistrySubscriber, + AvsContractBindings: avsContractBindings, + logger: logger, } } diff --git a/core/chainio/bindings.go b/core/chainio/bindings.go index d82b3975..98305f73 100644 --- a/core/chainio/bindings.go +++ b/core/chainio/bindings.go @@ -2,29 +2,37 @@ package chainio import ( "github.com/Layr-Labs/eigensdk-go/chainio/clients/eth" + blsapkreg "github.com/Layr-Labs/eigensdk-go/contracts/bindings/BLSApkRegistry" + regcoord "github.com/Layr-Labs/eigensdk-go/contracts/bindings/RegistryCoordinator" "github.com/Layr-Labs/eigensdk-go/logging" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" erc20mock "github.com/NethermindEth/near-sffl/contracts/bindings/ERC20Mock" opsetupdatereg "github.com/NethermindEth/near-sffl/contracts/bindings/SFFLOperatorSetUpdateRegistry" - regcoord "github.com/NethermindEth/near-sffl/contracts/bindings/SFFLRegistryCoordinator" + sfflregcoord "github.com/NethermindEth/near-sffl/contracts/bindings/SFFLRegistryCoordinator" csservicemanager "github.com/NethermindEth/near-sffl/contracts/bindings/SFFLServiceManager" taskmanager "github.com/NethermindEth/near-sffl/contracts/bindings/SFFLTaskManager" ) type AvsManagersBindings struct { - RegistryCoordinator *regcoord.ContractSFFLRegistryCoordinator + RegistryCoordinator *regcoord.ContractRegistryCoordinator + SFFLRegistryCoordinator *sfflregcoord.ContractSFFLRegistryCoordinator OperatorSetUpdateRegistry *opsetupdatereg.ContractSFFLOperatorSetUpdateRegistry TaskManager *taskmanager.ContractSFFLTaskManager ServiceManager *csservicemanager.ContractSFFLServiceManager + BlsApkRegistry blsapkreg.ContractBLSApkRegistryFilters ethClient eth.Client logger logging.Logger } func NewAvsManagersBindings(registryCoordinatorAddr, operatorStateRetrieverAddr common.Address, ethclient eth.Client, logger logging.Logger) (*AvsManagersBindings, error) { - contractRegistryCoordinator, err := regcoord.NewContractSFFLRegistryCoordinator(registryCoordinatorAddr, ethclient) + contractSfflRegistryCoordinator, err := sfflregcoord.NewContractSFFLRegistryCoordinator(registryCoordinatorAddr, ethclient) + if err != nil { + return nil, err + } + + contractRegistryCoordinator, err := regcoord.NewContractRegistryCoordinator(registryCoordinatorAddr, ethclient) if err != nil { return nil, err } @@ -51,7 +59,7 @@ func NewAvsManagersBindings(registryCoordinatorAddr, operatorStateRetrieverAddr return nil, err } - operatorSetUpdateRegistryAddr, err := contractRegistryCoordinator.OperatorSetUpdateRegistry(&bind.CallOpts{}) + operatorSetUpdateRegistryAddr, err := contractSfflRegistryCoordinator.OperatorSetUpdateRegistry(&bind.CallOpts{}) if err != nil { return nil, err } @@ -62,11 +70,23 @@ func NewAvsManagersBindings(registryCoordinatorAddr, operatorStateRetrieverAddr return nil, err } + blsApkRegistryAddr, err := contractRegistryCoordinator.BlsApkRegistry(&bind.CallOpts{}) + if err != nil { + return nil, err + } + + blsApkRegistry, err := blsapkreg.NewContractBLSApkRegistry(blsApkRegistryAddr, ethclient) + if err != nil { + return nil, err + } + return &AvsManagersBindings{ RegistryCoordinator: contractRegistryCoordinator, + SFFLRegistryCoordinator: contractSfflRegistryCoordinator, OperatorSetUpdateRegistry: contractOperatorSetUpdateRegistry, ServiceManager: contractServiceManager, TaskManager: contractTaskManager, + BlsApkRegistry: blsApkRegistry, ethClient: ethclient, logger: logger, }, nil diff --git a/core/chainio/mocks/avs_subscriber.go b/core/chainio/mocks/avs_subscriber.go index 53966183..e858c7a5 100644 --- a/core/chainio/mocks/avs_subscriber.go +++ b/core/chainio/mocks/avs_subscriber.go @@ -11,6 +11,8 @@ package mocks import ( reflect "reflect" + contractBLSApkRegistry "github.com/Layr-Labs/eigensdk-go/contracts/bindings/BLSApkRegistry" + contractRegistryCoordinator "github.com/Layr-Labs/eigensdk-go/contracts/bindings/RegistryCoordinator" contractSFFLOperatorSetUpdateRegistry "github.com/NethermindEth/near-sffl/contracts/bindings/SFFLOperatorSetUpdateRegistry" contractSFFLTaskManager "github.com/NethermindEth/near-sffl/contracts/bindings/SFFLTaskManager" types "github.com/ethereum/go-ethereum/core/types" @@ -56,6 +58,22 @@ func (mr *MockAvsSubscribererMockRecorder) ParseCheckpointTaskResponded(arg0 any return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseCheckpointTaskResponded", reflect.TypeOf((*MockAvsSubscriberer)(nil).ParseCheckpointTaskResponded), arg0) } +// SubscribeToNewPubkeyRegistrations mocks base method. +func (m *MockAvsSubscriberer) SubscribeToNewPubkeyRegistrations() (chan *contractBLSApkRegistry.ContractBLSApkRegistryNewPubkeyRegistration, event.Subscription, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubscribeToNewPubkeyRegistrations") + ret0, _ := ret[0].(chan *contractBLSApkRegistry.ContractBLSApkRegistryNewPubkeyRegistration) + ret1, _ := ret[1].(event.Subscription) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// SubscribeToNewPubkeyRegistrations indicates an expected call of SubscribeToNewPubkeyRegistrations. +func (mr *MockAvsSubscribererMockRecorder) SubscribeToNewPubkeyRegistrations() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeToNewPubkeyRegistrations", reflect.TypeOf((*MockAvsSubscriberer)(nil).SubscribeToNewPubkeyRegistrations)) +} + // SubscribeToNewTasks mocks base method. func (m *MockAvsSubscriberer) SubscribeToNewTasks(arg0 chan *contractSFFLTaskManager.ContractSFFLTaskManagerCheckpointTaskCreated) (event.Subscription, error) { m.ctrl.T.Helper() @@ -86,6 +104,22 @@ func (mr *MockAvsSubscribererMockRecorder) SubscribeToOperatorSetUpdates(arg0 an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeToOperatorSetUpdates", reflect.TypeOf((*MockAvsSubscriberer)(nil).SubscribeToOperatorSetUpdates), arg0) } +// SubscribeToOperatorSocketUpdates mocks base method. +func (m *MockAvsSubscriberer) SubscribeToOperatorSocketUpdates() (chan *contractRegistryCoordinator.ContractRegistryCoordinatorOperatorSocketUpdate, event.Subscription, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubscribeToOperatorSocketUpdates") + ret0, _ := ret[0].(chan *contractRegistryCoordinator.ContractRegistryCoordinatorOperatorSocketUpdate) + ret1, _ := ret[1].(event.Subscription) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// SubscribeToOperatorSocketUpdates indicates an expected call of SubscribeToOperatorSocketUpdates. +func (mr *MockAvsSubscribererMockRecorder) SubscribeToOperatorSocketUpdates() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeToOperatorSocketUpdates", reflect.TypeOf((*MockAvsSubscriberer)(nil).SubscribeToOperatorSocketUpdates)) +} + // SubscribeToTaskResponses mocks base method. func (m *MockAvsSubscriberer) SubscribeToTaskResponses(arg0 chan *contractSFFLTaskManager.ContractSFFLTaskManagerCheckpointTaskResponded) (event.Subscription, error) { m.ctrl.T.Helper()