Skip to content

Commit

Permalink
fix: unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
taco-paco committed Jun 3, 2024
1 parent 39d162e commit bde63d6
Show file tree
Hide file tree
Showing 10 changed files with 302 additions and 233 deletions.
38 changes: 24 additions & 14 deletions aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ var (
CheckpointNotFoundError = errors.New("CheckpointMessages not found")
)

type RpcAggregatorer interface {
ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse *messages.SignedCheckpointTaskResponse) error
ProcessSignedStateRootUpdateMessage(signedStateRootUpdateMessage *messages.SignedStateRootUpdateMessage) error
ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdateMessage *messages.SignedOperatorSetUpdateMessage) error
GetAggregatedCheckpointMessages(fromTimestamp, toTimestamp uint64) (*messages.CheckpointMessages, error)
GetRegistryCoordinatorAddress(reply *string) error
}

type RestAggregatorer interface {
GetStateRootUpdateAggregation(rollupId uint32, blockHeight uint64) (*types.GetStateRootUpdateAggregationResponse, error)
GetOperatorSetUpdateAggregation(id uint64) (*types.GetOperatorSetUpdateAggregationResponse, error)
GetCheckpointMessages(fromTimestamp, toTimestamp uint64) (*types.GetCheckpointMessagesResponse, error)
}

// Aggregator sends checkpoint tasks onchain, then listens for operator signed TaskResponses.
// It aggregates responses signatures, and if any of the TaskResponses reaches the QuorumThreshold for each quorum
// (currently we only use a single quorum of the ERC20Mock token), it sends the aggregated TaskResponse and signature onchain.
Expand Down Expand Up @@ -118,6 +132,8 @@ type Aggregator struct {
}

var _ core.Metricable = (*Aggregator)(nil)
var _ RpcAggregatorer = (*Aggregator)(nil)
var _ RestAggregatorer = (*Aggregator)(nil)

// NewAggregator creates a new Aggregator with the provided config.
// TODO: Remove this context once OperatorPubkeysServiceInMemory's API is
Expand Down Expand Up @@ -516,7 +532,7 @@ func (agg *Aggregator) handleOperatorSetUpdateReachedQuorum(ctx context.Context,
}
}

