From c57fd57b255699abdf337cc1e1cd2384034ac536 Mon Sep 17 00:00:00 2001 From: Martin Hutchinson Date: Thu, 5 Dec 2024 14:25:00 +0000 Subject: [PATCH] [MySQL Conformance] fix tree size bug (#382) Major changes: - MySQL storage read methods return os.ErrNotExist when values aren't found - ReadTile returns an error if the user requests more data than we have available - Added tests for writing and reading data from tiles - Made tests hermetic (though slower) by resetting the DB for each test case This got a bit bigger than intended. This fixes #364. --- cmd/conformance/mysql/main.go | 15 +++-- storage/mysql/mysql.go | 54 ++++++++++------- storage/mysql/mysql_test.go | 107 +++++++++++++++++++++++++++++++++- 3 files changed, 146 insertions(+), 30 deletions(-) diff --git a/cmd/conformance/mysql/main.go b/cmd/conformance/mysql/main.go index 7751597b..2fe58d08 100644 --- a/cmd/conformance/mysql/main.go +++ b/cmd/conformance/mysql/main.go @@ -165,17 +165,17 @@ func configureTilesReadAPI(mux *http.ServeMux, storage *mysql.Storage) { } return } - impliedSize := (index*256 + width) << (level * 8) - tile, err := storage.ReadTile(r.Context(), level, index, impliedSize) + inferredMinTreeSize := (index*256 + width) << (level * 8) + tile, err := storage.ReadTile(r.Context(), level, index, inferredMinTreeSize) if err != nil { + if os.IsNotExist(err) { + w.WriteHeader(http.StatusNotFound) + return + } klog.Errorf("/tile/{level}/{index...}: %v", err) w.WriteHeader(http.StatusInternalServerError) return } - if tile == nil { - w.WriteHeader(http.StatusNotFound) - return - } w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") @@ -207,6 +207,9 @@ func configureTilesReadAPI(mux *http.ServeMux, storage *mysql.Storage) { } // TODO: Add immutable Cache-Control header. + // Only do this once we're sure we're returning the right number of entries + // Currently a user can request a full tile and we can return a partial tile. + // If cache headers were set then this could cause caches to be poisoned. if _, err := w.Write(entryBundle); err != nil { klog.Errorf("/tile/entries/{index...}: %v", err) diff --git a/storage/mysql/mysql.go b/storage/mysql/mysql.go index 15e236af..a6d62dd2 100644 --- a/storage/mysql/mysql.go +++ b/storage/mysql/mysql.go @@ -18,9 +18,11 @@ package mysql import ( "bytes" "context" + "crypto/sha256" "database/sql" "errors" "fmt" + "os" "strings" "time" @@ -121,7 +123,7 @@ func (s *Storage) maybeInitTree(ctx context.Context) error { }() treeState, err := s.readTreeState(ctx, tx) - if err != nil { + if err != nil && !os.IsNotExist(err) { klog.Errorf("Failed to read tree state: %v", err) return err } @@ -142,7 +144,7 @@ func (s *Storage) maybeInitTree(ctx context.Context) error { } // ReadCheckpoint returns the latest stored checkpoint. -// If the checkpoint is not found, nil is returned with no error. +// If the checkpoint is not found, it returns os.ErrNotExist. func (s *Storage) ReadCheckpoint(ctx context.Context) ([]byte, error) { row := s.db.QueryRowContext(ctx, selectCheckpointByIDSQL, checkpointID) if err := row.Err(); err != nil { @@ -153,7 +155,7 @@ func (s *Storage) ReadCheckpoint(ctx context.Context) ([]byte, error) { var at int64 if err := row.Scan(&checkpoint, &at); err != nil { if err == sql.ErrNoRows { - return nil, nil + return nil, os.ErrNotExist } return nil, fmt.Errorf("scan checkpoint: %v", err) } @@ -207,7 +209,7 @@ type treeState struct { } // readTreeState returns the currently stored tree state information. -// If there is no stored tree state, nil is returned with no error. +// If there is no stored tree state, it returns os.ErrNotExist. func (s *Storage) readTreeState(ctx context.Context, tx *sql.Tx) (*treeState, error) { row := tx.QueryRowContext(ctx, selectTreeStateByIDForUpdateSQL, treeStateID) if err := row.Err(); err != nil { @@ -217,7 +219,7 @@ func (s *Storage) readTreeState(ctx context.Context, tx *sql.Tx) (*treeState, er r := &treeState{} if err := row.Scan(&r.size, &r.root); err != nil { if err == sql.ErrNoRows { - return nil, nil + return nil, os.ErrNotExist } return nil, fmt.Errorf("scan tree state: %v", err) } @@ -234,16 +236,13 @@ func (s *Storage) writeTreeState(ctx context.Context, tx *sql.Tx, size uint64, r return nil } -// ReadTile returns a full tile or a partial tile at the given level, index and width. -// If the tile is not found, nil is returned with no error. +// ReadTile returns a full tile or a partial tile at the given level, index and treeSize. +// If the tile is not found, it returns os.ErrNotExist. // -// TODO: Handle the following scenarios: -// 1. Full tile request with full tile output: Return full tile. -// 2. Full tile request with partial tile output: Return error. -// 3. Partial tile request with full/larger partial tile output: Return trimmed partial tile with correct tile width. -// 4. Partial tile request with partial tile (same width) output: Return partial tile. -// 5. Partial tile request with smaller partial tile output: Return error. -func (s *Storage) ReadTile(ctx context.Context, level, index, width uint64) ([]byte, error) { +// Note that if a partial tile is requested, but a larger tile is available, this +// will return the largest tile available. This could be trimmed to return only the +// number of entries specifically requested if this behaviour becomes problematic. +func (s *Storage) ReadTile(ctx context.Context, level, index, minTreeSize uint64) ([]byte, error) { row := s.db.QueryRowContext(ctx, selectSubtreeByLevelAndIndexSQL, level, index) if err := row.Err(); err != nil { return nil, err @@ -252,20 +251,34 @@ func (s *Storage) ReadTile(ctx context.Context, level, index, width uint64) ([]b var tile []byte if err := row.Scan(&tile); err != nil { if err == sql.ErrNoRows { - return nil, nil + return nil, os.ErrNotExist } return nil, fmt.Errorf("scan tile: %v", err) } - // Return nil when returning a partial tile on a full tile request. - if width == 256 && uint64(len(tile)/32) != width { - return nil, nil + requestedWidth := partialTileSize(level, index, minTreeSize) + numEntries := uint64(len(tile) / sha256.Size) + + if requestedWidth > numEntries { + // If the user has requested a size larger than we have, they can't have it + return nil, os.ErrNotExist } return tile, nil } +// partialTileSize returns the expected number of leaves in a tile at the given location within +// a tree of the specified logSize, or 0 if the tile is expected to be fully populated. +func partialTileSize(level, index, logSize uint64) uint64 { + sizeAtLevel := logSize >> (level * 8) + fullTiles := sizeAtLevel / 256 + if index < fullTiles { + return 256 + } + return sizeAtLevel % 256 +} + // writeTile replaces the tile nodes at the given level and index. func (s *Storage) writeTile(ctx context.Context, tx *sql.Tx, level, index uint64, nodes []byte) error { if _, err := tx.ExecContext(ctx, replaceSubtreeSQL, level, index, nodes); err != nil { @@ -277,7 +290,7 @@ func (s *Storage) writeTile(ctx context.Context, tx *sql.Tx, level, index uint64 } // ReadEntryBundle returns the log entries at the given index. -// If the entry bundle is not found, nil is returned with no error. +// If the entry bundle is not found, it returns os.ErrNotExist. // // TODO: Handle the following scenarios: // 1. Full tile request with full tile output: Return full tile. @@ -294,9 +307,8 @@ func (s *Storage) ReadEntryBundle(ctx context.Context, index, treeSize uint64) ( var entryBundle []byte if err := row.Scan(&entryBundle); err != nil { if err == sql.ErrNoRows { - return nil, nil + return nil, os.ErrNotExist } - return nil, fmt.Errorf("scan entry bundle: %v", err) } diff --git a/storage/mysql/mysql_test.go b/storage/mysql/mysql_test.go index 6d4d98ef..5989a499 100644 --- a/storage/mysql/mysql_test.go +++ b/storage/mysql/mysql_test.go @@ -22,8 +22,10 @@ package mysql_test import ( "bytes" "context" + "crypto/sha256" "database/sql" "flag" + "fmt" "os" "testing" "time" @@ -48,8 +50,8 @@ var ( ) const ( - // Matching public key: "transparency.dev/tessera/example+ae330e15+ASf4/L1zE859VqlfQgGzKy34l91Gl8W6wfwp+vKP62DW" testPrivateKey = "PRIVATE+KEY+transparency.dev/tessera/example+ae330e15+AXEwZQ2L6Ga3NX70ITObzyfEIketMr2o9Kc+ed/rt/QR" + testPublicKey = "transparency.dev/tessera/example+ae330e15+ASf4/L1zE859VqlfQgGzKy34l91Gl8W6wfwp+vKP62DW" ) // TestMain checks whether the test MySQL database is available and starts the tests including database schema initialization. @@ -163,6 +165,93 @@ func TestNew(t *testing.T) { } } +func TestGetTile(t *testing.T) { + ctx := context.Background() + s := newTestMySQLStorage(t, ctx) + + awaiter := tessera.NewIntegrationAwaiter(ctx, s.ReadCheckpoint, 10*time.Millisecond) + + treeSize := 258 + var lastIndex uint64 + for i := range treeSize { + idx, _, err := awaiter.Await(ctx, s.Add(ctx, tessera.NewEntry([]byte(fmt.Sprintf("TestGetTile %d", i))))) + if err != nil { + t.Fatalf("Failed to prep test with entry: %v", err) + } + if idx > lastIndex { + lastIndex = idx + } + } + if got, want := lastIndex, uint64(treeSize-1); got != want { + t.Fatalf("expected only newly created entries in database; tests are not hermetic (got %d, want %d)", got, want) + } + + for _, test := range []struct { + name string + level, index, treeSize uint64 + wantEntries int + wantNotFound bool + }{ + { + name: "requested partial tile for a complete tile", + level: 0, index: 0, treeSize: 10, + wantEntries: 256, + wantNotFound: false, + }, + { + name: "too small but that's ok", + level: 0, index: 1, treeSize: uint64(treeSize) - 1, + wantEntries: 2, + wantNotFound: false, + }, + { + name: "just right", + level: 0, index: 1, treeSize: uint64(treeSize), + wantEntries: 2, + wantNotFound: false, + }, + { + name: "too big", + level: 0, index: 1, treeSize: uint64(treeSize + 1), + wantNotFound: true, + }, + { + name: "level 1 too small", + level: 1, index: 0, treeSize: uint64(treeSize - 1), + wantEntries: 1, + wantNotFound: false, + }, + { + name: "level 1 just right", + level: 1, index: 0, treeSize: uint64(treeSize), + wantEntries: 1, + wantNotFound: false, + }, + { + name: "level 1 too big", + level: 1, index: 0, treeSize: 550, + wantNotFound: true, + }, + } { + t.Run(test.name, func(t *testing.T) { + tile, err := s.ReadTile(ctx, test.level, test.index, test.treeSize) + if err != nil { + if notFound, wantNotFound := os.IsNotExist(err), test.wantNotFound; notFound != wantNotFound { + t.Errorf("wantNotFound %v but notFound %v", wantNotFound, notFound) + } + if test.wantNotFound { + return + } + t.Errorf("got err: %v", err) + } + numEntries := len(tile) / sha256.Size + if got, want := numEntries, test.wantEntries; got != want { + t.Errorf("got %d entries, but want %d", got, want) + } + }) + } +} + func TestReadMissingTile(t *testing.T) { ctx := context.Background() s := newTestMySQLStorage(t, ctx) @@ -183,6 +272,10 @@ func TestReadMissingTile(t *testing.T) { t.Run(test.name, func(t *testing.T) { tile, err := s.ReadTile(ctx, test.level, test.index, test.width) if err != nil { + if os.IsNotExist(err) { + // this is success for this test + return + } t.Errorf("got err: %v", err) } if tile != nil { @@ -212,6 +305,10 @@ func TestReadMissingEntryBundle(t *testing.T) { t.Run(test.name, func(t *testing.T) { entryBundle, err := s.ReadEntryBundle(ctx, test.index, test.index) if err != nil { + if os.IsNotExist(err) { + // this is success for this test + return + } t.Errorf("got err: %v", err) } if entryBundle != nil { @@ -286,7 +383,7 @@ func TestTileRoundTrip(t *testing.T) { } tileLevel, tileIndex, _, nodeIndex := layout.NodeCoordsToTileAddress(0, entryIndex) - tileRaw, err := s.ReadTile(ctx, tileLevel, tileIndex, nodeIndex) + tileRaw, err := s.ReadTile(ctx, tileLevel, tileIndex, nodeIndex+1) if err != nil { t.Errorf("ReadTile got err: %v", err) } @@ -358,8 +455,12 @@ func TestEntryBundleRoundTrip(t *testing.T) { func newTestMySQLStorage(t *testing.T, ctx context.Context) *mysql.Storage { t.Helper() + initDatabaseSchema(ctx) - s, err := mysql.New(ctx, testDB, tessera.WithCheckpointSigner(noteSigner)) + s, err := mysql.New(ctx, testDB, + tessera.WithCheckpointSigner(noteSigner), + tessera.WithCheckpointInterval(200*time.Millisecond), + tessera.WithBatching(128, 100*time.Millisecond)) if err != nil { t.Errorf("Failed to create mysql.Storage: %v", err) }