diff --git a/client/client.go b/client/client.go index d4bbfaa..8ba8b3a 100644 --- a/client/client.go +++ b/client/client.go @@ -318,12 +318,12 @@ type LogStateTracker struct { // LatestConsistentRaw holds the raw bytes of the latest proven-consistent // LogState seen by this tracker. LatestConsistentRaw []byte - // LatestConsistent is the deserialised form of LatestConsistentRaw LatestConsistent log.Checkpoint - // The note with signatures and other metadata about the checkpoint CheckpointNote *note.Note + // ProofBuilder for building proofs at LatestConsistent checkpoint. + ProofBuilder *ProofBuilder CpSigVerifier note.Verifier } @@ -348,6 +348,10 @@ func NewLogStateTracker(ctx context.Context, f Fetcher, h merkle.LogHasher, chec return ret, err } ret.LatestConsistent = *cp + ret.ProofBuilder, err = NewProofBuilder(ctx, ret.LatestConsistent, ret.Hasher.HashChildren, ret.Fetcher) + if err != nil { + return ret, fmt.Errorf("NewProofBuilder: %v", err) + } return ret, nil } _, _, _, err := ret.Update(ctx) @@ -387,29 +391,33 @@ func (lst *LogStateTracker) Update(ctx context.Context) ([]byte, [][]byte, []byt if err != nil { return nil, nil, nil, err } + builder, err := NewProofBuilder(ctx, *c, lst.Hasher.HashChildren, lst.Fetcher) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to create proof builder: %w", err) + } var p [][]byte if lst.LatestConsistent.Size > 0 { - if c.Size > lst.LatestConsistent.Size { - builder, err := NewProofBuilder(ctx, *c, lst.Hasher.HashChildren, lst.Fetcher) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to create proof builder: %w", err) - } - p, err = builder.ConsistencyProof(ctx, lst.LatestConsistent.Size, c.Size) - if err != nil { - return nil, nil, nil, err - } - if err := proof.VerifyConsistency(lst.Hasher, lst.LatestConsistent.Size, c.Size, p, lst.LatestConsistent.Hash, c.Hash); err != nil { - return nil, nil, nil, ErrInconsistency{ - SmallerRaw: lst.LatestConsistentRaw, - LargerRaw: cRaw, - Proof: p, - Wrapped: err, - } + if c.Size <= lst.LatestConsistent.Size { + return lst.LatestConsistentRaw, p, lst.LatestConsistentRaw, nil + } + p, err = builder.ConsistencyProof(ctx, lst.LatestConsistent.Size, c.Size) + if err != nil { + return nil, nil, nil, err + } + if err := proof.VerifyConsistency(lst.Hasher, lst.LatestConsistent.Size, c.Size, p, lst.LatestConsistent.Hash, c.Hash); err != nil { + return nil, nil, nil, ErrInconsistency{ + SmallerRaw: lst.LatestConsistentRaw, + LargerRaw: cRaw, + Proof: p, + Wrapped: err, } } + // Update is consistent, + } oldRaw := lst.LatestConsistentRaw lst.LatestConsistentRaw, lst.LatestConsistent, lst.CheckpointNote = cRaw, *c, cn + lst.ProofBuilder = builder return oldRaw, p, lst.LatestConsistentRaw, nil } diff --git a/client/client_test.go b/client/client_test.go index c760d8f..dd8c513 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -17,65 +17,187 @@ package client import ( "bytes" "context" - "encoding/base64" + "errors" + "fmt" "os" "path/filepath" + "strings" "testing" "github.com/transparency-dev/formats/log" "github.com/transparency-dev/merkle/compact" "github.com/transparency-dev/merkle/rfc6962" "github.com/transparency-dev/serverless-log/api" + "golang.org/x/mod/sumdb/note" ) var ( + testOrigin = "Log Checkpoint v0" + testLogVerifier = mustMakeVerifier("astra+cad5a3d2+AZJqeuyE/GnknsCNh1eCtDtwdAwKBddOlS8M2eI1Jt4b") // Built using serverless/testdata/build_log.sh - testCheckpoints = []log.Checkpoint{ - { - Size: 1, - Hash: b64("0Nc2CrefWKseHj/mStd+LqC8B+NrX0btIiPt2SmN+ek="), - }, - { - Size: 2, - Hash: b64("T1X2GdkhUjV3iyufF9b0kVsWFxIU0VI4EpNml2Teci4="), - }, - { - Size: 3, - Hash: b64("Wqx3HImawpLnS/Gv4ubjAvi1WIOy0b8Ze0amvqbavKk="), - }, - { - Size: 4, - Hash: b64("zY1lN35vrXYAPixXSd59LsU29xUJtuW4o2dNNg5Y2Co="), - }, - { - Size: 5, - Hash: b64("gy5gl3aksFyiCO95a/1vLXz88A3dRq+0l9Sxte8ZqZQ="), - }, - { - Size: 6, - Hash: b64("a6sWvsc2eEzmj72vah7mZ5dwFltivehh2b11qwlp5Jg="), - }, - { - Size: 7, - Hash: b64("IrSXADBqJ7EQoUODSDKROySgNveeL6CFhik2w/+fS7U="), - }, - { - Size: 14, - Hash: b64("SvCd38yNade7xEPY1a/aAc1M3A2AHYVF8lIiUnsH1ao="), - }, - { - Size: 15, - Hash: b64("rKbDipCvhuX1GZ7g5BBe8sA6BbJ7ja/1nk427v383cs="), - }, - } + testRawCheckpoints, testCheckpoints = mustLoadTestCheckpoints() ) -func b64(r string) []byte { - ret, err := base64.StdEncoding.DecodeString(r) +func mustMakeVerifier(vs string) note.Verifier { + v, err := note.NewVerifier(vs) if err != nil { - panic(err) + panic(fmt.Errorf("NewVerifier(%q): %v", vs, err)) + } + return v +} + +func mustLoadTestCheckpoints() ([][]byte, []log.Checkpoint) { + raws, cps := make([][]byte, 0), make([]log.Checkpoint, 0) + for i := 1; ; i++ { + cpName := fmt.Sprintf("checkpoint.%d", i) + r, err := testLogFetcher(context.Background(), cpName) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + // Probably just no more checkpoints left + break + } + panic(err) + } + cp, _, _, err := log.ParseCheckpoint(r, testOrigin, testLogVerifier) + if err != nil { + panic(fmt.Errorf("ParseCheckpoint(%s): %v", cpName, err)) + } + raws, cps = append(raws, r), append(cps, *cp) + } + if len(raws) == 0 { + panic("no checkpoints loaded") + } + return raws, cps +} + +// testLogFetcher is a fetcher which reads from the checked-in golden test log +// data stored in ../testdata/log +func testLogFetcher(_ context.Context, p string) ([]byte, error) { + path := filepath.Join("../testdata/log", p) + return os.ReadFile(path) +} + +// fetchCheckpointShim allows fetcher requests for checkpoints to be intercepted. +type fetchCheckpointShim struct { + // Checkpoints holds raw checkpoints to be returned when the fetcher is asked to retrieve a checkpoint path. + // The zero-th entry will be returned until Advance is called. + Checkpoints [][]byte +} + +// Fetcher intercepts requests for the checkpoint file, returning the zero-th +// entry in the Checkpoints field. All other requests are passed through +// to the delegate fetcher. +func (f *fetchCheckpointShim) Fetcher(deleg Fetcher) Fetcher { + return func(ctx context.Context, path string) ([]byte, error) { + if strings.HasSuffix(path, "checkpoint") { + if len(f.Checkpoints) == 0 { + return nil, os.ErrNotExist + } + r := f.Checkpoints[0] + return r, nil + } + return deleg(ctx, path) + } +} + +// Advance causes subsequent intercepted checkpoint requests to return +// the next entry in the Checkpoints slice. +func (f *fetchCheckpointShim) Advance() { + f.Checkpoints = f.Checkpoints[1:] +} + +func TestCheckLogStateTracker(t *testing.T) { + ctx := context.Background() + h := rfc6962.DefaultHasher + + for _, test := range []struct { + desc string + cpRaws [][]byte + wantCpRaws [][]byte + }{ + { + desc: "Consistent", + cpRaws: [][]byte{ + testRawCheckpoints[0], + testRawCheckpoints[2], + testRawCheckpoints[3], + testRawCheckpoints[5], + testRawCheckpoints[6], + testRawCheckpoints[10], + }, + wantCpRaws: [][]byte{ + testRawCheckpoints[0], + testRawCheckpoints[2], + testRawCheckpoints[3], + testRawCheckpoints[5], + testRawCheckpoints[6], + testRawCheckpoints[10], + }, + }, { + desc: "Identical CP", + cpRaws: [][]byte{ + testRawCheckpoints[0], + testRawCheckpoints[0], + testRawCheckpoints[0], + testRawCheckpoints[0], + }, + wantCpRaws: [][]byte{ + testRawCheckpoints[0], + testRawCheckpoints[0], + testRawCheckpoints[0], + testRawCheckpoints[0], + }, + }, { + desc: "Identical CP pairs", + cpRaws: [][]byte{ + testRawCheckpoints[0], + testRawCheckpoints[0], + testRawCheckpoints[5], + testRawCheckpoints[5], + }, + wantCpRaws: [][]byte{ + testRawCheckpoints[0], + testRawCheckpoints[0], + testRawCheckpoints[5], + testRawCheckpoints[5], + }, + }, { + desc: "Out of order", + cpRaws: [][]byte{ + testRawCheckpoints[5], + testRawCheckpoints[2], + testRawCheckpoints[0], + testRawCheckpoints[3], + }, + wantCpRaws: [][]byte{ + testRawCheckpoints[5], + testRawCheckpoints[5], + testRawCheckpoints[5], + testRawCheckpoints[5], + }, + }, + } { + t.Run(test.desc, func(t *testing.T) { + shim := fetchCheckpointShim{Checkpoints: test.cpRaws} + f := shim.Fetcher(testLogFetcher) + lst, err := NewLogStateTracker(ctx, f, h, testRawCheckpoints[0], testLogVerifier, testOrigin, UnilateralConsensus(f)) + if err != nil { + t.Fatalf("NewLogStateTracker: %v", err) + } + + for i := range test.cpRaws { + _, _, newCP, err := lst.Update(ctx) + if err != nil { + t.Errorf("Update %d: %v", i, err) + } + if got, want := newCP, test.wantCpRaws[i]; !bytes.Equal(got, want) { + t.Errorf("Update moved to:\n%s\nwant:\n%s", string(got), string(want)) + } + + shim.Advance() + } + }) } - return ret } func TestCheckConsistency(t *testing.T) { @@ -83,10 +205,6 @@ func TestCheckConsistency(t *testing.T) { h := rfc6962.DefaultHasher - f := func(_ context.Context, p string) ([]byte, error) { - path := filepath.Join("../testdata/log", p) - return os.ReadFile(path) - } for _, test := range []struct { desc string cp []log.Checkpoint @@ -186,7 +304,7 @@ func TestCheckConsistency(t *testing.T) { }, } { t.Run(test.desc, func(t *testing.T) { - err := CheckConsistency(ctx, h, f, test.cp) + err := CheckConsistency(ctx, h, testLogFetcher, test.cp) if gotErr := err != nil; gotErr != test.wantErr { t.Fatalf("wantErr: %t, got %v", test.wantErr, err) }