func (agg *Aggregator) ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse *messages.SignedCheckpointTaskResponse, reply *bool) error {
func (agg *Aggregator) ProcessSignedCheckpointTaskResponse(signedCheckpointTaskResponse *messages.SignedCheckpointTaskResponse) error {
taskIndex := signedCheckpointTaskResponse.TaskResponse.ReferenceTaskIndex
taskResponseDigest, err := signedCheckpointTaskResponse.TaskResponse.Digest()
if err != nil {
Expand Down Expand Up @@ -545,7 +561,7 @@ func (agg *Aggregator) ProcessSignedCheckpointTaskResponse(signedCheckpointTaskR
}

// Rpc request handlers
func (agg *Aggregator) ProcessSignedStateRootUpdateMessage(signedStateRootUpdateMessage *messages.SignedStateRootUpdateMessage, reply *bool) error {
func (agg *Aggregator) ProcessSignedStateRootUpdateMessage(signedStateRootUpdateMessage *messages.SignedStateRootUpdateMessage) error {
messageDigest, err := signedStateRootUpdateMessage.Message.Digest()
if err != nil {
agg.logger.Error("Failed to get message digest", "err", err)
Expand Down Expand Up @@ -575,7 +591,7 @@ func (agg *Aggregator) ProcessSignedStateRootUpdateMessage(signedStateRootUpdate
return err
}

func (agg *Aggregator) ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdateMessage *messages.SignedOperatorSetUpdateMessage, reply *bool) error {
func (agg *Aggregator) ProcessSignedOperatorSetUpdateMessage(signedOperatorSetUpdateMessage *messages.SignedOperatorSetUpdateMessage) error {
messageDigest, err := signedOperatorSetUpdateMessage.Message.Digest()
if err != nil {
agg.logger.Error("Failed to get message digest", "err", err)
Expand Down Expand Up @@ -617,22 +633,16 @@ func (agg *Aggregator) GetRegistry() *prometheus.Registry {
return agg.registry
}

type GetAggregatedCheckpointMessagesArgs struct {
FromTimestamp, ToTimestamp uint64
}

func (agg *Aggregator) GetAggregatedCheckpointMessages(args *GetAggregatedCheckpointMessagesArgs, reply *messages.CheckpointMessages) error {
checkpointMessages, err := agg.msgDb.FetchCheckpointMessages(args.FromTimestamp, args.ToTimestamp)
func (agg *Aggregator) GetAggregatedCheckpointMessages(fromTimestamp, toTimestamp uint64) (*messages.CheckpointMessages, error) {
checkpointMessages, err := agg.msgDb.FetchCheckpointMessages(fromTimestamp, toTimestamp)
if err != nil {
return err
return nil, err
}

*reply = *checkpointMessages

return nil
return checkpointMessages, nil
}

func (agg *Aggregator) GetRegistryCoordinatorAddress(_ *struct{}, reply *string) error {
func (agg *Aggregator) GetRegistryCoordinatorAddress(reply *string) error {
*reply = agg.config.SFFLRegistryCoordinatorAddr.String()
return nil
}
Expand Down
2 changes: 0 additions & 2 deletions aggregator/aggregator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ func createMockAggregator(
rollupBroadcaster: mockRollupBroadcaster,
httpClient: mockClient,
wsClient: mockClient,
rpcListener: &SelectiveRpcListener{},
restListener: &SelectiveRestListener{},
aggregatorListener: &SelectiveAggregatorListener{},
}
return aggregator, mockAvsReader, mockAvsWriter, mockTaskBlsAggregationService, mockStateRootUpdateBlsAggregationService, mockOperatorSetUpdateBlsAggregationService, mockMsgDb, mockRollupBroadcaster, mockClient, nil
Expand Down
8 changes: 6 additions & 2 deletions aggregator/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,18 @@ func aggregatorMain(ctx *cli.Context) error {
registry := agg.GetRegistry()
rpcServer := rpcserver.NewRpcServer(config.AggregatorServerIpPortAddr, agg, logger)
if registry != nil {
rpcServer.EnableMetrics(registry)
if err = rpcServer.EnableMetrics(registry); err != nil {
return err
}
}

go rpcServer.Start()

restServer := restserver.NewRestServer(config.AggregatorRestServerIpPortAddr, agg, logger)
if registry != nil {
restServer.EnableMetrics(registry)
if err = restServer.EnableMetrics(registry); err != nil {
return err
}
}
go restServer.Start()

Expand Down
1 change: 1 addition & 0 deletions aggregator/gen.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package aggregator

//go:generate mockgen -destination=./mocks/rest_aggregator.go -package=mocks github.com/NethermindEth/near-sffl/aggregator RestAggregatorer
//go:generate mockgen -destination=./mocks/message_blsagg.go -package=mocks github.com/NethermindEth/near-sffl/aggregator MessageBlsAggregationService
//go:generate mockgen -destination=./mocks/rollup_broadcaster.go -package=mocks github.com/NethermindEth/near-sffl/aggregator RollupBroadcasterer
//go:generate mockgen -destination=./mocks/eth_client.go -package=mocks github.com/Layr-Labs/eigensdk-go/chainio/clients/eth Client
84 changes: 84 additions & 0 deletions aggregator/mocks/rest_aggregator.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions aggregator/rest_server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ var (

type RestServer struct {
serverIpPortAddr string
app *aggregator.Aggregator
app aggregator.RestAggregatorer

logger logging.Logger
listener EventListener
}

var _ core.Metricable = (*RestServer)(nil)

func NewRestServer(serverIpPortAddr string, app *aggregator.Aggregator, logger logging.Logger) *RestServer {
func NewRestServer(serverIpPortAddr string, app aggregator.RestAggregatorer, logger logging.Logger) *RestServer {
return &RestServer{
serverIpPortAddr: serverIpPortAddr,
app: app,
Expand Down
161 changes: 161 additions & 0 deletions aggregator/rest_server/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package rest_server

import (
"encoding/json"
"fmt"
sdklogging "github.com/Layr-Labs/eigensdk-go/logging"
"github.com/NethermindEth/near-sffl/aggregator/mocks"
"github.com/NethermindEth/near-sffl/core/types/messages"
"github.com/NethermindEth/near-sffl/tests"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"

aggtypes "github.com/NethermindEth/near-sffl/aggregator/types"
)

func TestGetStateRootUpdateAggregation(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

logger := sdklogging.NewNoopLogger()
aggregator := mocks.NewMockRestAggregatorer(mockCtrl)
restServer := NewRestServer("", aggregator, logger)

msg := messages.StateRootUpdateMessage{
RollupId: 1,
BlockHeight: 2,
Timestamp: 3,
NearDaCommitment: tests.Keccak256(4),
NearDaTransactionId: tests.Keccak256(5),
StateRoot: tests.Keccak256(6),
}
msgDigest, err := msg.Digest()
assert.Nil(t, err)

response := aggtypes.GetStateRootUpdateAggregationResponse{
Message: msg,
Aggregation: messages.MessageBlsAggregation{
MessageDigest: msgDigest,
},
}
aggregator.EXPECT().GetStateRootUpdateAggregation(msg.RollupId, msg.BlockHeight).Return(&response, nil)

req, err := http.NewRequest(
"GET",
fmt.Sprintf("/aggregation/state-root-update?rollupId=%d&blockHeight=%d", msg.RollupId, msg.BlockHeight),
nil,
)
assert.Nil(t, err)

recorder := httptest.NewRecorder()
err = restServer.handleGetStateRootUpdateAggregation(recorder, req)
assert.Nil(t, err)
assert.Equal(t, recorder.Code, http.StatusOK)

var body aggtypes.GetStateRootUpdateAggregationResponse
err = json.Unmarshal(recorder.Body.Bytes(), &body)
assert.Nil(t, err)
assert.Equal(t, body, response)
}

func TestGetOperatorSetUpdateAggregation(t *testing.T) {
mockCtrl := gomock.NewController(t)

logger := sdklogging.NewNoopLogger()
aggregator := mocks.NewMockRestAggregatorer(mockCtrl)
restServer := NewRestServer("", aggregator, logger)

msg := messages.OperatorSetUpdateMessage{
Id: 1,
Timestamp: 2,
}
digest, err := msg.Digest()
assert.Nil(t, err)

response := aggtypes.GetOperatorSetUpdateAggregationResponse{
Message: msg,
Aggregation: messages.MessageBlsAggregation{
MessageDigest: digest,
},
}

aggregator.EXPECT().GetOperatorSetUpdateAggregation(msg.Id).Return(&response, nil)

req, err := http.NewRequest(
"GET",
fmt.Sprintf("/aggregation/operator-set-update?id=%d", msg.Id),
nil,
)
assert.Nil(t, err)

recorder := httptest.NewRecorder()
err = restServer.handleGetOperatorSetUpdateAggregation(recorder, req)
assert.Nil(t, err)
assert.Equal(t, recorder.Code, http.StatusOK)

var actual aggtypes.GetOperatorSetUpdateAggregationResponse
err = json.Unmarshal(recorder.Body.Bytes(), &actual)
assert.Nil(t, err)
assert.Equal(t, response, actual)
}

func TestGetCheckpointMessages(t *testing.T) {
mockCtrl := gomock.NewController(t)

logger := sdklogging.NewNoopLogger()
aggregator := mocks.NewMockRestAggregatorer(mockCtrl)
restServer := NewRestServer("", aggregator, logger)

stateRootMessage := messages.StateRootUpdateMessage{
RollupId: 1,
BlockHeight: 2,
Timestamp: 3,
}
stateRootDigest, err := stateRootMessage.Digest()
assert.Nil(t, err)
stateRootAggregation := messages.MessageBlsAggregation{
MessageDigest: stateRootDigest,
}

operatorSetMesssage := messages.OperatorSetUpdateMessage{
Id: 1,
Timestamp: 2,
}
operatorSetDigest, err := operatorSetMesssage.Digest()
assert.Nil(t, err)
operatorSetAggregation := messages.MessageBlsAggregation{
MessageDigest: operatorSetDigest,
}

var fromTimestamp, toTimestamp uint64 = 0, 3
response := aggtypes.GetCheckpointMessagesResponse{
CheckpointMessages: messages.CheckpointMessages{
StateRootUpdateMessages: []messages.StateRootUpdateMessage{stateRootMessage},
StateRootUpdateMessageAggregations: []messages.MessageBlsAggregation{stateRootAggregation},
OperatorSetUpdateMessages: []messages.OperatorSetUpdateMessage{operatorSetMesssage},
OperatorSetUpdateMessageAggregations: []messages.MessageBlsAggregation{operatorSetAggregation},
},
}
aggregator.EXPECT().GetCheckpointMessages(fromTimestamp, toTimestamp).Return(&response, nil)

req, err := http.NewRequest(
"GET",
fmt.Sprintf("/checkpoint/messages?fromTimestamp=%d&toTimestamp=%d", fromTimestamp, toTimestamp),
nil,
)
assert.Nil(t, err)

recorder := httptest.NewRecorder()
err = restServer.handleGetCheckpointMessages(recorder, req)
assert.Nil(t, err)
assert.Equal(t, recorder.Code, http.StatusOK)

var actual aggtypes.GetCheckpointMessagesResponse
err = json.Unmarshal(recorder.Body.Bytes(), &actual)
assert.Nil(t, err)
assert.Equal(t, response, actual)
}
Loading

0 comments on commit bde63d6

Please sign in to comment.