Skip to content

Commit

Permalink
Add types to commit spec (#1045)
Browse files Browse the repository at this point in the history
## Motivation
Improving readability of spec

## Solution
Add typing
  • Loading branch information
connorwstein authored Jun 19, 2024
1 parent 7f834cc commit 95ff331
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 210 deletions.
71 changes: 38 additions & 33 deletions core/services/ocr3/plugins/ccip/commit/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ func NewPlugin(
) *Plugin {
knownSourceChains := mapset.NewSet[cciptypes.ChainSelector]()
for _, inf := range cfg.ObserverInfo {
knownSourceChains = knownSourceChains.Union(mapset.NewSet(inf.Reads...))
var sources []cciptypes.ChainSelector
for _, chain := range inf.Reads {
if chain != cfg.DestChain {
sources = append(sources, chain)
}
}
knownSourceChains = knownSourceChains.Union(mapset.NewSet(sources...))
}

return &Plugin{
Expand Down Expand Up @@ -96,30 +102,10 @@ func (p *Plugin) Query(_ context.Context, _ ocr3types.OutcomeContext) (types.Que
// We discover the token prices only for the tokens that are used to pay for ccip fees.
// The fee tokens are configured in the plugin config.
func (p *Plugin) Observation(ctx context.Context, outctx ocr3types.OutcomeContext, _ types.Query) (types.Observation, error) {
maxSeqNumsPerChain, seqNumsInSync, err := observeMaxSeqNums(
ctx,
p.lggr,
p.ccipReader,
outctx.PreviousOutcome,
p.readableChains,
p.cfg.DestChain,
p.knownSourceChainsSlice(),
)
if err != nil {
return types.Observation{}, fmt.Errorf("observe max sequence numbers per chain: %w", err)
}

newMsgs, err := observeNewMsgs(
ctx,
p.lggr,
p.ccipReader,
p.msgHasher,
p.readableChains,
maxSeqNumsPerChain,
p.cfg.NewMsgScanBatchSize,
)
msgBaseDetails := make([]cciptypes.CCIPMsgBaseDetails, 0)
latestCommittedSeqNumsObservation, err := observeLatestCommittedSeqNums(ctx, p.lggr, p.ccipReader, p.readableChains, p.cfg.DestChain, p.knownSourceChains.ToSlice())
if err != nil {
return types.Observation{}, fmt.Errorf("observe new messages: %w", err)
return types.Observation{}, fmt.Errorf("observe latest committed sequence numbers: %w", err)
}

var tokenPrices []cciptypes.TokenPrice
Expand All @@ -140,25 +126,44 @@ func (p *Plugin) Observation(ctx context.Context, outctx ocr3types.OutcomeContex
if err != nil {
return types.Observation{}, fmt.Errorf("observe gas prices: %w", err)
}
// If there's no previous outcome (first round ever), we only observe the latest committed sequence numbers.
// and on the next round we use those to look for messages.
if outctx.PreviousOutcome == nil {
p.lggr.Debugw("first round ever, can't observe new messages yet")
return cciptypes.NewCommitPluginObservation(msgBaseDetails, gasPrices, tokenPrices, latestCommittedSeqNumsObservation, p.cfg).Encode()
}

prevOutcome, err := cciptypes.DecodeCommitPluginOutcome(outctx.PreviousOutcome)
if err != nil {
return types.Observation{}, fmt.Errorf("decode commit plugin previous outcome: %w", err)
}
p.lggr.Debugw("previous outcome decoded", "outcome", prevOutcome.String())

// Always observe based on previous outcome. We'll filter out stale messages in the outcome phase.
newMsgs, err := observeNewMsgs(
ctx,
p.lggr,
p.ccipReader,
p.msgHasher,
p.readableChains,
prevOutcome.MaxSeqNums, // TODO: Chainlink common PR to rename.
p.cfg.NewMsgScanBatchSize,
)
if err != nil {
return types.Observation{}, fmt.Errorf("observe new messages: %w", err)
}

p.lggr.Infow("submitting observation",
"observedNewMsgs", len(newMsgs),
"gasPrices", len(gasPrices),
"tokenPrices", len(tokenPrices),
"maxSeqNumsPerChain", maxSeqNumsPerChain,
"latestCommittedSeqNums", latestCommittedSeqNumsObservation,
"observerInfo", p.cfg.ObserverInfo)

msgBaseDetails := make([]cciptypes.CCIPMsgBaseDetails, 0)
for _, msg := range newMsgs {
msgBaseDetails = append(msgBaseDetails, msg.CCIPMsgBaseDetails)
}

if !seqNumsInSync {
// If the node was not able to sync the max sequence numbers we don't want to transmit
// the potentially outdated ones. We expect that a sufficient number of nodes will be able to observe them.
maxSeqNumsPerChain = nil
}
return cciptypes.NewCommitPluginObservation(msgBaseDetails, gasPrices, tokenPrices, maxSeqNumsPerChain, p.cfg).Encode()
return cciptypes.NewCommitPluginObservation(msgBaseDetails, gasPrices, tokenPrices, latestCommittedSeqNumsObservation, p.cfg).Encode()
}

func (p *Plugin) ValidateObservation(_ ocr3types.OutcomeContext, _ types.Query, ao types.AttributedObservation) error {
Expand Down
34 changes: 29 additions & 5 deletions core/services/ocr3/plugins/ccip/commit/plugin_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"reflect"
"testing"

"github.com/stretchr/testify/require"

"github.com/stretchr/testify/assert"

"github.com/smartcontractkit/libocr/commontypes"
Expand All @@ -29,6 +31,7 @@ func TestPlugin(t *testing.T) {
expErr func(*testing.T, error)
expOutcome cciptypes.CommitPluginOutcome
expTransmittedReports []cciptypes.CommitPluginReport
initialOutcome cciptypes.CommitPluginOutcome
}{
{
name: "EmptyOutcome",
Expand Down Expand Up @@ -69,6 +72,15 @@ func TestPlugin(t *testing.T) {
},
},
},
initialOutcome: cciptypes.CommitPluginOutcome{
MaxSeqNums: []cciptypes.SeqNumChain{
{ChainSel: chainA, SeqNum: 10},
{ChainSel: chainB, SeqNum: 20},
},
MerkleRoots: []cciptypes.MerkleRootChain{},
TokenPrices: []cciptypes.TokenPrice{},
GasPrices: []cciptypes.GasPriceChain{},
},
},
{
name: "NodesDoNotAgreeOnMsgs",
Expand Down Expand Up @@ -98,6 +110,15 @@ func TestPlugin(t *testing.T) {
},
},
},
initialOutcome: cciptypes.CommitPluginOutcome{
MaxSeqNums: []cciptypes.SeqNumChain{
{ChainSel: chainA, SeqNum: 10},
{ChainSel: chainB, SeqNum: 20},
},
MerkleRoots: []cciptypes.MerkleRootChain{},
TokenPrices: []cciptypes.TokenPrice{},
GasPrices: []cciptypes.GasPriceChain{},
},
},
}

Expand All @@ -118,7 +139,9 @@ func TestPlugin(t *testing.T) {
for _, n := range nodesSetup {
nodeIDs = append(nodeIDs, n.node.nodeID)
}
runner := testhelpers.NewOCR3Runner(nodes, nodeIDs)
o, err := tc.initialOutcome.Encode()
require.NoError(t, err)
runner := testhelpers.NewOCR3Runner(nodes, nodeIDs, o)

res, err := runner.RunRound(ctx)
if tc.expErr != nil {
Expand Down Expand Up @@ -203,10 +226,6 @@ func setupAllNodesReadAllChains(ctx context.Context, t *testing.T, lggr logger.L
nodes := []nodeSetup{n1, n2, n3}

for _, n := range nodes {
// all nodes observe the same sequence numbers 10 for chainA and 20 for chainB
n.ccipReader.On("NextSeqNum", ctx, []cciptypes.ChainSelector{chainA, chainB}).
Return([]cciptypes.SeqNum{10, 20}, nil)

// then they fetch new msgs, there is nothing new on chainA
n.ccipReader.On(
"MsgsBetweenSeqNums",
Expand All @@ -231,6 +250,11 @@ func setupAllNodesReadAllChains(ctx context.Context, t *testing.T, lggr logger.L
cciptypes.NewBigIntFromInt64(1000),
cciptypes.NewBigIntFromInt64(20_000),
}, nil)

// all nodes observe the same sequence numbers 10 for chainA and 20 for chainB
n.ccipReader.On("NextSeqNum", ctx, []cciptypes.ChainSelector{chainA, chainB}).
Return([]cciptypes.SeqNum{10, 20}, nil)

}

return nodes
Expand Down
66 changes: 17 additions & 49 deletions core/services/ocr3/plugins/ccip/commit/plugin_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,80 +19,48 @@ import (
"github.com/smartcontractkit/chainlink-common/pkg/merklemulti"
)

// observeMaxSeqNums finds the maximum committed sequence numbers for each source chain.
// If a sequence number is pending (is not on-chain yet), it will be included in the results.
func observeMaxSeqNums(
// observeLatestCommittedSeqNums finds the maximum committed sequence numbers for each source chain.
// If we cannot observe the dest we return an empty slice and no error..
func observeLatestCommittedSeqNums(
ctx context.Context,
lggr logger.Logger,
ccipReader cciptypes.CCIPReader,
previousOutcomeBytes []byte,
readableChains mapset.Set[cciptypes.ChainSelector],
destChain cciptypes.ChainSelector,
knownSourceChains []cciptypes.ChainSelector,
) ([]cciptypes.SeqNumChain, bool, error) {
seqNumsInSync := false

// If there is a previous outcome, start with the sequence numbers of it.
seqNumPerChain := make(map[cciptypes.ChainSelector]cciptypes.SeqNum)
if previousOutcomeBytes != nil {
lggr.Debugw("observing based on previous outcome")
prevOutcome, err := cciptypes.DecodeCommitPluginOutcome(previousOutcomeBytes)
if err != nil {
return nil, false, fmt.Errorf("decode commit plugin previous outcome: %w", err)
}
lggr.Debugw("previous outcome decoded", "outcome", prevOutcome.String())

for _, seqNumChain := range prevOutcome.MaxSeqNums {
if seqNumChain.SeqNum > seqNumPerChain[seqNumChain.ChainSel] {
seqNumPerChain[seqNumChain.ChainSel] = seqNumChain.SeqNum
}
}
lggr.Debugw("discovered sequence numbers from prev outcome", "seqNumPerChain", seqNumPerChain)
}

// If reading destination chain is supported find the latest sequence numbers per chain from the onchain state.
) ([]cciptypes.SeqNumChain, error) {
sort.Slice(knownSourceChains, func(i, j int) bool { return knownSourceChains[i] < knownSourceChains[j] })
latestCommittedSeqNumsObservation := make([]cciptypes.SeqNumChain, 0)
if readableChains.Contains(destChain) {
lggr.Debugw("reading sequence numbers from destination")
onChainSeqNums, err := ccipReader.NextSeqNum(ctx, knownSourceChains)
lggr.Debugw("reading latest committed sequence from destination")
onChainLatestCommittedSeqNums, err := ccipReader.NextSeqNum(ctx, knownSourceChains)
if err != nil {
return nil, false, fmt.Errorf("get next seq nums: %w", err)
return latestCommittedSeqNumsObservation, fmt.Errorf("get next seq nums: %w", err)
}
lggr.Debugw("discovered sequence numbers from destination", "onChainSeqNums", onChainSeqNums)

// Update the seq nums if the on-chain sequence number is greater than previous outcome.
lggr.Debugw("observed latest committed sequence numbers on destination", "latestCommittedSeqNumsObservation", onChainLatestCommittedSeqNums)
for i, ch := range knownSourceChains {
if onChainSeqNums[i] > seqNumPerChain[ch] {
seqNumPerChain[ch] = onChainSeqNums[i]
lggr.Debugw("updated sequence number", "chain", ch, "seqNum", onChainSeqNums[i])
}
latestCommittedSeqNumsObservation = append(latestCommittedSeqNumsObservation, cciptypes.NewSeqNumChain(ch, onChainLatestCommittedSeqNums[i]))
}
seqNumsInSync = true
}

maxChainSeqNums := make([]cciptypes.SeqNumChain, 0)
for ch, seqNum := range seqNumPerChain {
maxChainSeqNums = append(maxChainSeqNums, cciptypes.NewSeqNumChain(ch, seqNum))
}

sort.Slice(maxChainSeqNums, func(i, j int) bool { return maxChainSeqNums[i].ChainSel < maxChainSeqNums[j].ChainSel })
return maxChainSeqNums, seqNumsInSync, nil
return latestCommittedSeqNumsObservation, nil
}

// observeNewMsgs finds the new messages for each supported chain based on the provided max sequence numbers.
// If latestCommitSeqNums is empty (first ever OCR round), it will return an empty slice.
func observeNewMsgs(
ctx context.Context,
lggr logger.Logger,
ccipReader cciptypes.CCIPReader,
msgHasher cciptypes.MessageHasher,
readableChains mapset.Set[cciptypes.ChainSelector],
maxSeqNumsPerChain []cciptypes.SeqNumChain,
latestCommittedSeqNums []cciptypes.SeqNumChain,
msgScanBatchSize int,
) ([]cciptypes.CCIPMsg, error) {
// Find the new msgs for each supported chain based on the discovered max sequence numbers.
newMsgsPerChain := make([][]cciptypes.CCIPMsg, len(maxSeqNumsPerChain))
newMsgsPerChain := make([][]cciptypes.CCIPMsg, len(latestCommittedSeqNums))
eg := new(errgroup.Group)

for chainIdx, seqNumChain := range maxSeqNumsPerChain {
for chainIdx, seqNumChain := range latestCommittedSeqNums {
if !readableChains.Contains(seqNumChain.ChainSel) {
lggr.Debugw("reading chain is not supported", "chain", seqNumChain.ChainSel)
continue
Expand Down Expand Up @@ -140,7 +108,7 @@ func observeNewMsgs(
}

observedNewMsgs := make([]cciptypes.CCIPMsg, 0)
for chainIdx := range maxSeqNumsPerChain {
for chainIdx := range latestCommittedSeqNums {
observedNewMsgs = append(observedNewMsgs, newMsgsPerChain[chainIdx]...)
}
return observedNewMsgs, nil
Expand Down
Loading

0 comments on commit 95ff331

Please sign in to comment.