From db19286769725b9ef8b5f52f8007dd3952113430 Mon Sep 17 00:00:00 2001 From: dimkouv Date: Wed, 13 Sep 2023 13:29:55 +0300 Subject: [PATCH 1/2] move libs under internal pkg --- .../services/ocr2/plugins/ccip/cache/cache.go | 138 --------- .../ocr2/plugins/ccip/cache/cache_mock.go | 53 ---- .../ocr2/plugins/ccip/cache/cache_test.go | 178 ------------ .../ocr2/plugins/ccip/cache/snoozed_roots.go | 65 ----- .../plugins/ccip/cache/snoozed_roots_test.go | 40 --- .../ocr2/plugins/ccip/cache/tokens.go | 266 ------------------ .../ocr2/plugins/ccip/cache/tokens_test.go | 227 --------------- .../ocr2/plugins/ccip/commit_plugin.go | 13 +- .../plugins/ccip/commit_reporting_plugin.go | 13 +- .../ccip/commit_reporting_plugin_test.go | 15 +- .../plugins/ccip/execution_batch_building.go | 10 +- .../ocr2/plugins/ccip/execution_plugin.go | 9 +- .../ccip/execution_reporting_plugin.go | 8 +- .../ccip/execution_reporting_plugin_test.go | 8 +- .../ccip/{ => internal}/ccipevents/client.go | 0 .../{ => internal}/ccipevents/logpoller.go | 0 .../ccipevents/logpoller_test.go | 0 .../{hasher => internal/hashlib}/hasher.go | 2 +- .../hashlib}/hasher_test.go | 2 +- .../hashlib}/leaf_hasher.go | 2 +- .../hashlib}/leaf_hasher_test.go | 22 +- .../merkle_multi_proof_test_vector.go | 0 .../merklemulti/merkle_multi.go | 16 +- .../merklemulti/merkle_multi_test.go | 8 +- .../oraclelib}/backfilled_oracle.go | 2 +- .../oraclelib}/backfilled_oracle_test.go | 11 +- .../ocr2/plugins/ccip/plugins_common.go | 6 +- .../ccip/testhelpers/ccip_contracts.go | 8 +- .../testhelpers/plugins/plugin_harness.go | 8 +- 29 files changed, 85 insertions(+), 1045 deletions(-) delete mode 100644 core/services/ocr2/plugins/ccip/cache/cache.go delete mode 100644 core/services/ocr2/plugins/ccip/cache/cache_mock.go delete mode 100644 core/services/ocr2/plugins/ccip/cache/cache_test.go delete mode 100644 core/services/ocr2/plugins/ccip/cache/snoozed_roots.go delete mode 100644 core/services/ocr2/plugins/ccip/cache/snoozed_roots_test.go delete mode 100644 core/services/ocr2/plugins/ccip/cache/tokens.go delete mode 100644 core/services/ocr2/plugins/ccip/cache/tokens_test.go rename core/services/ocr2/plugins/ccip/{ => internal}/ccipevents/client.go (100%) rename core/services/ocr2/plugins/ccip/{ => internal}/ccipevents/logpoller.go (100%) rename core/services/ocr2/plugins/ccip/{ => internal}/ccipevents/logpoller_test.go (100%) rename core/services/ocr2/plugins/ccip/{hasher => internal/hashlib}/hasher.go (98%) rename core/services/ocr2/plugins/ccip/{hasher => internal/hashlib}/hasher_test.go (95%) rename core/services/ocr2/plugins/ccip/{hasher => internal/hashlib}/leaf_hasher.go (99%) rename core/services/ocr2/plugins/ccip/{hasher => internal/hashlib}/leaf_hasher_test.go (78%) rename core/services/ocr2/plugins/ccip/{ => internal}/merklemulti/fixtures/merkle_multi_proof_test_vector.go (100%) rename core/services/ocr2/plugins/ccip/{ => internal}/merklemulti/merkle_multi.go (90%) rename core/services/ocr2/plugins/ccip/{ => internal}/merklemulti/merkle_multi_test.go (97%) rename core/services/ocr2/plugins/ccip/{ => internal/oraclelib}/backfilled_oracle.go (99%) rename core/services/ocr2/plugins/ccip/{ => internal/oraclelib}/backfilled_oracle_test.go (80%) diff --git a/core/services/ocr2/plugins/ccip/cache/cache.go b/core/services/ocr2/plugins/ccip/cache/cache.go deleted file mode 100644 index 1668d6cf33..0000000000 --- a/core/services/ocr2/plugins/ccip/cache/cache.go +++ /dev/null @@ -1,138 +0,0 @@ -package cache - -import ( - "context" - "sync" - - "github.com/ethereum/go-ethereum/common" - - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" -) - -// AutoSync cache provides only a Get method, the expiration and syncing is a black-box for the caller. -// -//go:generate mockery --quiet --name AutoSync --output . --filename cache_mock.go --inpackage --case=underscore -type AutoSync[T any] interface { - Get(ctx context.Context) (T, error) -} - -// CachedChain represents caching on-chain calls based on the events read from logpoller.LogPoller. -// Instead of directly going to on-chain to fetch data, we start with checking logpoller.LogPoller events (database request). -// If we discover that change occurred since last update, we perform RPC to the chain using ContractOrigin.CallOrigin function. -// Purpose of this struct is handle common logic in a single place, you only need to override methods from ContractOrigin -// and Get function (behaving as orchestrator) will take care of the rest. -// -// That being said, adding caching layer to the new contract is as simple as: -// * implementing ContractOrigin interface -// * registering proper events in log poller -type CachedChain[T any] struct { - // Static configuration - observedEvents []common.Hash - logPoller logpoller.LogPoller - address []common.Address - optimisticConfirmations int64 - - // Cache - lock *sync.RWMutex - value T - lastChangeBlock int64 - origin ContractOrigin[T] -} - -type ContractOrigin[T any] interface { - // Copy must return copy of the cached data to limit locking section to the minimum - Copy(T) T - // CallOrigin fetches data that is next stored within cache. Usually, should perform RPC to the source (e.g. chain) - CallOrigin(ctx context.Context) (T, error) -} - -// Get is an entry point to the caching. Main function that decides whether cache content is fresh and should be returned -// to the caller, or whether we need to update it's content from on-chain data. -// This decision is made based on the events emitted by Smart Contracts -func (c *CachedChain[T]) Get(ctx context.Context) (T, error) { - var empty T - - lastChangeBlock := c.readLastChangeBlock() - - // Handles first call, because cache is not eagerly populated - if lastChangeBlock == 0 { - return c.initializeCache(ctx) - } - - currentBlockNumber, err := c.logPoller.LatestBlockByEventSigsAddrsWithConfs(lastChangeBlock, c.observedEvents, c.address, int(c.optimisticConfirmations), pg.WithParentCtx(ctx)) - - if err != nil { - return empty, err - } - - // In case of new updates, fetch fresh data from the origin - if currentBlockNumber > lastChangeBlock { - return c.fetchFromOrigin(ctx, currentBlockNumber) - } - return c.copyCachedValue(), nil -} - -// initializeCache performs first call to origin when is not populated yet. -// It's done eagerly, so cache it's populated for the first time when data is needed, not at struct initialization -func (c *CachedChain[T]) initializeCache(ctx context.Context) (T, error) { - var empty T - - // To prevent missing data when blocks are produced after calling the origin, - // we first get the latest block and then call the origin. - latestBlock, err := c.logPoller.LatestBlock(pg.WithParentCtx(ctx)) - if err != nil { - return empty, err - } - - // Init - value, err := c.origin.CallOrigin(ctx) - if err != nil { - return empty, err - } - - c.updateCache(value, latestBlock-c.optimisticConfirmations) - return c.copyCachedValue(), nil - -} - -// fetchFromOrigin fetches data from origin. This action is performed when logpoller.LogPoller says there were events -// emitted since last update. -func (c *CachedChain[T]) fetchFromOrigin(ctx context.Context, currentBlockNumber int64) (T, error) { - var empty T - value, err := c.origin.CallOrigin(ctx) - if err != nil { - return empty, err - } - c.updateCache(value, currentBlockNumber) - - return c.copyCachedValue(), nil -} - -// updateCache performs updating two critical variables for cache to work properly: -// * value that is stored within cache -// * lastChangeBlock representing last seen event from logpoller.LogPoller -func (c *CachedChain[T]) updateCache(newValue T, currentBlockNumber int64) { - c.lock.Lock() - defer c.lock.Unlock() - - // Double-lock checking. No need to update if other goroutine was faster - if currentBlockNumber <= c.lastChangeBlock { - return - } - - c.value = newValue - c.lastChangeBlock = currentBlockNumber -} - -func (c *CachedChain[T]) readLastChangeBlock() int64 { - c.lock.RLock() - defer c.lock.RUnlock() - return c.lastChangeBlock -} - -func (c *CachedChain[T]) copyCachedValue() T { - c.lock.RLock() - defer c.lock.RUnlock() - return c.origin.Copy(c.value) -} diff --git a/core/services/ocr2/plugins/ccip/cache/cache_mock.go b/core/services/ocr2/plugins/ccip/cache/cache_mock.go deleted file mode 100644 index a5cd3d901d..0000000000 --- a/core/services/ocr2/plugins/ccip/cache/cache_mock.go +++ /dev/null @@ -1,53 +0,0 @@ -// Code generated by mockery v2.28.1. DO NOT EDIT. - -package cache - -import ( - context "context" - - mock "github.com/stretchr/testify/mock" -) - -// MockAutoSync is an autogenerated mock type for the AutoSync type -type MockAutoSync[T interface{}] struct { - mock.Mock -} - -// Get provides a mock function with given fields: ctx -func (_m *MockAutoSync[T]) Get(ctx context.Context) (T, error) { - ret := _m.Called(ctx) - - var r0 T - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (T, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) T); ok { - r0 = rf(ctx) - } else { - r0 = ret.Get(0).(T) - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -type mockConstructorTestingTNewMockAutoSync interface { - mock.TestingT - Cleanup(func()) -} - -// NewMockAutoSync creates a new instance of MockAutoSync. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewMockAutoSync[T interface{}](t mockConstructorTestingTNewMockAutoSync) *MockAutoSync[T] { - mock := &MockAutoSync[T]{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/core/services/ocr2/plugins/ccip/cache/cache_test.go b/core/services/ocr2/plugins/ccip/cache/cache_test.go deleted file mode 100644 index e37b09b95e..0000000000 --- a/core/services/ocr2/plugins/ccip/cache/cache_test.go +++ /dev/null @@ -1,178 +0,0 @@ -package cache - -import ( - "context" - "strconv" - "sync" - "testing" - - "github.com/ethereum/go-ethereum/common" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" - lpMocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller/mocks" - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" -) - -const ( - cachedValue = "cached_value" -) - -func TestGet_InitDataForTheFirstTime(t *testing.T) { - lp := lpMocks.NewLogPoller(t) - lp.On("LatestBlock", mock.Anything).Maybe().Return(int64(100), nil) - - contract := newCachedContract(lp, "", []string{"value1"}, 0) - - value, err := contract.Get(testutils.Context(t)) - require.NoError(t, err) - require.Equal(t, "value1", value) -} - -func TestGet_ReturnDataFromCacheIfNoNewEvents(t *testing.T) { - latestBlock := int64(100) - lp := lpMocks.NewLogPoller(t) - mockLogPollerQuery(lp, latestBlock) - - contract := newCachedContract(lp, cachedValue, []string{"value1"}, latestBlock) - - value, err := contract.Get(testutils.Context(t)) - require.NoError(t, err) - require.Equal(t, cachedValue, value) -} - -func TestGet_CallOriginForNewEvents(t *testing.T) { - latestBlock := int64(100) - lp := lpMocks.NewLogPoller(t) - m := mockLogPollerQuery(lp, latestBlock+1) - - contract := newCachedContract(lp, cachedValue, []string{"value1", "value2", "value3"}, latestBlock) - - // First call - value, err := contract.Get(testutils.Context(t)) - require.NoError(t, err) - require.Equal(t, "value1", value) - - currentBlock := contract.lastChangeBlock - require.Equal(t, latestBlock+1, currentBlock) - - m.Unset() - mockLogPollerQuery(lp, latestBlock+1) - - // Second call doesn't change anything - value, err = contract.Get(testutils.Context(t)) - require.NoError(t, err) - require.Equal(t, "value1", value) - require.Equal(t, int64(101), contract.lastChangeBlock) -} - -func TestGet_CacheProgressing(t *testing.T) { - firstBlock := int64(100) - secondBlock := int64(105) - thirdBlock := int64(110) - - lp := lpMocks.NewLogPoller(t) - m := mockLogPollerQuery(lp, secondBlock) - - contract := newCachedContract(lp, cachedValue, []string{"value1", "value2", "value3"}, firstBlock) - - // First call - value, err := contract.Get(testutils.Context(t)) - require.NoError(t, err) - require.Equal(t, "value1", value) - require.Equal(t, secondBlock, contract.lastChangeBlock) - - m.Unset() - mockLogPollerQuery(lp, thirdBlock) - - // Second call - value, err = contract.Get(testutils.Context(t)) - require.NoError(t, err) - require.Equal(t, "value2", value) - require.Equal(t, thirdBlock, contract.lastChangeBlock) -} - -func TestGet_ConcurrentAccess(t *testing.T) { - mockedPoller := lpMocks.NewLogPoller(t) - progressingPoller := ProgressingLogPoller{ - LogPoller: mockedPoller, - latestBlock: 1, - } - - iterations := 100 - originValues := make([]string, iterations) - for i := 0; i < iterations; i++ { - originValues[i] = "value_" + strconv.Itoa(i) - } - contract := newCachedContract(&progressingPoller, "empty", originValues, 1) - - var wg sync.WaitGroup - wg.Add(iterations) - for i := 0; i < iterations; i++ { - go func() { - defer wg.Done() - _, _ = contract.Get(testutils.Context(t)) - }() - } - wg.Wait() - - // 1 init block + 100 iterations - require.Equal(t, int64(101), contract.lastChangeBlock) - - // Make sure that recent value is stored in cache - val := contract.copyCachedValue() - require.Equal(t, "value_99", val) -} - -func newCachedContract(lp logpoller.LogPoller, cacheValue string, originValue []string, lastChangeBlock int64) *CachedChain[string] { - return &CachedChain[string]{ - observedEvents: []common.Hash{{}}, - logPoller: lp, - address: []common.Address{{}}, - optimisticConfirmations: 0, - - lock: &sync.RWMutex{}, - value: cacheValue, - lastChangeBlock: lastChangeBlock, - origin: &FakeContractOrigin{values: originValue}, - } -} - -func mockLogPollerQuery(lp *lpMocks.LogPoller, latestBlock int64) *mock.Call { - return lp.On("LatestBlockByEventSigsAddrsWithConfs", mock.Anything, []common.Hash{{}}, []common.Address{{}}, 0, mock.Anything). - Maybe().Return(latestBlock, nil) -} - -type ProgressingLogPoller struct { - *lpMocks.LogPoller - latestBlock int64 - lock sync.Mutex -} - -func (lp *ProgressingLogPoller) LatestBlockByEventSigsAddrsWithConfs(int64, []common.Hash, []common.Address, int, ...pg.QOpt) (int64, error) { - lp.lock.Lock() - defer lp.lock.Unlock() - lp.latestBlock++ - return lp.latestBlock, nil -} - -type FakeContractOrigin struct { - values []string - counter int - lock sync.Mutex -} - -func (f *FakeContractOrigin) CallOrigin(context.Context) (string, error) { - f.lock.Lock() - defer func() { - f.counter++ - f.lock.Unlock() - }() - return f.values[f.counter], nil -} - -func (f *FakeContractOrigin) Copy(value string) string { - return value -} diff --git a/core/services/ocr2/plugins/ccip/cache/snoozed_roots.go b/core/services/ocr2/plugins/ccip/cache/snoozed_roots.go deleted file mode 100644 index 916716942a..0000000000 --- a/core/services/ocr2/plugins/ccip/cache/snoozed_roots.go +++ /dev/null @@ -1,65 +0,0 @@ -package cache - -import ( - "time" - - "github.com/patrickmn/go-cache" -) - -const ( - // EvictionGracePeriod defines how long after the permissionless execution threshold a root is still kept in the cache - EvictionGracePeriod = 1 * time.Hour - // CleanupInterval defines how often roots have to be evicted - CleanupInterval = 30 * time.Minute -) - -type SnoozedRoots interface { - IsSnoozed(merkleRoot [32]byte) bool - MarkAsExecuted(merkleRoot [32]byte) - Snooze(merkleRoot [32]byte) -} - -type snoozedRoots struct { - cache *cache.Cache - // Both rootSnoozedTime and permissionLessExecutionThresholdDuration can be kept in the snoozedRoots without need to be updated. - // Those config properties are populates via onchain/offchain config. When changed, OCR plugin will be restarted and cache initialized with new config. - rootSnoozedTime time.Duration - permissionLessExecutionThresholdDuration time.Duration -} - -func newSnoozedRoots( - permissionLessExecutionThresholdDuration time.Duration, - rootSnoozeTime time.Duration, - evictionGracePeriod time.Duration, - cleanupInterval time.Duration, -) *snoozedRoots { - evictionTime := permissionLessExecutionThresholdDuration + evictionGracePeriod - internalCache := cache.New(evictionTime, cleanupInterval) - - return &snoozedRoots{ - cache: internalCache, - rootSnoozedTime: rootSnoozeTime, - permissionLessExecutionThresholdDuration: permissionLessExecutionThresholdDuration, - } -} - -func NewSnoozedRoots(permissionLessExecutionThresholdDuration time.Duration, rootSnoozeTime time.Duration) *snoozedRoots { - return newSnoozedRoots(permissionLessExecutionThresholdDuration, rootSnoozeTime, EvictionGracePeriod, CleanupInterval) -} - -func (s *snoozedRoots) IsSnoozed(merkleRoot [32]byte) bool { - rawValue, found := s.cache.Get(merkleRootToString(merkleRoot)) - return found && time.Now().Before(rawValue.(time.Time)) -} - -func (s *snoozedRoots) MarkAsExecuted(merkleRoot [32]byte) { - s.cache.SetDefault(merkleRootToString(merkleRoot), time.Now().Add(s.permissionLessExecutionThresholdDuration)) -} - -func (s *snoozedRoots) Snooze(merkleRoot [32]byte) { - s.cache.SetDefault(merkleRootToString(merkleRoot), time.Now().Add(s.rootSnoozedTime)) -} - -func merkleRootToString(merkleRoot [32]byte) string { - return string(merkleRoot[:]) -} diff --git a/core/services/ocr2/plugins/ccip/cache/snoozed_roots_test.go b/core/services/ocr2/plugins/ccip/cache/snoozed_roots_test.go deleted file mode 100644 index f3813df575..0000000000 --- a/core/services/ocr2/plugins/ccip/cache/snoozed_roots_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package cache - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestSnoozedRoots(t *testing.T) { - c := NewSnoozedRoots(1*time.Minute, 1*time.Minute) - - k1 := [32]byte{1} - k2 := [32]byte{2} - - // return false for non existing element - snoozed := c.IsSnoozed(k1) - assert.False(t, snoozed) - - // after an element is marked as executed it should be snoozed - c.MarkAsExecuted(k1) - snoozed = c.IsSnoozed(k1) - assert.True(t, snoozed) - - // after snoozing an element it should be snoozed - c.Snooze(k2) - snoozed = c.IsSnoozed(k2) - assert.True(t, snoozed) -} - -func TestEvictingElements(t *testing.T) { - c := newSnoozedRoots(1*time.Millisecond, 1*time.Hour, 1*time.Millisecond, 1*time.Millisecond) - - k1 := [32]byte{1} - c.Snooze(k1) - - time.Sleep(1 * time.Second) - - assert.False(t, c.IsSnoozed(k1)) -} diff --git a/core/services/ocr2/plugins/ccip/cache/tokens.go b/core/services/ocr2/plugins/ccip/cache/tokens.go deleted file mode 100644 index 5a53cf964b..0000000000 --- a/core/services/ocr2/plugins/ccip/cache/tokens.go +++ /dev/null @@ -1,266 +0,0 @@ -package cache - -import ( - "context" - "fmt" - "sync" - - "github.com/ethereum/go-ethereum/accounts/abi/bind" - "github.com/ethereum/go-ethereum/common" - "golang.org/x/exp/slices" - - evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" - "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_offramp" - "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/price_registry" - "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/link_token_interface" - "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" -) - -// NewCachedFeeTokens cache fee tokens returned from PriceRegistry -func NewCachedFeeTokens( - lp logpoller.LogPoller, - priceRegistry price_registry.PriceRegistryInterface, - optimisticConfirmations int64, -) *CachedChain[[]common.Address] { - return &CachedChain[[]common.Address]{ - observedEvents: []common.Hash{ - abihelpers.EventSignatures.FeeTokenAdded, - abihelpers.EventSignatures.FeeTokenRemoved, - }, - logPoller: lp, - address: []common.Address{priceRegistry.Address()}, - optimisticConfirmations: optimisticConfirmations, - lock: &sync.RWMutex{}, - value: []common.Address{}, - lastChangeBlock: 0, - origin: &feeTokensOrigin{priceRegistry: priceRegistry}, - } -} - -type CachedTokens struct { - SupportedTokens map[common.Address]common.Address - FeeTokens []common.Address -} - -// NewCachedSupportedTokens cache both fee tokens and supported tokens. Therefore, it uses 4 different events -// when checking for changes in logpoller.LogPoller -func NewCachedSupportedTokens( - lp logpoller.LogPoller, - offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface, - priceRegistry price_registry.PriceRegistryInterface, - optimisticConfirmations int64, -) *CachedChain[CachedTokens] { - return &CachedChain[CachedTokens]{ - observedEvents: []common.Hash{ - abihelpers.EventSignatures.FeeTokenAdded, - abihelpers.EventSignatures.FeeTokenRemoved, - abihelpers.EventSignatures.PoolAdded, - abihelpers.EventSignatures.PoolRemoved, - }, - logPoller: lp, - address: []common.Address{priceRegistry.Address(), offRamp.Address()}, - optimisticConfirmations: optimisticConfirmations, - lock: &sync.RWMutex{}, - value: CachedTokens{}, - lastChangeBlock: 0, - origin: &feeAndSupportedTokensOrigin{ - feeTokensOrigin: feeTokensOrigin{priceRegistry: priceRegistry}, - supportedTokensOrigin: supportedTokensOrigin{offRamp: offRamp}}, - } -} - -func NewTokenToDecimals( - lggr logger.Logger, - lp logpoller.LogPoller, - offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface, - priceRegistry price_registry.PriceRegistryInterface, - client evmclient.Client, - optimisticConfirmations int64, -) *CachedChain[map[common.Address]uint8] { - return &CachedChain[map[common.Address]uint8]{ - observedEvents: []common.Hash{ - abihelpers.EventSignatures.FeeTokenAdded, - abihelpers.EventSignatures.FeeTokenRemoved, - abihelpers.EventSignatures.PoolAdded, - abihelpers.EventSignatures.PoolRemoved, - }, - logPoller: lp, - address: []common.Address{priceRegistry.Address(), offRamp.Address()}, - optimisticConfirmations: optimisticConfirmations, - lock: &sync.RWMutex{}, - value: make(map[common.Address]uint8), - lastChangeBlock: 0, - origin: &tokenToDecimals{ - lggr: lggr, - priceRegistry: priceRegistry, - offRamp: offRamp, - tokenFactory: func(token common.Address) (link_token_interface.LinkTokenInterface, error) { - return link_token_interface.NewLinkToken(token, client) - }, - }, - } -} - -type supportedTokensOrigin struct { - offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface -} - -func (t *supportedTokensOrigin) Copy(value map[common.Address]common.Address) map[common.Address]common.Address { - return copyMap(value) -} - -// CallOrigin Generates the source to dest token mapping based on the offRamp. -// NOTE: this queries the offRamp n+1 times, where n is the number of enabled tokens. -func (t *supportedTokensOrigin) CallOrigin(ctx context.Context) (map[common.Address]common.Address, error) { - srcToDstTokenMapping := make(map[common.Address]common.Address) - sourceTokens, err := t.offRamp.GetSupportedTokens(&bind.CallOpts{Context: ctx}) - if err != nil { - return nil, err - } - - seenDestinationTokens := make(map[common.Address]struct{}) - - for _, sourceToken := range sourceTokens { - dst, err1 := t.offRamp.GetDestinationToken(&bind.CallOpts{Context: ctx}, sourceToken) - if err1 != nil { - return nil, err1 - } - - if _, exists := seenDestinationTokens[dst]; exists { - return nil, fmt.Errorf("offRamp misconfig, destination token %s already exists", dst) - } - - seenDestinationTokens[dst] = struct{}{} - srcToDstTokenMapping[sourceToken] = dst - } - return srcToDstTokenMapping, nil -} - -type feeTokensOrigin struct { - priceRegistry price_registry.PriceRegistryInterface -} - -func (t *feeTokensOrigin) Copy(value []common.Address) []common.Address { - return copyArray(value) -} - -func (t *feeTokensOrigin) CallOrigin(ctx context.Context) ([]common.Address, error) { - return t.priceRegistry.GetFeeTokens(&bind.CallOpts{Context: ctx}) -} - -func copyArray(source []common.Address) []common.Address { - dst := make([]common.Address, len(source)) - copy(dst, source) - return dst -} - -type feeAndSupportedTokensOrigin struct { - feeTokensOrigin feeTokensOrigin - supportedTokensOrigin supportedTokensOrigin -} - -func (t *feeAndSupportedTokensOrigin) Copy(value CachedTokens) CachedTokens { - return CachedTokens{ - SupportedTokens: t.supportedTokensOrigin.Copy(value.SupportedTokens), - FeeTokens: t.feeTokensOrigin.Copy(value.FeeTokens), - } -} - -func (t *feeAndSupportedTokensOrigin) CallOrigin(ctx context.Context) (CachedTokens, error) { - supportedTokens, err := t.supportedTokensOrigin.CallOrigin(ctx) - if err != nil { - return CachedTokens{}, err - } - feeToken, err := t.feeTokensOrigin.CallOrigin(ctx) - if err != nil { - return CachedTokens{}, err - } - return CachedTokens{ - SupportedTokens: supportedTokens, - FeeTokens: feeToken, - }, nil -} - -func copyMap[M ~map[K]V, K comparable, V any](m M) M { - cpy := make(M) - for k, v := range m { - cpy[k] = v - } - return cpy -} - -type tokenToDecimals struct { - lggr logger.Logger - offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface - priceRegistry price_registry.PriceRegistryInterface - tokenFactory func(address common.Address) (link_token_interface.LinkTokenInterface, error) - tokenDecimals sync.Map -} - -func (t *tokenToDecimals) Copy(value map[common.Address]uint8) map[common.Address]uint8 { - return copyMap(value) -} - -// CallOrigin Generates the token to decimal mapping for dest tokens and fee tokens. -// NOTE: this queries token decimals n times, where n is the number of tokens whose decimals are not already cached. -func (t *tokenToDecimals) CallOrigin(ctx context.Context) (map[common.Address]uint8, error) { - mapping := make(map[common.Address]uint8) - - destTokens, err := t.offRamp.GetDestinationTokens(&bind.CallOpts{Context: ctx}) - if err != nil { - return nil, err - } - - feeTokens, err := t.priceRegistry.GetFeeTokens(&bind.CallOpts{Context: ctx}) - if err != nil { - return nil, err - } - - // In case if a fee token is not an offramp dest token, we still want to update its decimals and price - for _, feeToken := range feeTokens { - if !slices.Contains(destTokens, feeToken) { - destTokens = append(destTokens, feeToken) - } - } - - for _, token := range destTokens { - if decimals, exists := t.getCachedDecimals(token); exists { - mapping[token] = decimals - continue - } - - tokenContract, err := t.tokenFactory(token) - if err != nil { - return nil, err - } - - decimals, err := tokenContract.Decimals(&bind.CallOpts{Context: ctx}) - if err != nil { - return nil, fmt.Errorf("get token %s decimals: %w", token, err) - } - - t.setCachedDecimals(token, decimals) - mapping[token] = decimals - } - return mapping, nil -} - -func (t *tokenToDecimals) getCachedDecimals(token common.Address) (uint8, bool) { - rawVal, exists := t.tokenDecimals.Load(token.String()) - if !exists { - return 0, false - } - - decimals, isUint8 := rawVal.(uint8) - if !isUint8 { - return 0, false - } - - return decimals, true -} - -func (t *tokenToDecimals) setCachedDecimals(token common.Address, decimals uint8) { - t.tokenDecimals.Store(token.String(), decimals) -} diff --git a/core/services/ocr2/plugins/ccip/cache/tokens_test.go b/core/services/ocr2/plugins/ccip/cache/tokens_test.go deleted file mode 100644 index b562c44563..0000000000 --- a/core/services/ocr2/plugins/ccip/cache/tokens_test.go +++ /dev/null @@ -1,227 +0,0 @@ -package cache - -import ( - "context" - "fmt" - "math/big" - "testing" - - "github.com/ethereum/go-ethereum/common" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - - mock_contracts "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/mocks" - "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/link_token_interface" - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" - "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/utils" -) - -func Test_tokenToDecimals(t *testing.T) { - tokenPriceMappings := map[common.Address]uint8{ - common.HexToAddress("0xA"): 10, - common.HexToAddress("0xB"): 5, - common.HexToAddress("0xC"): 2, - } - - tests := []struct { - name string - destTokens []common.Address - feeTokens []common.Address - want map[common.Address]uint8 - wantErr bool - }{ - { - name: "empty map for empty tokens from origin", - destTokens: []common.Address{}, - feeTokens: []common.Address{}, - want: map[common.Address]uint8{}, - }, - { - name: "separate destination and fee tokens", - destTokens: []common.Address{common.HexToAddress("0xC")}, - feeTokens: []common.Address{common.HexToAddress("0xB")}, - want: map[common.Address]uint8{ - common.HexToAddress("0xC"): 2, - common.HexToAddress("0xB"): 5, - }, - }, - { - name: "fee tokens and dest tokens are overlapping", - destTokens: []common.Address{common.HexToAddress("0xA")}, - feeTokens: []common.Address{common.HexToAddress("0xA")}, - want: map[common.Address]uint8{ - common.HexToAddress("0xA"): 10, - }, - }, - { - name: "only fee tokens are returned", - destTokens: []common.Address{}, - feeTokens: []common.Address{common.HexToAddress("0xA"), common.HexToAddress("0xC")}, - want: map[common.Address]uint8{ - common.HexToAddress("0xA"): 10, - common.HexToAddress("0xC"): 2, - }, - }, - { - name: "missing tokens are skipped", - destTokens: []common.Address{}, - feeTokens: []common.Address{common.HexToAddress("0xD")}, - want: map[common.Address]uint8{}, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - offRamp := &mock_contracts.EVM2EVMOffRampInterface{} - offRamp.On("GetDestinationTokens", mock.Anything).Return(tt.destTokens, nil) - - priceRegistry := &mock_contracts.PriceRegistryInterface{} - priceRegistry.On("GetFeeTokens", mock.Anything).Return(tt.feeTokens, nil) - - tokenToDecimal := &tokenToDecimals{ - lggr: logger.TestLogger(t), - offRamp: offRamp, - priceRegistry: priceRegistry, - tokenFactory: createTokenFactory(tokenPriceMappings), - } - - got, err := tokenToDecimal.CallOrigin(testutils.Context(t)) - if tt.wantErr { - require.Error(t, err) - return - } - - require.NoError(t, err) - assert.Equal(t, tt.want, got) - - // we set token factory to always return an error - // we don't expect it to be used again, decimals should be in cache. - tokenToDecimal.tokenFactory = func(address common.Address) (link_token_interface.LinkTokenInterface, error) { - return nil, fmt.Errorf("some error") - } - got, err = tokenToDecimal.CallOrigin(testutils.Context(t)) - require.NoError(t, err) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestCallOrigin(t *testing.T) { - src1 := common.HexToAddress("10") - dst1 := common.HexToAddress("11") - src2 := common.HexToAddress("20") - dst2 := common.HexToAddress("21") - - testCases := []struct { - name string - srcTokens []common.Address - srcToDst map[common.Address]common.Address - expErr bool - }{ - { - name: "base", - srcTokens: []common.Address{src1, src2}, - srcToDst: map[common.Address]common.Address{ - src1: dst1, - src2: dst2, - }, - expErr: false, - }, - { - name: "dup dst token", - srcTokens: []common.Address{src1, src2}, - srcToDst: map[common.Address]common.Address{ - src1: dst1, - src2: dst1, - }, - expErr: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - offRamp := mock_contracts.NewEVM2EVMOffRampInterface(t) - offRamp.On("GetSupportedTokens", mock.Anything).Return(tc.srcTokens, nil) - for src, dst := range tc.srcToDst { - offRamp.On("GetDestinationToken", mock.Anything, src).Return(dst, nil) - } - o := supportedTokensOrigin{offRamp: offRamp} - srcToDst, err := o.CallOrigin(context.Background()) - - if tc.expErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - for src, dst := range tc.srcToDst { - assert.Equal(t, dst, srcToDst[src]) - } - }) - } -} - -func Test_copyArray(t *testing.T) { - t.Run("base", func(t *testing.T) { - a := []common.Address{common.HexToAddress("1"), common.HexToAddress("2")} - b := copyArray(a) - assert.Equal(t, a, b) - b[0] = common.HexToAddress("3") - assert.NotEqual(t, a, b) - }) - - t.Run("empty", func(t *testing.T) { - b := copyArray([]common.Address{}) - assert.Empty(t, b) - }) -} - -func Test_copyMap(t *testing.T) { - t.Run("base", func(t *testing.T) { - val := map[string]int{"a": 100, "b": 50} - cp := copyMap(val) - assert.Len(t, val, 2) - assert.Equal(t, 100, cp["a"]) - assert.Equal(t, 50, cp["b"]) - val["b"] = 10 - assert.Equal(t, 50, cp["b"]) - }) - - t.Run("pointer val", func(t *testing.T) { - val := map[string]*big.Int{"a": big.NewInt(100), "b": big.NewInt(50)} - cp := copyMap(val) - val["a"] = big.NewInt(20) - assert.Equal(t, int64(100), cp["a"].Int64()) - }) -} - -func Test_cachedDecimals(t *testing.T) { - tokenDecimalsCache := &tokenToDecimals{} - addr := utils.RandomAddress() - - decimals, exists := tokenDecimalsCache.getCachedDecimals(addr) - assert.Zero(t, decimals) - assert.False(t, exists) - - tokenDecimalsCache.setCachedDecimals(addr, 123) - decimals, exists = tokenDecimalsCache.getCachedDecimals(addr) - assert.Equal(t, uint8(123), decimals) - assert.True(t, exists) -} - -func createTokenFactory(decimalMapping map[common.Address]uint8) func(address common.Address) (link_token_interface.LinkTokenInterface, error) { - return func(address common.Address) (link_token_interface.LinkTokenInterface, error) { - linkToken := &mock_contracts.LinkTokenInterface{} - if decimals, found := decimalMapping[address]; found { - // Make sure each token is fetched only once - linkToken.On("Decimals", mock.Anything).Return(decimals, nil) - } else { - linkToken.On("Decimals", mock.Anything).Return(uint8(0), errors.New("Error")) - } - return linkToken, nil - } -} diff --git a/core/services/ocr2/plugins/ccip/commit_plugin.go b/core/services/ocr2/plugins/ccip/commit_plugin.go index 170df6713b..f86b06b15d 100644 --- a/core/services/ocr2/plugins/ccip/commit_plugin.go +++ b/core/services/ocr2/plugins/ccip/commit_plugin.go @@ -10,9 +10,14 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" chainselectors "github.com/smartcontractkit/chain-selectors" - relaylogger "github.com/smartcontractkit/chainlink-relay/pkg/logger" + libocr2 "github.com/smartcontractkit/libocr/offchainreporting2plus" + relaylogger "github.com/smartcontractkit/chainlink-relay/pkg/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipevents" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/oraclelib" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" @@ -21,9 +26,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipevents" ccipconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/config" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/promwrapper" "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" @@ -92,7 +95,7 @@ func NewCommitServices(lggr logger.Logger, jb job.Job, chainSet evm.LegacyChainC return nil, err } - leafHasher := hasher.NewLeafHasher(staticConfig.SourceChainSelector, staticConfig.ChainSelector, onRamp.Address(), hasher.NewKeccakCtx()) + leafHasher := hashlib.NewLeafHasher(staticConfig.SourceChainSelector, staticConfig.ChainSelector, onRamp.Address(), hashlib.NewKeccakCtx()) // Note that lggr already has the jobName and contractID (commit store) commitLggr := lggr.Named("CCIPCommit").With( "sourceChain", ChainName(int64(chainId)), @@ -136,7 +139,7 @@ func NewCommitServices(lggr logger.Logger, jb job.Job, chainSet evm.LegacyChainC "sourceRouter", sourceRouter.Address()) // If this is a brand-new job, then we make use of the start blocks. If not then we're rebooting and log poller will pick up where we left off. if new { - return []job.ServiceCtx{NewBackfilledOracle( + return []job.ServiceCtx{oraclelib.NewBackfilledOracle( commitLggr, sourceChain.LogPoller(), destChain.LogPoller(), diff --git a/core/services/ocr2/plugins/ccip/commit_reporting_plugin.go b/core/services/ocr2/plugins/ccip/commit_reporting_plugin.go index e6207fae90..8d22fb8461 100644 --- a/core/services/ocr2/plugins/ccip/commit_reporting_plugin.go +++ b/core/services/ocr2/plugins/ccip/commit_reporting_plugin.go @@ -24,11 +24,12 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/price_registry" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/cache" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipevents" ccipconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/config" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/merklemulti" + + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/cache" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipevents" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/merklemulti" "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) @@ -64,7 +65,7 @@ type CommitPluginConfig struct { sourceNative common.Address sourceFeeEstimator gas.EvmFeeEstimator sourceClient, destClient evmclient.Client - leafHasher hasher.LeafHasherInterface[[32]byte] + leafHasher hashlib.LeafHasherInterface[[32]byte] checkFinalityTags bool } @@ -678,7 +679,7 @@ func (r *CommitReportingPlugin) buildReport(ctx context.Context, lggr logger.Log return commit_store.CommitStoreCommitReport{}, fmt.Errorf("tried building a tree without leaves") } - tree, err := merklemulti.NewTree(hasher.NewKeccakCtx(), leaves) + tree, err := merklemulti.NewTree(hashlib.NewKeccakCtx(), leaves) if err != nil { return commit_store.CommitStoreCommitReport{}, err } diff --git a/core/services/ocr2/plugins/ccip/commit_reporting_plugin_test.go b/core/services/ocr2/plugins/ccip/commit_reporting_plugin_test.go index be1037d0b2..416cad1258 100644 --- a/core/services/ocr2/plugins/ccip/commit_reporting_plugin_test.go +++ b/core/services/ocr2/plugins/ccip/commit_reporting_plugin_test.go @@ -33,12 +33,11 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/cache" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipevents" ccipconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/config" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/merklemulti" - + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/cache" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipevents" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/merklemulti" plugintesthelpers "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/testhelpers/plugins" "github.com/smartcontractkit/chainlink/v2/core/store/models" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -85,7 +84,7 @@ func setupCommitTestHarness(t *testing.T) commitTestHarness { sourceChainSelector: th.Source.ChainSelector, destClient: backendClient, sourceClient: backendClient, - leafHasher: hasher.NewLeafHasher(th.Source.ChainSelector, th.Dest.ChainSelector, th.Source.OnRamp.Address(), hasher.NewKeccakCtx()), + leafHasher: hashlib.NewLeafHasher(th.Source.ChainSelector, th.Dest.ChainSelector, th.Source.OnRamp.Address(), hashlib.NewKeccakCtx()), }, inflightReports: newInflightCommitReportsContainer(time.Hour), onchainConfig: th.CommitOnchainConfig, @@ -142,7 +141,7 @@ func TestCommitReportEncoding(t *testing.T) { newGasPrice := big.NewInt(2000e9) // $2000 per eth * 1gwei // Send a report. - mctx := hasher.NewKeccakCtx() + mctx := hashlib.NewKeccakCtx() tree, err := merklemulti.NewTree(mctx, [][32]byte{mctx.Hash([]byte{0xaa})}) require.NoError(t, err) report := commit_store.CommitStoreCommitReport{ @@ -1096,7 +1095,7 @@ func TestShouldAcceptFinalizedReport(t *testing.T) { } func TestCommitReportToEthTxMeta(t *testing.T) { - mctx := hasher.NewKeccakCtx() + mctx := hashlib.NewKeccakCtx() tree, err := merklemulti.NewTree(mctx, [][32]byte{mctx.Hash([]byte{0xaa})}) require.NoError(t, err) diff --git a/core/services/ocr2/plugins/ccip/execution_batch_building.go b/core/services/ocr2/plugins/ccip/execution_batch_building.go index ede8b919a1..5f2019c24f 100644 --- a/core/services/ocr2/plugins/ccip/execution_batch_building.go +++ b/core/services/ocr2/plugins/ccip/execution_batch_building.go @@ -13,16 +13,16 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_onramp" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipevents" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/merklemulti" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipevents" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/merklemulti" "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) func getProofData( ctx context.Context, lggr logger.Logger, - hashLeaf hasher.LeafHasherInterface[[32]byte], + hashLeaf hashlib.LeafHasherInterface[[32]byte], onRampAddress common.Address, sourceEventsClient ccipevents.Client, interval commit_store.CommitStoreInterval, @@ -41,7 +41,7 @@ func getProofData( if err != nil { return nil, nil, nil, err } - tree, err = merklemulti.NewTree(hasher.NewKeccakCtx(), leaves) + tree, err = merklemulti.NewTree(hashlib.NewKeccakCtx(), leaves) if err != nil { return nil, nil, nil, err } diff --git a/core/services/ocr2/plugins/ccip/execution_plugin.go b/core/services/ocr2/plugins/ccip/execution_plugin.go index 1a103cea7e..2e9f45e709 100644 --- a/core/services/ocr2/plugins/ccip/execution_plugin.go +++ b/core/services/ocr2/plugins/ccip/execution_plugin.go @@ -15,6 +15,9 @@ import ( libocr2 "github.com/smartcontractkit/libocr/offchainreporting2plus" relaylogger "github.com/smartcontractkit/chainlink-relay/pkg/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipevents" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/oraclelib" "github.com/smartcontractkit/chainlink/v2/core/chains/evm" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" @@ -26,9 +29,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipevents" ccipconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/config" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/observability" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/promwrapper" "github.com/smartcontractkit/chainlink/v2/core/services/pg" @@ -121,7 +122,7 @@ func NewExecutionServices(lggr logger.Logger, jb job.Job, chainSet evm.LegacyCha destClient: destChain.Client(), sourceClient: sourceChain.Client(), destGasEstimator: destChain.GasEstimator(), - leafHasher: hasher.NewLeafHasher(offRampConfig.SourceChainSelector, offRampConfig.ChainSelector, onRamp.Address(), hasher.NewKeccakCtx()), + leafHasher: hashlib.NewLeafHasher(offRampConfig.SourceChainSelector, offRampConfig.ChainSelector, onRamp.Address(), hashlib.NewKeccakCtx()), }) err = wrappedPluginFactory.UpdateLogPollerFilters(zeroAddress, qopts...) @@ -145,7 +146,7 @@ func NewExecutionServices(lggr logger.Logger, jb job.Job, chainSet evm.LegacyCha // If this is a brand-new job, then we make use of the start blocks. If not then we're rebooting and log poller will pick up where we left off. if new { return []job.ServiceCtx{ - NewBackfilledOracle( + oraclelib.NewBackfilledOracle( execLggr, sourceChain.LogPoller(), destChain.LogPoller(), diff --git a/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go b/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go index f7f0c6e219..b35ae4a600 100644 --- a/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go +++ b/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go @@ -31,10 +31,10 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/router" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/cache" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipevents" ccipconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/config" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/cache" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipevents" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/observability" "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) @@ -65,7 +65,7 @@ type ExecutionPluginConfig struct { destClient evmclient.Client sourceClient evmclient.Client destGasEstimator gas.EvmFeeEstimator - leafHasher hasher.LeafHasherInterface[[32]byte] + leafHasher hashlib.LeafHasherInterface[[32]byte] } type ExecutionReportingPlugin struct { diff --git a/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go b/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go index a89636ea18..cc1c543566 100644 --- a/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go +++ b/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go @@ -27,9 +27,10 @@ import ( mock_contracts "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/mocks" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/cache" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipevents" ccipconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/config" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/cache" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipevents" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/testhelpers" plugintesthelpers "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/testhelpers/plugins" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -39,7 +40,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/commit_store" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_offramp" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" "github.com/smartcontractkit/chainlink/v2/core/store/models" ) @@ -90,7 +90,7 @@ func setupExecTestHarness(t *testing.T) execTestHarness { destClient: th.DestClient, sourceClient: th.SourceClient, sourceWrappedNativeToken: th.Source.WrappedNative.Address(), - leafHasher: hasher.NewLeafHasher(th.Source.ChainSelector, th.Dest.ChainSelector, th.Source.OnRamp.Address(), hasher.NewKeccakCtx()), + leafHasher: hashlib.NewLeafHasher(th.Source.ChainSelector, th.Dest.ChainSelector, th.Source.OnRamp.Address(), hashlib.NewKeccakCtx()), destGasEstimator: destFeeEstimator, }, onchainConfig: th.ExecOnchainConfig, diff --git a/core/services/ocr2/plugins/ccip/ccipevents/client.go b/core/services/ocr2/plugins/ccip/internal/ccipevents/client.go similarity index 100% rename from core/services/ocr2/plugins/ccip/ccipevents/client.go rename to core/services/ocr2/plugins/ccip/internal/ccipevents/client.go diff --git a/core/services/ocr2/plugins/ccip/ccipevents/logpoller.go b/core/services/ocr2/plugins/ccip/internal/ccipevents/logpoller.go similarity index 100% rename from core/services/ocr2/plugins/ccip/ccipevents/logpoller.go rename to core/services/ocr2/plugins/ccip/internal/ccipevents/logpoller.go diff --git a/core/services/ocr2/plugins/ccip/ccipevents/logpoller_test.go b/core/services/ocr2/plugins/ccip/internal/ccipevents/logpoller_test.go similarity index 100% rename from core/services/ocr2/plugins/ccip/ccipevents/logpoller_test.go rename to core/services/ocr2/plugins/ccip/internal/ccipevents/logpoller_test.go diff --git a/core/services/ocr2/plugins/ccip/hasher/hasher.go b/core/services/ocr2/plugins/ccip/internal/hashlib/hasher.go similarity index 98% rename from core/services/ocr2/plugins/ccip/hasher/hasher.go rename to core/services/ocr2/plugins/ccip/internal/hashlib/hasher.go index 2f2a4f555a..6eeb4e2f1d 100644 --- a/core/services/ocr2/plugins/ccip/hasher/hasher.go +++ b/core/services/ocr2/plugins/ccip/internal/hashlib/hasher.go @@ -1,4 +1,4 @@ -package hasher +package hashlib import ( "bytes" diff --git a/core/services/ocr2/plugins/ccip/hasher/hasher_test.go b/core/services/ocr2/plugins/ccip/internal/hashlib/hasher_test.go similarity index 95% rename from core/services/ocr2/plugins/ccip/hasher/hasher_test.go rename to core/services/ocr2/plugins/ccip/internal/hashlib/hasher_test.go index 9fa558efba..856be2358b 100644 --- a/core/services/ocr2/plugins/ccip/hasher/hasher_test.go +++ b/core/services/ocr2/plugins/ccip/internal/hashlib/hasher_test.go @@ -1,4 +1,4 @@ -package hasher +package hashlib import ( "testing" diff --git a/core/services/ocr2/plugins/ccip/hasher/leaf_hasher.go b/core/services/ocr2/plugins/ccip/internal/hashlib/leaf_hasher.go similarity index 99% rename from core/services/ocr2/plugins/ccip/hasher/leaf_hasher.go rename to core/services/ocr2/plugins/ccip/internal/hashlib/leaf_hasher.go index 8a93e8a520..3bcd9e9c85 100644 --- a/core/services/ocr2/plugins/ccip/hasher/leaf_hasher.go +++ b/core/services/ocr2/plugins/ccip/internal/hashlib/leaf_hasher.go @@ -1,4 +1,4 @@ -package hasher +package hashlib import ( "math/big" diff --git a/core/services/ocr2/plugins/ccip/hasher/leaf_hasher_test.go b/core/services/ocr2/plugins/ccip/internal/hashlib/leaf_hasher_test.go similarity index 78% rename from core/services/ocr2/plugins/ccip/hasher/leaf_hasher_test.go rename to core/services/ocr2/plugins/ccip/internal/hashlib/leaf_hasher_test.go index 4acf16dbd4..fd35c14bfc 100644 --- a/core/services/ocr2/plugins/ccip/hasher/leaf_hasher_test.go +++ b/core/services/ocr2/plugins/ccip/internal/hashlib/leaf_hasher_test.go @@ -1,4 +1,4 @@ -package hasher_test +package hashlib import ( "encoding/hex" @@ -6,20 +6,20 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_onramp" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/testhelpers" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" ) func TestHasher(t *testing.T) { sourceChainSelector, destChainSelector := uint64(1), uint64(4) onRampAddress := common.HexToAddress("0x5550000000000000000000000000000000000001") - hashingCtx := hasher.NewKeccakCtx() + hashingCtx := NewKeccakCtx() - hasher := hasher.NewLeafHasher(sourceChainSelector, destChainSelector, onRampAddress, hashingCtx) + hasher := NewLeafHasher(sourceChainSelector, destChainSelector, onRampAddress, hashingCtx) message := evm_2_evm_onramp.InternalEVM2EVMMessage{ SourceChainSelector: sourceChainSelector, @@ -37,7 +37,9 @@ func TestHasher(t *testing.T) { MessageId: [32]byte{}, } - hash, err := hasher.HashLeaf(testhelpers.GenerateCCIPSendLog(t, message)) + pack, err := abihelpers.MessageArgs.Pack(message) + require.NoError(t, err) + hash, err := hasher.HashLeaf(types.Log{Topics: []common.Hash{abihelpers.EventSignatures.SendRequested}, Data: pack}) require.NoError(t, err) // NOTE: Must match spec @@ -62,7 +64,9 @@ func TestHasher(t *testing.T) { MessageId: [32]byte{}, } - hash, err = hasher.HashLeaf(testhelpers.GenerateCCIPSendLog(t, message)) + pack, err = abihelpers.MessageArgs.Pack(message) + require.NoError(t, err) + hash, err = hasher.HashLeaf(types.Log{Topics: []common.Hash{abihelpers.EventSignatures.SendRequested}, Data: pack}) require.NoError(t, err) // NOTE: Must match spec @@ -72,7 +76,7 @@ func TestHasher(t *testing.T) { func TestMetaDataHash(t *testing.T) { sourceChainSelector, destChainSelector := uint64(1), uint64(4) onRampAddress := common.HexToAddress("0x5550000000000000000000000000000000000001") - ctx := hasher.NewKeccakCtx() - hash := hasher.GetMetaDataHash(ctx, ctx.Hash([]byte("EVM2EVMSubscriptionMessagePlus")), sourceChainSelector, onRampAddress, destChainSelector) + ctx := NewKeccakCtx() + hash := GetMetaDataHash(ctx, ctx.Hash([]byte("EVM2EVMSubscriptionMessagePlus")), sourceChainSelector, onRampAddress, destChainSelector) require.Equal(t, "e8b93c9d01a7a72ec6c7235e238701cf1511b267a31fdb78dd342649ee58c08d", hex.EncodeToString(hash[:])) } diff --git a/core/services/ocr2/plugins/ccip/merklemulti/fixtures/merkle_multi_proof_test_vector.go b/core/services/ocr2/plugins/ccip/internal/merklemulti/fixtures/merkle_multi_proof_test_vector.go similarity index 100% rename from core/services/ocr2/plugins/ccip/merklemulti/fixtures/merkle_multi_proof_test_vector.go rename to core/services/ocr2/plugins/ccip/internal/merklemulti/fixtures/merkle_multi_proof_test_vector.go diff --git a/core/services/ocr2/plugins/ccip/merklemulti/merkle_multi.go b/core/services/ocr2/plugins/ccip/internal/merklemulti/merkle_multi.go similarity index 90% rename from core/services/ocr2/plugins/ccip/merklemulti/merkle_multi.go rename to core/services/ocr2/plugins/ccip/internal/merklemulti/merkle_multi.go index 55096b1485..c9031b470e 100644 --- a/core/services/ocr2/plugins/ccip/merklemulti/merkle_multi.go +++ b/core/services/ocr2/plugins/ccip/internal/merklemulti/merkle_multi.go @@ -6,16 +6,16 @@ import ( "github.com/pkg/errors" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" ) -type singleLayerProof[H hasher.Hash] struct { +type singleLayerProof[H hashlib.Hash] struct { nextIndices []int subProof []H sourceFlags []bool } -type Proof[H hasher.Hash] struct { +type Proof[H hashlib.Hash] struct { Hashes []H `json:"hashes"` SourceFlags []bool `json:"source_flags"` } @@ -44,7 +44,7 @@ func siblingIndex(idx int) int { return idx ^ 1 } -func proveSingleLayer[H hasher.Hash](layer []H, indices []int) (singleLayerProof[H], error) { +func proveSingleLayer[H hashlib.Hash](layer []H, indices []int) (singleLayerProof[H], error) { var ( authIndices []int nextIndices []int @@ -77,11 +77,11 @@ func proveSingleLayer[H hasher.Hash](layer []H, indices []int) (singleLayerProof }, nil } -type Tree[H hasher.Hash] struct { +type Tree[H hashlib.Hash] struct { layers [][]H } -func NewTree[H hasher.Hash](ctx hasher.Ctx[H], leafHashes []H) (*Tree[H], error) { +func NewTree[H hashlib.Hash](ctx hashlib.Ctx[H], leafHashes []H) (*Tree[H], error) { if len(leafHashes) == 0 { return nil, errors.New("Cannot construct a tree without leaves") } @@ -131,7 +131,7 @@ func (t *Tree[H]) Prove(indices []int) (Proof[H], error) { return proof, nil } -func computeNextLayer[H hasher.Hash](ctx hasher.Ctx[H], layer []H) ([]H, []H) { +func computeNextLayer[H hashlib.Hash](ctx hashlib.Ctx[H], layer []H) ([]H, []H) { if len(layer) == 1 { return layer, layer } @@ -145,7 +145,7 @@ func computeNextLayer[H hasher.Hash](ctx hasher.Ctx[H], layer []H) ([]H, []H) { return layer, nextLayer } -func VerifyComputeRoot[H hasher.Hash](ctx hasher.Ctx[H], leafHashes []H, proof Proof[H]) (H, error) { +func VerifyComputeRoot[H hashlib.Hash](ctx hashlib.Ctx[H], leafHashes []H, proof Proof[H]) (H, error) { leavesLength := len(leafHashes) proofsLength := len(proof.Hashes) if leavesLength == 0 && proofsLength == 0 { diff --git a/core/services/ocr2/plugins/ccip/merklemulti/merkle_multi_test.go b/core/services/ocr2/plugins/ccip/internal/merklemulti/merkle_multi_test.go similarity index 97% rename from core/services/ocr2/plugins/ccip/merklemulti/merkle_multi_test.go rename to core/services/ocr2/plugins/ccip/internal/merklemulti/merkle_multi_test.go index 0d7f81600a..fc85172158 100644 --- a/core/services/ocr2/plugins/ccip/merklemulti/merkle_multi_test.go +++ b/core/services/ocr2/plugins/ccip/internal/merklemulti/merkle_multi_test.go @@ -10,12 +10,12 @@ import ( "github.com/stretchr/testify/require" "gonum.org/v1/gonum/stat/combin" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/merklemulti/fixtures" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/merklemulti/fixtures" ) var ( - ctx = hasher.NewKeccakCtx() + ctx = hashlib.NewKeccakCtx() a, b, c, d, e, f = ctx.Hash([]byte{0xa}), ctx.Hash([]byte{0xb}), ctx.Hash([]byte{0xc}), ctx.Hash([]byte{0xd}), ctx.Hash([]byte{0xe}), ctx.Hash([]byte{0xf}) ) @@ -87,7 +87,7 @@ func TestSpecFixtureVerifyProof(t *testing.T) { func TestSpecFixtureNewTree(t *testing.T) { for _, testVector := range fixtures.TestVectors { var leafHashes = hashesFromHexStrings(testVector.AllLeafs) - mctx := hasher.NewKeccakCtx() + mctx := hashlib.NewKeccakCtx() tree, err := NewTree(mctx, leafHashes) assert.NoError(t, err) actualRoot := tree.Root() diff --git a/core/services/ocr2/plugins/ccip/backfilled_oracle.go b/core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle.go similarity index 99% rename from core/services/ocr2/plugins/ccip/backfilled_oracle.go rename to core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle.go index 63cd4ad9c8..b4d9da24b3 100644 --- a/core/services/ocr2/plugins/ccip/backfilled_oracle.go +++ b/core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle.go @@ -1,4 +1,4 @@ -package ccip +package oraclelib import ( "context" diff --git a/core/services/ocr2/plugins/ccip/backfilled_oracle_test.go b/core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle_test.go similarity index 80% rename from core/services/ocr2/plugins/ccip/backfilled_oracle_test.go rename to core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle_test.go index 8b40c2e84e..f2fe03e7b6 100644 --- a/core/services/ocr2/plugins/ccip/backfilled_oracle_test.go +++ b/core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle_test.go @@ -1,4 +1,4 @@ -package ccip_test +package oraclelib import ( "testing" @@ -10,10 +10,9 @@ import ( lpmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller/mocks" "github.com/smartcontractkit/chainlink/v2/core/logger" jobmocks "github.com/smartcontractkit/chainlink/v2/core/services/job/mocks" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip" ) -func TestOracleBackfill(t *testing.T) { +func TestBackfilledOracle(t *testing.T) { // First scenario: Start() fails, check that all Replay are being called. lp1 := lpmocks.NewLogPoller(t) lp2 := lpmocks.NewLogPoller(t) @@ -21,7 +20,7 @@ func TestOracleBackfill(t *testing.T) { lp2.On("Replay", mock.Anything, int64(2)).Return(nil) oracle1 := jobmocks.NewServiceCtx(t) oracle1.On("Start", mock.Anything).Return(errors.New("Failed to start")).Twice() - job := ccip.NewBackfilledOracle(logger.TestLogger(t), lp1, lp2, 1, 2, oracle1) + job := NewBackfilledOracle(logger.TestLogger(t), lp1, lp2, 1, 2, oracle1) job.Run() assert.False(t, job.IsRunning()) @@ -33,7 +32,7 @@ func TestOracleBackfill(t *testing.T) { oracle2.On("Start", mock.Anything).Return(nil).Twice() oracle2.On("Close").Return(nil).Once() - job2 := ccip.NewBackfilledOracle(logger.TestLogger(t), lp1, lp2, 1, 2, oracle2) + job2 := NewBackfilledOracle(logger.TestLogger(t), lp1, lp2, 1, 2, oracle2) job2.Run() assert.True(t, job2.IsRunning()) assert.Nil(t, job2.Close()) @@ -50,7 +49,7 @@ func TestOracleBackfill(t *testing.T) { lp12.On("Replay", mock.Anything, int64(2)).Return(errors.New("Replay failed")).Once() oracle := jobmocks.NewServiceCtx(t) - job3 := ccip.NewBackfilledOracle(logger.NullLogger, lp11, lp12, 1, 2, oracle) + job3 := NewBackfilledOracle(logger.NullLogger, lp11, lp12, 1, 2, oracle) job3.Run() assert.False(t, job3.IsRunning()) } diff --git a/core/services/ocr2/plugins/ccip/plugins_common.go b/core/services/ocr2/plugins/ccip/plugins_common.go index f3df762e55..9fed840bd9 100644 --- a/core/services/ocr2/plugins/ccip/plugins_common.go +++ b/core/services/ocr2/plugins/ccip/plugins_common.go @@ -19,9 +19,9 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_onramp_1_0_0" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_onramp_1_1_0" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipevents" ccipconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/config" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipevents" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/observability" "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -140,7 +140,7 @@ func calculateUsdPerUnitGas(sourceGasPrice *big.Int, usdPerFeeCoin *big.Int) *bi func leavesFromIntervals( lggr logger.Logger, interval commit_store.CommitStoreInterval, - hasher hasher.LeafHasherInterface[[32]byte], + hasher hashlib.LeafHasherInterface[[32]byte], sendReqs []ccipevents.Event[evm_2_evm_onramp.EVM2EVMOnRampCCIPSendRequested], ) ([][32]byte, error) { var seqNrs []uint64 diff --git a/core/services/ocr2/plugins/ccip/testhelpers/ccip_contracts.go b/core/services/ocr2/plugins/ccip/testhelpers/ccip_contracts.go index 96f015c95e..feeb7528c0 100644 --- a/core/services/ocr2/plugins/ccip/testhelpers/ccip_contracts.go +++ b/core/services/ocr2/plugins/ccip/testhelpers/ccip_contracts.go @@ -38,8 +38,8 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/shared/generated/burn_mint_erc677" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/merklemulti" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/merklemulti" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -1345,8 +1345,8 @@ func (args *ManualExecArgs) execute(report *commit_store.CommitStoreCommitReport log.Info().Msg("Executing request manually") seqNr := args.seqNr // Build a merkle tree for the report - mctx := hasher.NewKeccakCtx() - leafHasher := hasher.NewLeafHasher(args.SourceChainID, args.DestChainID, common.HexToAddress(args.OnRamp), mctx) + mctx := hashlib.NewKeccakCtx() + leafHasher := hashlib.NewLeafHasher(args.SourceChainID, args.DestChainID, common.HexToAddress(args.OnRamp), mctx) onRampContract, err := evm_2_evm_onramp.NewEVM2EVMOnRamp(common.HexToAddress(args.OnRamp), args.SourceChain) if err != nil { return nil, err diff --git a/core/services/ocr2/plugins/ccip/testhelpers/plugins/plugin_harness.go b/core/services/ocr2/plugins/ccip/testhelpers/plugins/plugin_harness.go index 3ac293d6ae..61de46d80f 100644 --- a/core/services/ocr2/plugins/ccip/testhelpers/plugins/plugin_harness.go +++ b/core/services/ocr2/plugins/ccip/testhelpers/plugins/plugin_harness.go @@ -21,8 +21,8 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" ccipconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/config" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/hasher" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/merklemulti" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/hashlib" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/merklemulti" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/testhelpers" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -195,8 +195,8 @@ func (mb MessageBatch) ToExecutionReport() evm_2_evm_offramp.InternalExecutionRe } func (th *CCIPPluginTestHarness) GenerateAndSendMessageBatch(t *testing.T, nMessages int, payloadSize int, nTokensPerMessage int) MessageBatch { - mctx := hasher.NewKeccakCtx() - leafHasher := hasher.NewLeafHasher(th.Source.ChainSelector, th.Dest.ChainSelector, th.Source.OnRamp.Address(), mctx) + mctx := hashlib.NewKeccakCtx() + leafHasher := hashlib.NewLeafHasher(th.Source.ChainSelector, th.Dest.ChainSelector, th.Source.OnRamp.Address(), mctx) maxPayload := make([]byte, payloadSize) for i := 0; i < payloadSize; i++ { From 7f8d4c9477d86db2d294a63402a2a6612c1ca7cd Mon Sep 17 00:00:00 2001 From: dimkouv Date: Wed, 13 Sep 2023 13:31:36 +0300 Subject: [PATCH 2/2] add cache --- .gitignore | 2 +- .../ocr2/plugins/ccip/internal/cache/cache.go | 138 +++++++++ .../plugins/ccip/internal/cache/cache_mock.go | 53 ++++ .../plugins/ccip/internal/cache/cache_test.go | 178 ++++++++++++ .../ccip/internal/cache/snoozed_roots.go | 65 +++++ .../ccip/internal/cache/snoozed_roots_test.go | 40 +++ .../plugins/ccip/internal/cache/tokens.go | 266 ++++++++++++++++++ .../ccip/internal/cache/tokens_test.go | 227 +++++++++++++++ 8 files changed, 968 insertions(+), 1 deletion(-) create mode 100644 core/services/ocr2/plugins/ccip/internal/cache/cache.go create mode 100644 core/services/ocr2/plugins/ccip/internal/cache/cache_mock.go create mode 100644 core/services/ocr2/plugins/ccip/internal/cache/cache_test.go create mode 100644 core/services/ocr2/plugins/ccip/internal/cache/snoozed_roots.go create mode 100644 core/services/ocr2/plugins/ccip/internal/cache/snoozed_roots_test.go create mode 100644 core/services/ocr2/plugins/ccip/internal/cache/tokens.go create mode 100644 core/services/ocr2/plugins/ccip/internal/cache/tokens_test.go diff --git a/.gitignore b/.gitignore index de57cfb672..cd6a6ea09c 100644 --- a/.gitignore +++ b/.gitignore @@ -75,7 +75,7 @@ MacOSX* cache core/services/ocr2/plugins/ccip/transactions.rlp lcov.info -!core/services/ocr2/plugins/ccip/cache/ +!core/services/ocr2/plugins/ccip/internal/cache/ core/scripts/ccip/json/credentials diff --git a/core/services/ocr2/plugins/ccip/internal/cache/cache.go b/core/services/ocr2/plugins/ccip/internal/cache/cache.go new file mode 100644 index 0000000000..1668d6cf33 --- /dev/null +++ b/core/services/ocr2/plugins/ccip/internal/cache/cache.go @@ -0,0 +1,138 @@ +package cache + +import ( + "context" + "sync" + + "github.com/ethereum/go-ethereum/common" + + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" + "github.com/smartcontractkit/chainlink/v2/core/services/pg" +) + +// AutoSync cache provides only a Get method, the expiration and syncing is a black-box for the caller. +// +//go:generate mockery --quiet --name AutoSync --output . --filename cache_mock.go --inpackage --case=underscore +type AutoSync[T any] interface { + Get(ctx context.Context) (T, error) +} + +// CachedChain represents caching on-chain calls based on the events read from logpoller.LogPoller. +// Instead of directly going to on-chain to fetch data, we start with checking logpoller.LogPoller events (database request). +// If we discover that change occurred since last update, we perform RPC to the chain using ContractOrigin.CallOrigin function. +// Purpose of this struct is handle common logic in a single place, you only need to override methods from ContractOrigin +// and Get function (behaving as orchestrator) will take care of the rest. +// +// That being said, adding caching layer to the new contract is as simple as: +// * implementing ContractOrigin interface +// * registering proper events in log poller +type CachedChain[T any] struct { + // Static configuration + observedEvents []common.Hash + logPoller logpoller.LogPoller + address []common.Address + optimisticConfirmations int64 + + // Cache + lock *sync.RWMutex + value T + lastChangeBlock int64 + origin ContractOrigin[T] +} + +type ContractOrigin[T any] interface { + // Copy must return copy of the cached data to limit locking section to the minimum + Copy(T) T + // CallOrigin fetches data that is next stored within cache. Usually, should perform RPC to the source (e.g. chain) + CallOrigin(ctx context.Context) (T, error) +} + +// Get is an entry point to the caching. Main function that decides whether cache content is fresh and should be returned +// to the caller, or whether we need to update it's content from on-chain data. +// This decision is made based on the events emitted by Smart Contracts +func (c *CachedChain[T]) Get(ctx context.Context) (T, error) { + var empty T + + lastChangeBlock := c.readLastChangeBlock() + + // Handles first call, because cache is not eagerly populated + if lastChangeBlock == 0 { + return c.initializeCache(ctx) + } + + currentBlockNumber, err := c.logPoller.LatestBlockByEventSigsAddrsWithConfs(lastChangeBlock, c.observedEvents, c.address, int(c.optimisticConfirmations), pg.WithParentCtx(ctx)) + + if err != nil { + return empty, err + } + + // In case of new updates, fetch fresh data from the origin + if currentBlockNumber > lastChangeBlock { + return c.fetchFromOrigin(ctx, currentBlockNumber) + } + return c.copyCachedValue(), nil +} + +// initializeCache performs first call to origin when is not populated yet. +// It's done eagerly, so cache it's populated for the first time when data is needed, not at struct initialization +func (c *CachedChain[T]) initializeCache(ctx context.Context) (T, error) { + var empty T + + // To prevent missing data when blocks are produced after calling the origin, + // we first get the latest block and then call the origin. + latestBlock, err := c.logPoller.LatestBlock(pg.WithParentCtx(ctx)) + if err != nil { + return empty, err + } + + // Init + value, err := c.origin.CallOrigin(ctx) + if err != nil { + return empty, err + } + + c.updateCache(value, latestBlock-c.optimisticConfirmations) + return c.copyCachedValue(), nil + +} + +// fetchFromOrigin fetches data from origin. This action is performed when logpoller.LogPoller says there were events +// emitted since last update. +func (c *CachedChain[T]) fetchFromOrigin(ctx context.Context, currentBlockNumber int64) (T, error) { + var empty T + value, err := c.origin.CallOrigin(ctx) + if err != nil { + return empty, err + } + c.updateCache(value, currentBlockNumber) + + return c.copyCachedValue(), nil +} + +// updateCache performs updating two critical variables for cache to work properly: +// * value that is stored within cache +// * lastChangeBlock representing last seen event from logpoller.LogPoller +func (c *CachedChain[T]) updateCache(newValue T, currentBlockNumber int64) { + c.lock.Lock() + defer c.lock.Unlock() + + // Double-lock checking. No need to update if other goroutine was faster + if currentBlockNumber <= c.lastChangeBlock { + return + } + + c.value = newValue + c.lastChangeBlock = currentBlockNumber +} + +func (c *CachedChain[T]) readLastChangeBlock() int64 { + c.lock.RLock() + defer c.lock.RUnlock() + return c.lastChangeBlock +} + +func (c *CachedChain[T]) copyCachedValue() T { + c.lock.RLock() + defer c.lock.RUnlock() + return c.origin.Copy(c.value) +} diff --git a/core/services/ocr2/plugins/ccip/internal/cache/cache_mock.go b/core/services/ocr2/plugins/ccip/internal/cache/cache_mock.go new file mode 100644 index 0000000000..a5cd3d901d --- /dev/null +++ b/core/services/ocr2/plugins/ccip/internal/cache/cache_mock.go @@ -0,0 +1,53 @@ +// Code generated by mockery v2.28.1. DO NOT EDIT. + +package cache + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// MockAutoSync is an autogenerated mock type for the AutoSync type +type MockAutoSync[T interface{}] struct { + mock.Mock +} + +// Get provides a mock function with given fields: ctx +func (_m *MockAutoSync[T]) Get(ctx context.Context) (T, error) { + ret := _m.Called(ctx) + + var r0 T + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (T, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) T); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(T) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type mockConstructorTestingTNewMockAutoSync interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockAutoSync creates a new instance of MockAutoSync. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockAutoSync[T interface{}](t mockConstructorTestingTNewMockAutoSync) *MockAutoSync[T] { + mock := &MockAutoSync[T]{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/services/ocr2/plugins/ccip/internal/cache/cache_test.go b/core/services/ocr2/plugins/ccip/internal/cache/cache_test.go new file mode 100644 index 0000000000..e37b09b95e --- /dev/null +++ b/core/services/ocr2/plugins/ccip/internal/cache/cache_test.go @@ -0,0 +1,178 @@ +package cache + +import ( + "context" + "strconv" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" + lpMocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller/mocks" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/services/pg" +) + +const ( + cachedValue = "cached_value" +) + +func TestGet_InitDataForTheFirstTime(t *testing.T) { + lp := lpMocks.NewLogPoller(t) + lp.On("LatestBlock", mock.Anything).Maybe().Return(int64(100), nil) + + contract := newCachedContract(lp, "", []string{"value1"}, 0) + + value, err := contract.Get(testutils.Context(t)) + require.NoError(t, err) + require.Equal(t, "value1", value) +} + +func TestGet_ReturnDataFromCacheIfNoNewEvents(t *testing.T) { + latestBlock := int64(100) + lp := lpMocks.NewLogPoller(t) + mockLogPollerQuery(lp, latestBlock) + + contract := newCachedContract(lp, cachedValue, []string{"value1"}, latestBlock) + + value, err := contract.Get(testutils.Context(t)) + require.NoError(t, err) + require.Equal(t, cachedValue, value) +} + +func TestGet_CallOriginForNewEvents(t *testing.T) { + latestBlock := int64(100) + lp := lpMocks.NewLogPoller(t) + m := mockLogPollerQuery(lp, latestBlock+1) + + contract := newCachedContract(lp, cachedValue, []string{"value1", "value2", "value3"}, latestBlock) + + // First call + value, err := contract.Get(testutils.Context(t)) + require.NoError(t, err) + require.Equal(t, "value1", value) + + currentBlock := contract.lastChangeBlock + require.Equal(t, latestBlock+1, currentBlock) + + m.Unset() + mockLogPollerQuery(lp, latestBlock+1) + + // Second call doesn't change anything + value, err = contract.Get(testutils.Context(t)) + require.NoError(t, err) + require.Equal(t, "value1", value) + require.Equal(t, int64(101), contract.lastChangeBlock) +} + +func TestGet_CacheProgressing(t *testing.T) { + firstBlock := int64(100) + secondBlock := int64(105) + thirdBlock := int64(110) + + lp := lpMocks.NewLogPoller(t) + m := mockLogPollerQuery(lp, secondBlock) + + contract := newCachedContract(lp, cachedValue, []string{"value1", "value2", "value3"}, firstBlock) + + // First call + value, err := contract.Get(testutils.Context(t)) + require.NoError(t, err) + require.Equal(t, "value1", value) + require.Equal(t, secondBlock, contract.lastChangeBlock) + + m.Unset() + mockLogPollerQuery(lp, thirdBlock) + + // Second call + value, err = contract.Get(testutils.Context(t)) + require.NoError(t, err) + require.Equal(t, "value2", value) + require.Equal(t, thirdBlock, contract.lastChangeBlock) +} + +func TestGet_ConcurrentAccess(t *testing.T) { + mockedPoller := lpMocks.NewLogPoller(t) + progressingPoller := ProgressingLogPoller{ + LogPoller: mockedPoller, + latestBlock: 1, + } + + iterations := 100 + originValues := make([]string, iterations) + for i := 0; i < iterations; i++ { + originValues[i] = "value_" + strconv.Itoa(i) + } + contract := newCachedContract(&progressingPoller, "empty", originValues, 1) + + var wg sync.WaitGroup + wg.Add(iterations) + for i := 0; i < iterations; i++ { + go func() { + defer wg.Done() + _, _ = contract.Get(testutils.Context(t)) + }() + } + wg.Wait() + + // 1 init block + 100 iterations + require.Equal(t, int64(101), contract.lastChangeBlock) + + // Make sure that recent value is stored in cache + val := contract.copyCachedValue() + require.Equal(t, "value_99", val) +} + +func newCachedContract(lp logpoller.LogPoller, cacheValue string, originValue []string, lastChangeBlock int64) *CachedChain[string] { + return &CachedChain[string]{ + observedEvents: []common.Hash{{}}, + logPoller: lp, + address: []common.Address{{}}, + optimisticConfirmations: 0, + + lock: &sync.RWMutex{}, + value: cacheValue, + lastChangeBlock: lastChangeBlock, + origin: &FakeContractOrigin{values: originValue}, + } +} + +func mockLogPollerQuery(lp *lpMocks.LogPoller, latestBlock int64) *mock.Call { + return lp.On("LatestBlockByEventSigsAddrsWithConfs", mock.Anything, []common.Hash{{}}, []common.Address{{}}, 0, mock.Anything). + Maybe().Return(latestBlock, nil) +} + +type ProgressingLogPoller struct { + *lpMocks.LogPoller + latestBlock int64 + lock sync.Mutex +} + +func (lp *ProgressingLogPoller) LatestBlockByEventSigsAddrsWithConfs(int64, []common.Hash, []common.Address, int, ...pg.QOpt) (int64, error) { + lp.lock.Lock() + defer lp.lock.Unlock() + lp.latestBlock++ + return lp.latestBlock, nil +} + +type FakeContractOrigin struct { + values []string + counter int + lock sync.Mutex +} + +func (f *FakeContractOrigin) CallOrigin(context.Context) (string, error) { + f.lock.Lock() + defer func() { + f.counter++ + f.lock.Unlock() + }() + return f.values[f.counter], nil +} + +func (f *FakeContractOrigin) Copy(value string) string { + return value +} diff --git a/core/services/ocr2/plugins/ccip/internal/cache/snoozed_roots.go b/core/services/ocr2/plugins/ccip/internal/cache/snoozed_roots.go new file mode 100644 index 0000000000..916716942a --- /dev/null +++ b/core/services/ocr2/plugins/ccip/internal/cache/snoozed_roots.go @@ -0,0 +1,65 @@ +package cache + +import ( + "time" + + "github.com/patrickmn/go-cache" +) + +const ( + // EvictionGracePeriod defines how long after the permissionless execution threshold a root is still kept in the cache + EvictionGracePeriod = 1 * time.Hour + // CleanupInterval defines how often roots have to be evicted + CleanupInterval = 30 * time.Minute +) + +type SnoozedRoots interface { + IsSnoozed(merkleRoot [32]byte) bool + MarkAsExecuted(merkleRoot [32]byte) + Snooze(merkleRoot [32]byte) +} + +type snoozedRoots struct { + cache *cache.Cache + // Both rootSnoozedTime and permissionLessExecutionThresholdDuration can be kept in the snoozedRoots without need to be updated. + // Those config properties are populates via onchain/offchain config. When changed, OCR plugin will be restarted and cache initialized with new config. + rootSnoozedTime time.Duration + permissionLessExecutionThresholdDuration time.Duration +} + +func newSnoozedRoots( + permissionLessExecutionThresholdDuration time.Duration, + rootSnoozeTime time.Duration, + evictionGracePeriod time.Duration, + cleanupInterval time.Duration, +) *snoozedRoots { + evictionTime := permissionLessExecutionThresholdDuration + evictionGracePeriod + internalCache := cache.New(evictionTime, cleanupInterval) + + return &snoozedRoots{ + cache: internalCache, + rootSnoozedTime: rootSnoozeTime, + permissionLessExecutionThresholdDuration: permissionLessExecutionThresholdDuration, + } +} + +func NewSnoozedRoots(permissionLessExecutionThresholdDuration time.Duration, rootSnoozeTime time.Duration) *snoozedRoots { + return newSnoozedRoots(permissionLessExecutionThresholdDuration, rootSnoozeTime, EvictionGracePeriod, CleanupInterval) +} + +func (s *snoozedRoots) IsSnoozed(merkleRoot [32]byte) bool { + rawValue, found := s.cache.Get(merkleRootToString(merkleRoot)) + return found && time.Now().Before(rawValue.(time.Time)) +} + +func (s *snoozedRoots) MarkAsExecuted(merkleRoot [32]byte) { + s.cache.SetDefault(merkleRootToString(merkleRoot), time.Now().Add(s.permissionLessExecutionThresholdDuration)) +} + +func (s *snoozedRoots) Snooze(merkleRoot [32]byte) { + s.cache.SetDefault(merkleRootToString(merkleRoot), time.Now().Add(s.rootSnoozedTime)) +} + +func merkleRootToString(merkleRoot [32]byte) string { + return string(merkleRoot[:]) +} diff --git a/core/services/ocr2/plugins/ccip/internal/cache/snoozed_roots_test.go b/core/services/ocr2/plugins/ccip/internal/cache/snoozed_roots_test.go new file mode 100644 index 0000000000..f3813df575 --- /dev/null +++ b/core/services/ocr2/plugins/ccip/internal/cache/snoozed_roots_test.go @@ -0,0 +1,40 @@ +package cache + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSnoozedRoots(t *testing.T) { + c := NewSnoozedRoots(1*time.Minute, 1*time.Minute) + + k1 := [32]byte{1} + k2 := [32]byte{2} + + // return false for non existing element + snoozed := c.IsSnoozed(k1) + assert.False(t, snoozed) + + // after an element is marked as executed it should be snoozed + c.MarkAsExecuted(k1) + snoozed = c.IsSnoozed(k1) + assert.True(t, snoozed) + + // after snoozing an element it should be snoozed + c.Snooze(k2) + snoozed = c.IsSnoozed(k2) + assert.True(t, snoozed) +} + +func TestEvictingElements(t *testing.T) { + c := newSnoozedRoots(1*time.Millisecond, 1*time.Hour, 1*time.Millisecond, 1*time.Millisecond) + + k1 := [32]byte{1} + c.Snooze(k1) + + time.Sleep(1 * time.Second) + + assert.False(t, c.IsSnoozed(k1)) +} diff --git a/core/services/ocr2/plugins/ccip/internal/cache/tokens.go b/core/services/ocr2/plugins/ccip/internal/cache/tokens.go new file mode 100644 index 0000000000..5a53cf964b --- /dev/null +++ b/core/services/ocr2/plugins/ccip/internal/cache/tokens.go @@ -0,0 +1,266 @@ +package cache + +import ( + "context" + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "golang.org/x/exp/slices" + + evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" + "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_offramp" + "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/price_registry" + "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/link_token_interface" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" +) + +// NewCachedFeeTokens cache fee tokens returned from PriceRegistry +func NewCachedFeeTokens( + lp logpoller.LogPoller, + priceRegistry price_registry.PriceRegistryInterface, + optimisticConfirmations int64, +) *CachedChain[[]common.Address] { + return &CachedChain[[]common.Address]{ + observedEvents: []common.Hash{ + abihelpers.EventSignatures.FeeTokenAdded, + abihelpers.EventSignatures.FeeTokenRemoved, + }, + logPoller: lp, + address: []common.Address{priceRegistry.Address()}, + optimisticConfirmations: optimisticConfirmations, + lock: &sync.RWMutex{}, + value: []common.Address{}, + lastChangeBlock: 0, + origin: &feeTokensOrigin{priceRegistry: priceRegistry}, + } +} + +type CachedTokens struct { + SupportedTokens map[common.Address]common.Address + FeeTokens []common.Address +} + +// NewCachedSupportedTokens cache both fee tokens and supported tokens. Therefore, it uses 4 different events +// when checking for changes in logpoller.LogPoller +func NewCachedSupportedTokens( + lp logpoller.LogPoller, + offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface, + priceRegistry price_registry.PriceRegistryInterface, + optimisticConfirmations int64, +) *CachedChain[CachedTokens] { + return &CachedChain[CachedTokens]{ + observedEvents: []common.Hash{ + abihelpers.EventSignatures.FeeTokenAdded, + abihelpers.EventSignatures.FeeTokenRemoved, + abihelpers.EventSignatures.PoolAdded, + abihelpers.EventSignatures.PoolRemoved, + }, + logPoller: lp, + address: []common.Address{priceRegistry.Address(), offRamp.Address()}, + optimisticConfirmations: optimisticConfirmations, + lock: &sync.RWMutex{}, + value: CachedTokens{}, + lastChangeBlock: 0, + origin: &feeAndSupportedTokensOrigin{ + feeTokensOrigin: feeTokensOrigin{priceRegistry: priceRegistry}, + supportedTokensOrigin: supportedTokensOrigin{offRamp: offRamp}}, + } +} + +func NewTokenToDecimals( + lggr logger.Logger, + lp logpoller.LogPoller, + offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface, + priceRegistry price_registry.PriceRegistryInterface, + client evmclient.Client, + optimisticConfirmations int64, +) *CachedChain[map[common.Address]uint8] { + return &CachedChain[map[common.Address]uint8]{ + observedEvents: []common.Hash{ + abihelpers.EventSignatures.FeeTokenAdded, + abihelpers.EventSignatures.FeeTokenRemoved, + abihelpers.EventSignatures.PoolAdded, + abihelpers.EventSignatures.PoolRemoved, + }, + logPoller: lp, + address: []common.Address{priceRegistry.Address(), offRamp.Address()}, + optimisticConfirmations: optimisticConfirmations, + lock: &sync.RWMutex{}, + value: make(map[common.Address]uint8), + lastChangeBlock: 0, + origin: &tokenToDecimals{ + lggr: lggr, + priceRegistry: priceRegistry, + offRamp: offRamp, + tokenFactory: func(token common.Address) (link_token_interface.LinkTokenInterface, error) { + return link_token_interface.NewLinkToken(token, client) + }, + }, + } +} + +type supportedTokensOrigin struct { + offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface +} + +func (t *supportedTokensOrigin) Copy(value map[common.Address]common.Address) map[common.Address]common.Address { + return copyMap(value) +} + +// CallOrigin Generates the source to dest token mapping based on the offRamp. +// NOTE: this queries the offRamp n+1 times, where n is the number of enabled tokens. +func (t *supportedTokensOrigin) CallOrigin(ctx context.Context) (map[common.Address]common.Address, error) { + srcToDstTokenMapping := make(map[common.Address]common.Address) + sourceTokens, err := t.offRamp.GetSupportedTokens(&bind.CallOpts{Context: ctx}) + if err != nil { + return nil, err + } + + seenDestinationTokens := make(map[common.Address]struct{}) + + for _, sourceToken := range sourceTokens { + dst, err1 := t.offRamp.GetDestinationToken(&bind.CallOpts{Context: ctx}, sourceToken) + if err1 != nil { + return nil, err1 + } + + if _, exists := seenDestinationTokens[dst]; exists { + return nil, fmt.Errorf("offRamp misconfig, destination token %s already exists", dst) + } + + seenDestinationTokens[dst] = struct{}{} + srcToDstTokenMapping[sourceToken] = dst + } + return srcToDstTokenMapping, nil +} + +type feeTokensOrigin struct { + priceRegistry price_registry.PriceRegistryInterface +} + +func (t *feeTokensOrigin) Copy(value []common.Address) []common.Address { + return copyArray(value) +} + +func (t *feeTokensOrigin) CallOrigin(ctx context.Context) ([]common.Address, error) { + return t.priceRegistry.GetFeeTokens(&bind.CallOpts{Context: ctx}) +} + +func copyArray(source []common.Address) []common.Address { + dst := make([]common.Address, len(source)) + copy(dst, source) + return dst +} + +type feeAndSupportedTokensOrigin struct { + feeTokensOrigin feeTokensOrigin + supportedTokensOrigin supportedTokensOrigin +} + +func (t *feeAndSupportedTokensOrigin) Copy(value CachedTokens) CachedTokens { + return CachedTokens{ + SupportedTokens: t.supportedTokensOrigin.Copy(value.SupportedTokens), + FeeTokens: t.feeTokensOrigin.Copy(value.FeeTokens), + } +} + +func (t *feeAndSupportedTokensOrigin) CallOrigin(ctx context.Context) (CachedTokens, error) { + supportedTokens, err := t.supportedTokensOrigin.CallOrigin(ctx) + if err != nil { + return CachedTokens{}, err + } + feeToken, err := t.feeTokensOrigin.CallOrigin(ctx) + if err != nil { + return CachedTokens{}, err + } + return CachedTokens{ + SupportedTokens: supportedTokens, + FeeTokens: feeToken, + }, nil +} + +func copyMap[M ~map[K]V, K comparable, V any](m M) M { + cpy := make(M) + for k, v := range m { + cpy[k] = v + } + return cpy +} + +type tokenToDecimals struct { + lggr logger.Logger + offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface + priceRegistry price_registry.PriceRegistryInterface + tokenFactory func(address common.Address) (link_token_interface.LinkTokenInterface, error) + tokenDecimals sync.Map +} + +func (t *tokenToDecimals) Copy(value map[common.Address]uint8) map[common.Address]uint8 { + return copyMap(value) +} + +// CallOrigin Generates the token to decimal mapping for dest tokens and fee tokens. +// NOTE: this queries token decimals n times, where n is the number of tokens whose decimals are not already cached. +func (t *tokenToDecimals) CallOrigin(ctx context.Context) (map[common.Address]uint8, error) { + mapping := make(map[common.Address]uint8) + + destTokens, err := t.offRamp.GetDestinationTokens(&bind.CallOpts{Context: ctx}) + if err != nil { + return nil, err + } + + feeTokens, err := t.priceRegistry.GetFeeTokens(&bind.CallOpts{Context: ctx}) + if err != nil { + return nil, err + } + + // In case if a fee token is not an offramp dest token, we still want to update its decimals and price + for _, feeToken := range feeTokens { + if !slices.Contains(destTokens, feeToken) { + destTokens = append(destTokens, feeToken) + } + } + + for _, token := range destTokens { + if decimals, exists := t.getCachedDecimals(token); exists { + mapping[token] = decimals + continue + } + + tokenContract, err := t.tokenFactory(token) + if err != nil { + return nil, err + } + + decimals, err := tokenContract.Decimals(&bind.CallOpts{Context: ctx}) + if err != nil { + return nil, fmt.Errorf("get token %s decimals: %w", token, err) + } + + t.setCachedDecimals(token, decimals) + mapping[token] = decimals + } + return mapping, nil +} + +func (t *tokenToDecimals) getCachedDecimals(token common.Address) (uint8, bool) { + rawVal, exists := t.tokenDecimals.Load(token.String()) + if !exists { + return 0, false + } + + decimals, isUint8 := rawVal.(uint8) + if !isUint8 { + return 0, false + } + + return decimals, true +} + +func (t *tokenToDecimals) setCachedDecimals(token common.Address, decimals uint8) { + t.tokenDecimals.Store(token.String(), decimals) +} diff --git a/core/services/ocr2/plugins/ccip/internal/cache/tokens_test.go b/core/services/ocr2/plugins/ccip/internal/cache/tokens_test.go new file mode 100644 index 0000000000..b562c44563 --- /dev/null +++ b/core/services/ocr2/plugins/ccip/internal/cache/tokens_test.go @@ -0,0 +1,227 @@ +package cache + +import ( + "context" + "fmt" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + mock_contracts "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/mocks" + "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/link_token_interface" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +func Test_tokenToDecimals(t *testing.T) { + tokenPriceMappings := map[common.Address]uint8{ + common.HexToAddress("0xA"): 10, + common.HexToAddress("0xB"): 5, + common.HexToAddress("0xC"): 2, + } + + tests := []struct { + name string + destTokens []common.Address + feeTokens []common.Address + want map[common.Address]uint8 + wantErr bool + }{ + { + name: "empty map for empty tokens from origin", + destTokens: []common.Address{}, + feeTokens: []common.Address{}, + want: map[common.Address]uint8{}, + }, + { + name: "separate destination and fee tokens", + destTokens: []common.Address{common.HexToAddress("0xC")}, + feeTokens: []common.Address{common.HexToAddress("0xB")}, + want: map[common.Address]uint8{ + common.HexToAddress("0xC"): 2, + common.HexToAddress("0xB"): 5, + }, + }, + { + name: "fee tokens and dest tokens are overlapping", + destTokens: []common.Address{common.HexToAddress("0xA")}, + feeTokens: []common.Address{common.HexToAddress("0xA")}, + want: map[common.Address]uint8{ + common.HexToAddress("0xA"): 10, + }, + }, + { + name: "only fee tokens are returned", + destTokens: []common.Address{}, + feeTokens: []common.Address{common.HexToAddress("0xA"), common.HexToAddress("0xC")}, + want: map[common.Address]uint8{ + common.HexToAddress("0xA"): 10, + common.HexToAddress("0xC"): 2, + }, + }, + { + name: "missing tokens are skipped", + destTokens: []common.Address{}, + feeTokens: []common.Address{common.HexToAddress("0xD")}, + want: map[common.Address]uint8{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + offRamp := &mock_contracts.EVM2EVMOffRampInterface{} + offRamp.On("GetDestinationTokens", mock.Anything).Return(tt.destTokens, nil) + + priceRegistry := &mock_contracts.PriceRegistryInterface{} + priceRegistry.On("GetFeeTokens", mock.Anything).Return(tt.feeTokens, nil) + + tokenToDecimal := &tokenToDecimals{ + lggr: logger.TestLogger(t), + offRamp: offRamp, + priceRegistry: priceRegistry, + tokenFactory: createTokenFactory(tokenPriceMappings), + } + + got, err := tokenToDecimal.CallOrigin(testutils.Context(t)) + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, got) + + // we set token factory to always return an error + // we don't expect it to be used again, decimals should be in cache. + tokenToDecimal.tokenFactory = func(address common.Address) (link_token_interface.LinkTokenInterface, error) { + return nil, fmt.Errorf("some error") + } + got, err = tokenToDecimal.CallOrigin(testutils.Context(t)) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestCallOrigin(t *testing.T) { + src1 := common.HexToAddress("10") + dst1 := common.HexToAddress("11") + src2 := common.HexToAddress("20") + dst2 := common.HexToAddress("21") + + testCases := []struct { + name string + srcTokens []common.Address + srcToDst map[common.Address]common.Address + expErr bool + }{ + { + name: "base", + srcTokens: []common.Address{src1, src2}, + srcToDst: map[common.Address]common.Address{ + src1: dst1, + src2: dst2, + }, + expErr: false, + }, + { + name: "dup dst token", + srcTokens: []common.Address{src1, src2}, + srcToDst: map[common.Address]common.Address{ + src1: dst1, + src2: dst1, + }, + expErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + offRamp := mock_contracts.NewEVM2EVMOffRampInterface(t) + offRamp.On("GetSupportedTokens", mock.Anything).Return(tc.srcTokens, nil) + for src, dst := range tc.srcToDst { + offRamp.On("GetDestinationToken", mock.Anything, src).Return(dst, nil) + } + o := supportedTokensOrigin{offRamp: offRamp} + srcToDst, err := o.CallOrigin(context.Background()) + + if tc.expErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + for src, dst := range tc.srcToDst { + assert.Equal(t, dst, srcToDst[src]) + } + }) + } +} + +func Test_copyArray(t *testing.T) { + t.Run("base", func(t *testing.T) { + a := []common.Address{common.HexToAddress("1"), common.HexToAddress("2")} + b := copyArray(a) + assert.Equal(t, a, b) + b[0] = common.HexToAddress("3") + assert.NotEqual(t, a, b) + }) + + t.Run("empty", func(t *testing.T) { + b := copyArray([]common.Address{}) + assert.Empty(t, b) + }) +} + +func Test_copyMap(t *testing.T) { + t.Run("base", func(t *testing.T) { + val := map[string]int{"a": 100, "b": 50} + cp := copyMap(val) + assert.Len(t, val, 2) + assert.Equal(t, 100, cp["a"]) + assert.Equal(t, 50, cp["b"]) + val["b"] = 10 + assert.Equal(t, 50, cp["b"]) + }) + + t.Run("pointer val", func(t *testing.T) { + val := map[string]*big.Int{"a": big.NewInt(100), "b": big.NewInt(50)} + cp := copyMap(val) + val["a"] = big.NewInt(20) + assert.Equal(t, int64(100), cp["a"].Int64()) + }) +} + +func Test_cachedDecimals(t *testing.T) { + tokenDecimalsCache := &tokenToDecimals{} + addr := utils.RandomAddress() + + decimals, exists := tokenDecimalsCache.getCachedDecimals(addr) + assert.Zero(t, decimals) + assert.False(t, exists) + + tokenDecimalsCache.setCachedDecimals(addr, 123) + decimals, exists = tokenDecimalsCache.getCachedDecimals(addr) + assert.Equal(t, uint8(123), decimals) + assert.True(t, exists) +} + +func createTokenFactory(decimalMapping map[common.Address]uint8) func(address common.Address) (link_token_interface.LinkTokenInterface, error) { + return func(address common.Address) (link_token_interface.LinkTokenInterface, error) { + linkToken := &mock_contracts.LinkTokenInterface{} + if decimals, found := decimalMapping[address]; found { + // Make sure each token is fetched only once + linkToken.On("Decimals", mock.Anything).Return(decimals, nil) + } else { + linkToken.On("Decimals", mock.Anything).Return(uint8(0), errors.New("Error")) + } + return linkToken, nil + } +}