diff --git a/core/safeclient/client_test.go b/core/safeclient/client_test.go index a4f97d4d..711f7a03 100644 --- a/core/safeclient/client_test.go +++ b/core/safeclient/client_test.go @@ -617,3 +617,101 @@ func TestLogCache(t *testing.T) { log = <-logCh assert.Equal(t, uint64(2), log.BlockNumber) } + +func TestSubscribeFilterLogs_Unsubscribe(t *testing.T) { + logger, err := logging.NewZapLogger("development") + assert.NoError(t, err) + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockClient := mocks.NewMockEthClient(mockCtrl) + mockClient.EXPECT().BlockNumber(gomock.Any()).Return(uint64(1_000), nil) + mockClient.EXPECT().SubscribeFilterLogs(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, query ethereum.FilterQuery, ch chan<- types.Log) (ethereum.Subscription, error) { + errChann := make(chan error) + + sub := mocks.NewMockSubscription(mockCtrl) + sub.EXPECT().Unsubscribe().Do(func() { close(errChann) }) + sub.EXPECT().Err().Return(errChann) + + return sub, nil + }, + ) + + client, err := safeclient.NewSafeEthClient("", logger, safeclient.WithCustomCreateClient(func(string, logging.Logger) (eth.Client, error) { return mockClient, nil })) + assert.NoError(t, err) + assert.NotNil(t, client) + defer client.Close() + + filterQuery := ethereum.FilterQuery{ + FromBlock: big.NewInt(900), + ToBlock: big.NewInt(1_100), + } + logCh := make(chan types.Log) + + sub, err := client.SubscribeFilterLogs(context.Background(), filterQuery, logCh) + assert.NoError(t, err) + assert.NotNil(t, sub) +} + +func TestSubscribeFilterLogs_ErrorInSubscription_Resubscribe(t *testing.T) { + logger, err := logging.NewZapLogger("development") + assert.NoError(t, err) + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockClient := mocks.NewMockEthClient(mockCtrl) + mockClient.EXPECT().BlockNumber(gomock.Any()).Return(uint64(1_000), nil).Times(2) + + var triggerError func() + mockClient.EXPECT().SubscribeFilterLogs(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, query ethereum.FilterQuery, ch chan<- types.Log) (ethereum.Subscription, error) { + sub := mocks.NewMockSubscription(mockCtrl) + errChan := make(chan error) + triggerError = func() { + errChan <- errors.New("error") + } + sub.EXPECT().Unsubscribe().Do(func() { close(errChan) }) + sub.EXPECT().Err().Return(errChan) + + return sub, nil + }, + ).Times(2) // First subscription + one resubscription + + client, err := safeclient.NewSafeEthClient("", logger, safeclient.WithCustomCreateClient(func(string, logging.Logger) (eth.Client, error) { return mockClient, nil })) + assert.NoError(t, err) + assert.NotNil(t, client) + defer client.Close() + + filterQuery := ethereum.FilterQuery{ + FromBlock: big.NewInt(900), + ToBlock: big.NewInt(1_100), + } + logCh := make(chan types.Log) + + sub, err := client.SubscribeFilterLogs(context.Background(), filterQuery, logCh) + assert.NoError(t, err) + assert.NotNil(t, sub) + + triggerError() +} + +func TestSafeSubscription_ConcurrentUnsubscribe(t *testing.T) { + mockCtrl := gomock.NewController(t) + sub := mocks.NewMockSubscription(mockCtrl) + sub.EXPECT().Unsubscribe().Times(1) + + safeSub := safeclient.NewSafeSubscription(sub) + + var wg sync.WaitGroup + for i := 1; i <= 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + safeSub.Unsubscribe() + }() + } + wg.Wait() +}