diff --git a/cmd/conformance/aws/main.go b/cmd/conformance/aws/main.go index 8e4626a3..f86c6043 100644 --- a/cmd/conformance/aws/main.go +++ b/cmd/conformance/aws/main.go @@ -43,6 +43,7 @@ var ( dbMaxConns = flag.Int("db_max_conns", 0, "Maximum connections to the database, defaults to 0, i.e unlimited") dbMaxIdle = flag.Int("db_max_idle_conns", 2, "Maximum idle database connections in the connection pool, defaults to 2") signer = flag.String("signer", "", "Note signer to use to sign checkpoints") + publishInterval = flag.Duration("publish_interval", 3*time.Second, "How frequently to publish updated checkpoints") additionalSigners = []string{} ) @@ -64,6 +65,7 @@ func main() { awsCfg := storageConfigFromFlags() storage, err := aws.New(ctx, awsCfg, tessera.WithCheckpointSigner(s, a...), + tessera.WithCheckpointInterval(*publishInterval), tessera.WithBatching(1024, time.Second), tessera.WithPushback(10*4096), ) diff --git a/storage/aws/aws.go b/storage/aws/aws.go index 1970becc..c4766190 100644 --- a/storage/aws/aws.go +++ b/storage/aws/aws.go @@ -46,6 +46,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/smithy-go" "github.com/google/go-cmp/cmp" + "github.com/transparency-dev/merkle/rfc6962" tessera "github.com/transparency-dev/trillian-tessera" "github.com/transparency-dev/trillian-tessera/api" "github.com/transparency-dev/trillian-tessera/api/layout" @@ -75,6 +76,8 @@ type Storage struct { objStore objStore queue *storage.Queue + + treeUpdated chan struct{} } // objStore describes a type which can store and retrieve objects. @@ -82,6 +85,7 @@ type objStore interface { getObject(ctx context.Context, obj string) ([]byte, error) setObject(ctx context.Context, obj string, data []byte, contType string) error setObjectIfNoneMatch(ctx context.Context, obj string, data []byte, contType string) error + lastModified(ctx context.Context, obj string) (time.Time, error) } // sequencer describes a type which knows how to sequence entries. @@ -95,10 +99,14 @@ type sequencer interface { // If forceUpdate is true, then the consumeFunc should be called, with an empty slice of entries if // necessary. This allows the log self-initialise in a transactionally safe manner. consumeEntries(ctx context.Context, limit uint64, f consumeFunc, forceUpdate bool) (bool, error) + + // currentTree returns the sequencer's view of the current tree state. + currentTree(ctx context.Context) (uint64, []byte, error) } // consumeFunc is the signature of a function which can consume entries from the sequencer. -type consumeFunc func(ctx context.Context, from uint64, entries []storage.SequencedEntry) error +// Returns the updated root hash of the tree with the consumed entries integrated. +type consumeFunc func(ctx context.Context, from uint64, entries []storage.SequencedEntry) ([]byte, error) // Config holds AWS project and resource configuration for a storage instance. type Config struct { @@ -113,6 +121,9 @@ type Config struct { } // New creates a new instance of the AWS based Storage. +// +// Storage instances created via this c'tor will participate in integrating newly sequenced entries into the log +// and periodically publishing a new checkpoint which commits to the state of the tree. func New(ctx context.Context, cfg Config, opts ...func(*options.StorageOptions)) (*Storage, error) { opt := storage.ResolveStorageOptions(opts...) if opt.PushbackMaxOutstanding == 0 { @@ -138,6 +149,7 @@ func New(ctx context.Context, cfg Config, opts ...func(*options.StorageOptions)) sequencer: seq, newCP: opt.NewCP, entriesPath: opt.EntriesPath, + treeUpdated: make(chan struct{}), } r.queue = storage.NewQueue(ctx, opt.BatchMaxAge, opt.BatchMaxSize, r.sequencer.assignEntries) @@ -145,29 +157,63 @@ func New(ctx context.Context, cfg Config, opts ...func(*options.StorageOptions)) return nil, fmt.Errorf("failed to initialise log storage: %v", err) } - go func() { - t := time.NewTicker(1 * time.Second) - defer t.Stop() - for { - select { - case <-ctx.Done(): - return - case <-t.C: - } + // Kick off go-routine which handles the integration of entries. + go r.consumeEntriesTask(ctx) - func() { - // Don't quickloop for now, it causes issues updating checkpoint too frequently. - cctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() + // Kick off go-routing which handles the publication of checkpoints. + go r.publishCheckpointTask(ctx, opt.CheckpointInterval) - if _, err := r.sequencer.consumeEntries(cctx, DefaultIntegrationSizeLimit, r.integrate, false); err != nil { - klog.Errorf("integrate: %v", err) - } - }() + return r, nil +} + +// sequenceEntriesTask periodically integrates newly sequenced entries. +// +// This function does not return until the passed context is done. +func (s *Storage) consumeEntriesTask(ctx context.Context) { + t := time.NewTicker(1 * time.Second) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: } - }() - return r, nil + func() { + // Don't quickloop for now, it causes issues updating checkpoint too frequently. + cctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + if _, err := s.sequencer.consumeEntries(cctx, DefaultIntegrationSizeLimit, s.integrate, false); err != nil { + klog.Errorf("integrate: %v", err) + return + } + select { + case s.treeUpdated <- struct{}{}: + default: + } + }() + } +} + +// publishCheckpointTask periodically attempts to publish a new checkpoint representing the current state +// of the tree, once per interval. +// +// This function does not return until the passed in context is done. +func (s *Storage) publishCheckpointTask(ctx context.Context, interval time.Duration) { + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-s.treeUpdated: + case <-t.C: + } + if err := s.publishCheckpoint(ctx, interval); err != nil { + klog.Warningf("publishCheckpoint: %v", err) + } + } } // Add is the entrypoint for adding entries to a sequencing log. @@ -210,6 +256,10 @@ func (s *Storage) init(ctx context.Context) error { if _, err := s.sequencer.consumeEntries(cctx, DefaultIntegrationSizeLimit, s.integrate, true); err != nil { return fmt.Errorf("forced integrate: %v", err) } + select { + case s.treeUpdated <- struct{}{}: + default: + } return nil } return fmt.Errorf("failed to read checkpoint: %v", err) @@ -218,11 +268,26 @@ func (s *Storage) init(ctx context.Context) error { return nil } -func (s *Storage) updateCP(ctx context.Context, newSize uint64, newRoot []byte) error { - cpRaw, err := s.newCP(newSize, newRoot) +func (s *Storage) publishCheckpoint(ctx context.Context, minStaleness time.Duration) error { + m, err := s.objStore.lastModified(ctx, layout.CheckpointPath) + // Do not use errors.Is. Keep errors.As to compare by type and not by value. + var nske *types.NoSuchKey + if err != nil && !errors.As(err, &nske) { + return fmt.Errorf("lastModified(%q): %v", layout.CheckpointPath, err) + } + if time.Since(m) < minStaleness { + return nil + } + + size, root, err := s.sequencer.currentTree(ctx) + if err != nil { + return fmt.Errorf("currentTree: %v", err) + } + cpRaw, err := s.newCP(size, root) if err != nil { return fmt.Errorf("newCP: %v", err) } + if err := s.objStore.setObject(ctx, layout.CheckpointPath, cpRaw, ckptContType); err != nil { return fmt.Errorf("writeCheckpoint: %v", err) } @@ -315,7 +380,11 @@ func (s *Storage) setEntryBundle(ctx context.Context, bundleIndex uint64, logSiz } // integrate incorporates the provided entries into the log starting at fromSeq. -func (s *Storage) integrate(ctx context.Context, fromSeq uint64, entries []storage.SequencedEntry) error { +// +// Returns the new root hash of the log with the entries added. +func (s *Storage) integrate(ctx context.Context, fromSeq uint64, entries []storage.SequencedEntry) ([]byte, error) { + var newRoot []byte + getTiles := func(ctx context.Context, tileIDs []storage.TileID, treeSize uint64) ([]*api.HashTile, error) { n, err := s.getTiles(ctx, tileIDs, treeSize) if err != nil { @@ -334,10 +403,11 @@ func (s *Storage) integrate(ctx context.Context, fromSeq uint64, entries []stora }) errG.Go(func() error { - newSize, newRoot, tiles, err := storage.Integrate(ctx, getTiles, fromSeq, entries) + newSize, root, tiles, err := storage.Integrate(ctx, getTiles, fromSeq, entries) if err != nil { return fmt.Errorf("Integrate: %v", err) } + newRoot = root for k, v := range tiles { func(ctx context.Context, k storage.TileID, v *api.HashTile) { errG.Go(func() error { @@ -345,18 +415,13 @@ func (s *Storage) integrate(ctx context.Context, fromSeq uint64, entries []stora }) }(ctx, k, v) } - errG.Go(func() error { - klog.Infof("New CP: %d, %x", newSize, newRoot) - if s.newCP != nil { - return s.updateCP(ctx, newSize, newRoot) - } - return nil - }) + klog.Infof("New tree: %d, %x", newSize, newRoot) return nil }) - return errG.Wait() + err := errG.Wait() + return newRoot, err } // updateEntryBundles adds the entries being integrated into the entry bundles. @@ -500,6 +565,7 @@ func (s *mySQLSequencer) initDB(ctx context.Context) error { `CREATE TABLE IF NOT EXISTS IntCoord( id INT UNSIGNED NOT NULL, seq BIGINT UNSIGNED NOT NULL, + rootHash TINYBLOB NOT NULL, PRIMARY KEY (id) )`); err != nil { return err @@ -514,7 +580,7 @@ func (s *mySQLSequencer) initDB(ctx context.Context) error { return err } if _, err := s.dbPool.ExecContext(ctx, - `INSERT IGNORE INTO IntCoord (id, seq) VALUES (0, 0)`); err != nil { + `INSERT IGNORE INTO IntCoord (id, seq, rootHash) VALUES (0, 0, ?)`, rfc6962.DefaultHasher.EmptyRoot()); err != nil { return err } return nil @@ -618,9 +684,10 @@ func (s *mySQLSequencer) consumeEntries(ctx context.Context, limit uint64, f con }() // Figure out which is the starting index of sequenced entries to start consuming from. - row := tx.QueryRowContext(ctx, "SELECT seq FROM IntCoord WHERE id = ? FOR UPDATE", 0) + row := tx.QueryRowContext(ctx, "SELECT seq, rootHash FROM IntCoord WHERE id = ? FOR UPDATE", 0) var fromSeq uint64 - if err := row.Scan(&fromSeq); err == sql.ErrNoRows { + var rootHash []byte + if err := row.Scan(&fromSeq, &rootHash); err == sql.ErrNoRows { return false, nil } else if err != nil { return false, fmt.Errorf("failed to read IntCoord: %v", err) @@ -665,13 +732,14 @@ func (s *mySQLSequencer) consumeEntries(ctx context.Context, limit uint64, f con } // Call consumeFunc with the entries we've found - if err := f(ctx, uint64(fromSeq), entries); err != nil { + newRoot, err := f(ctx, uint64(fromSeq), entries) + if err != nil { return false, err } // consumeFunc was successful, so we can update our coordination row, and delete the row(s) for // the then consumed entries. - if _, err := tx.ExecContext(ctx, "UPDATE IntCoord SET seq=? WHERE id=?", orderCheck, 0); err != nil { + if _, err := tx.ExecContext(ctx, "UPDATE IntCoord SET seq=?, rootHash=? WHERE id=?", orderCheck, newRoot, 0); err != nil { return false, fmt.Errorf("update intcoord: %v", err) } @@ -691,6 +759,18 @@ func (s *mySQLSequencer) consumeEntries(ctx context.Context, limit uint64, f con return true, nil } +// currentTree returns the size and root hash of the currently integrated tree. +func (s *mySQLSequencer) currentTree(ctx context.Context) (uint64, []byte, error) { + row := s.dbPool.QueryRowContext(ctx, "SELECT seq, rootHash FROM IntCoord WHERE id = ?", 0) + var fromSeq uint64 + var rootHash []byte + if err := row.Scan(&fromSeq, &rootHash); err != nil { + return 0, nil, fmt.Errorf("failed to read IntCoord: %v", err) + } + + return fromSeq, rootHash, nil +} + func placeholder(n int) string { places := make([]string, n) for i := 0; i < n; i++ { @@ -777,3 +857,16 @@ func (s *s3Storage) setObjectIfNoneMatch(ctx context.Context, objName string, da } return nil } + +// lastModified returns the time the specified object was last modified, or an error +func (s *s3Storage) lastModified(ctx context.Context, obj string) (time.Time, error) { + r, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(obj), + }) + if err != nil { + return time.Time{}, fmt.Errorf("getObject: failed to create reader for object %q in bucket %q: %w", obj, s.bucket, err) + } + + return *r.LastModified, r.Body.Close() +} diff --git a/storage/aws/aws_test.go b/storage/aws/aws_test.go index a078015d..5ce734be 100644 --- a/storage/aws/aws_test.go +++ b/storage/aws/aws_test.go @@ -33,8 +33,9 @@ import ( "reflect" "sync" "testing" + "time" - gcs "cloud.google.com/go/storage" + "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/smithy-go" "github.com/google/go-cmp/cmp" tessera "github.com/transparency-dev/trillian-tessera" @@ -232,18 +233,18 @@ func TestMySQLSequencerRoundTrip(t *testing.T) { } seenIdx := uint64(0) - f := func(_ context.Context, fromSeq uint64, entries []storage.SequencedEntry) error { + f := func(_ context.Context, fromSeq uint64, entries []storage.SequencedEntry) ([]byte, error) { if fromSeq != seenIdx { - return fmt.Errorf("f called with fromSeq %d, want %d", fromSeq, seenIdx) + return nil, fmt.Errorf("f called with fromSeq %d, want %d", fromSeq, seenIdx) } for i, e := range entries { if got, want := e, wantEntries[i]; !reflect.DeepEqual(got, want) { - return fmt.Errorf("entry %d+%d != %d", fromSeq, i, seenIdx) + return nil, fmt.Errorf("entry %d+%d != %d", fromSeq, i, seenIdx) } seenIdx++ } - return nil + return []byte("newroot"), nil } more, err := s.consumeEntries(ctx, 7, f, false) @@ -366,9 +367,80 @@ func TestBundleRoundtrip(t *testing.T) { } } +func TestPublishCheckpoint(t *testing.T) { + ctx := context.Background() + if canSkipMySQLTest(t, ctx) { + klog.Warningf("MySQL not available, skipping %s", t.Name()) + t.Skip("MySQL not available, skipping test") + } + // Clean tables in case there's already something in there. + mustDropTables(t, ctx) + + s, err := newMySQLSequencer(ctx, *mySQLURI, 1000, 0, 0) + if err != nil { + t.Fatalf("newMySQLSequencer: %v", err) + } + + for _, test := range []struct { + name string + cpModifiedAt time.Time + publishInterval time.Duration + wantUpdate bool + }{ + { + name: "works ok", + cpModifiedAt: time.Now().Add(-15 * time.Second), + publishInterval: 10 * time.Second, + wantUpdate: true, + }, { + name: "too soon, skip update", + cpModifiedAt: time.Now().Add(-5 * time.Second), + publishInterval: 10 * time.Second, + wantUpdate: false, + }, + } { + t.Run(test.name, func(t *testing.T) { + m := newMemObjStore() + storage := &Storage{ + objStore: m, + sequencer: s, + entriesPath: layout.EntriesPath, + newCP: func(size uint64, hash []byte) ([]byte, error) { return []byte(fmt.Sprintf("%d/%x,", size, hash)), nil }, + } + // Call init so we've got a zero-sized checkpoint to work with. + if err := storage.init(ctx); err != nil { + t.Fatalf("storage.init: %v", err) + } + cpOld := []byte("bananas") + if err := m.setObject(ctx, layout.CheckpointPath, cpOld, ""); err != nil { + t.Fatalf("setObject(bananas): %v", err) + } + m.lMod = test.cpModifiedAt + if err := storage.publishCheckpoint(ctx, test.publishInterval); err != nil { + t.Fatalf("publishCheckpoint: %v", err) + } + cpNew, err := m.getObject(ctx, layout.CheckpointPath) + cpUpdated := !bytes.Equal(cpOld, cpNew) + if err != nil { + // Do not use errors.Is. Keep errors.As to compare by type and not by value. + var nske *types.NoSuchKey + if !errors.As(err, &nske) { + t.Fatalf("getObject: %v", err) + } + cpUpdated = false + } + if test.wantUpdate != cpUpdated { + t.Fatalf("got cpUpdated=%t, want %t", cpUpdated, test.wantUpdate) + } + }) + } + +} + type memObjStore struct { sync.RWMutex - mem map[string][]byte + mem map[string][]byte + lMod time.Time } func newMemObjStore() *memObjStore { @@ -383,7 +455,7 @@ func (m *memObjStore) getObject(_ context.Context, obj string) ([]byte, error) { d, ok := m.mem[obj] if !ok { - return nil, fmt.Errorf("obj %q not found: %w", obj, gcs.ErrObjectNotExist) + return nil, fmt.Errorf("obj %q not found: %w", obj, &types.NoSuchKey{}) } return d, nil } @@ -408,3 +480,7 @@ func (m *memObjStore) setObjectIfNoneMatch(_ context.Context, obj string, data [ m.mem[obj] = data return nil } + +func (m *memObjStore) lastModified(_ context.Context, obj string) (time.Time, error) { + return m.lMod, nil +} diff --git a/storage/gcp/gcp.go b/storage/gcp/gcp.go index e4954fde..6b5fce2f 100644 --- a/storage/gcp/gcp.go +++ b/storage/gcp/gcp.go @@ -163,10 +163,11 @@ func New(ctx context.Context, cfg Config, opts ...func(*options.StorageOptions)) if _, err := r.sequencer.consumeEntries(cctx, DefaultIntegrationSizeLimit, r.integrate, false); err != nil { klog.Errorf("integrate: %v", err) - select { - case r.cpUpdated <- struct{}{}: - default: - } + return + } + select { + case r.cpUpdated <- struct{}{}: + default: } }() } @@ -372,11 +373,7 @@ func (s *Storage) integrate(ctx context.Context, fromSeq uint64, entries []stora if err != nil { return fmt.Errorf("Integrate: %v", err) } - if newSize > 0 { - newRoot = root - } else { - newRoot = rfc6962.DefaultHasher.EmptyRoot() - } + newRoot = root for k, v := range tiles { func(ctx context.Context, k storage.TileID, v *api.HashTile) { errG.Go(func() error { diff --git a/storage/internal/integrate.go b/storage/internal/integrate.go index 5bad7199..d04ac27b 100644 --- a/storage/internal/integrate.go +++ b/storage/internal/integrate.go @@ -111,6 +111,11 @@ func (t *treeBuilder) integrate(ctx context.Context, fromSize uint64, entries [] } if len(entries) == 0 { klog.V(1).Infof("Nothing to do.") + // C2SP.org/log-tiles says all Merkle operations are those from RFC6962, we need to override + // the root of the empty tree to match (compact.Range will return an empty slice). + if fromSize == 0 { + r = rfc6962.DefaultHasher.EmptyRoot() + } // Nothing to do, nothing done. return fromSize, r, nil, nil }