diff --git a/CHANGELOG.md b/CHANGELOG.md index 462c63d629..9f4bd084c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ different version can lead to presubmits failing due to unexpected diffs. +* Use a separate unit (StmtCache) to manage the cache of prepared statements, which wraps the sql.Stmt struct to handle and monitor the execution errors of the prepared statement. When an error occurs during statement execution, it closes the statement and clears the cache, as well as increments the error monitoring indicator. + ### Misc * Bump Go version from 1.17 to 1.19. diff --git a/storage/mysql/log_storage.go b/storage/mysql/log_storage.go index 2d57c7d4fd..cdbd4a8670 100644 --- a/storage/mysql/log_storage.go +++ b/storage/mysql/log_storage.go @@ -29,6 +29,7 @@ import ( "github.com/google/trillian/monitoring" "github.com/google/trillian/storage" "github.com/google/trillian/storage/cache" + "github.com/google/trillian/storage/stmtcache" "github.com/google/trillian/storage/tree" "github.com/google/trillian/types" "github.com/transparency-dev/merkle/compact" @@ -135,7 +136,7 @@ func NewLogStorage(db *sql.DB, mf monitoring.MetricFactory) storage.LogStorage { } return &mySQLLogStorage{ admin: NewAdminStorage(db), - mySQLTreeStorage: newTreeStorage(db), + mySQLTreeStorage: newTreeStorage(db, mf), metricFactory: mf, } } @@ -144,16 +145,16 @@ func (m *mySQLLogStorage) CheckDatabaseAccessible(ctx context.Context) error { return m.db.PingContext(ctx) } -func (m *mySQLLogStorage) getLeavesByMerkleHashStmt(ctx context.Context, num int, orderBySequence bool) (*sql.Stmt, error) { +func (m *mySQLLogStorage) getLeavesByMerkleHashStmt(ctx context.Context, num int, orderBySequence bool) (*stmtcache.Stmt, error) { if orderBySequence { - return m.getStmt(ctx, selectLeavesByMerkleHashOrderedBySequenceSQL, num, "?", "?") + return m.stmtCache.GetStmt(ctx, selectLeavesByMerkleHashOrderedBySequenceSQL, num, "?", "?") } - return m.getStmt(ctx, selectLeavesByMerkleHashSQL, num, "?", "?") + return m.stmtCache.GetStmt(ctx, selectLeavesByMerkleHashSQL, num, "?", "?") } -func (m *mySQLLogStorage) getLeavesByLeafIdentityHashStmt(ctx context.Context, num int) (*sql.Stmt, error) { - return m.getStmt(ctx, selectLeavesByLeafIdentityHashSQL, num, "?", "?") +func (m *mySQLLogStorage) getLeavesByLeafIdentityHashStmt(ctx context.Context, num int) (*stmtcache.Stmt, error) { + return m.stmtCache.GetStmt(ctx, selectLeavesByLeafIdentityHashSQL, num, "?", "?") } func (m *mySQLLogStorage) GetActiveLogIDs(ctx context.Context) ([]int64, error) { @@ -730,8 +731,8 @@ func (t *logTreeTX) StoreSignedLogRoot(ctx context.Context, root *trillian.Signe return checkResultOkAndRowCountIs(res, err, 1) } -func (t *logTreeTX) getLeavesByHashInternal(ctx context.Context, leafHashes [][]byte, tmpl *sql.Stmt, desc string) ([]*trillian.LogLeaf, error) { - stx := t.tx.StmtContext(ctx, tmpl) +func (t *logTreeTX) getLeavesByHashInternal(ctx context.Context, leafHashes [][]byte, tmpl *stmtcache.Stmt, desc string) ([]*trillian.LogLeaf, error) { + stx := tmpl.WithTx(ctx, t.tx) defer stx.Close() var args []interface{} diff --git a/storage/mysql/tree_storage.go b/storage/mysql/tree_storage.go index ffb0159cd6..515af8940d 100644 --- a/storage/mysql/tree_storage.go +++ b/storage/mysql/tree_storage.go @@ -21,11 +21,12 @@ import ( "encoding/base64" "fmt" "runtime/debug" - "strings" "sync" "github.com/google/trillian" + "github.com/google/trillian/monitoring" "github.com/google/trillian/storage/cache" + "github.com/google/trillian/storage/stmtcache" "github.com/google/trillian/storage/storagepb" "github.com/google/trillian/storage/tree" "google.golang.org/protobuf/proto" @@ -52,7 +53,7 @@ const ( AND Subtree.SubtreeRevision = x.MaxRevision AND Subtree.TreeId = x.TreeId AND Subtree.TreeId = ?` - placeholderSQL = "" + placeholderSQL = stmtcache.PlaceholderSQL ) // mySQLTreeStorage is shared between the mySQLLog- and (forthcoming) mySQLMap- @@ -60,12 +61,7 @@ const ( type mySQLTreeStorage struct { db *sql.DB - // Must hold the mutex before manipulating the statement map. Sharing a lock because - // it only needs to be held while the statements are built, not while they execute and - // this will be a short time. These maps are from the number of placeholder '?' - // in the query to the statement that should be used. - statementMutex sync.Mutex - statements map[string]map[int]*sql.Stmt + stmtCache *stmtcache.StmtCache } // OpenDB opens a database connection for all MySQL-based storage implementations. @@ -85,60 +81,19 @@ func OpenDB(dbURL string) (*sql.DB, error) { return db, nil } -func newTreeStorage(db *sql.DB) *mySQLTreeStorage { +func newTreeStorage(db *sql.DB, mf monitoring.MetricFactory) *mySQLTreeStorage { return &mySQLTreeStorage{ - db: db, - statements: make(map[string]map[int]*sql.Stmt), + db: db, + stmtCache: stmtcache.New(db, mf), } } -// expandPlaceholderSQL expands an sql statement by adding a specified number of '?' -// placeholder slots. At most one placeholder will be expanded. -func expandPlaceholderSQL(sql string, num int, first, rest string) string { - if num <= 0 { - panic(fmt.Errorf("trying to expand SQL placeholder with <= 0 parameters: %s", sql)) - } - - parameters := first + strings.Repeat(","+rest, num-1) - - return strings.Replace(sql, placeholderSQL, parameters, 1) -} - -// getStmt creates and caches sql.Stmt structs based on the passed in statement -// and number of bound arguments. -// TODO(al,martin): consider pulling this all out as a separate unit for reuse -// elsewhere. -func (m *mySQLTreeStorage) getStmt(ctx context.Context, statement string, num int, first, rest string) (*sql.Stmt, error) { - m.statementMutex.Lock() - defer m.statementMutex.Unlock() - - if m.statements[statement] != nil { - if m.statements[statement][num] != nil { - // TODO(al,martin): we'll possibly need to expire Stmts from the cache, - // e.g. when DB connections break etc. - return m.statements[statement][num], nil - } - } else { - m.statements[statement] = make(map[int]*sql.Stmt) - } - - s, err := m.db.PrepareContext(ctx, expandPlaceholderSQL(statement, num, first, rest)) - if err != nil { - klog.Warningf("Failed to prepare statement %d: %s", num, err) - return nil, err - } - - m.statements[statement][num] = s - - return s, nil -} - -func (m *mySQLTreeStorage) getSubtreeStmt(ctx context.Context, num int) (*sql.Stmt, error) { - return m.getStmt(ctx, selectSubtreeSQL, num, "?", "?") +func (m *mySQLTreeStorage) getSubtreeStmt(ctx context.Context, num int) (*stmtcache.Stmt, error) { + return m.stmtCache.GetStmt(ctx, selectSubtreeSQL, num, "?", "?") } -func (m *mySQLTreeStorage) setSubtreeStmt(ctx context.Context, num int) (*sql.Stmt, error) { - return m.getStmt(ctx, insertSubtreeMultiSQL, num, "VALUES(?, ?, ?, ?)", "(?, ?, ?, ?)") +func (m *mySQLTreeStorage) setSubtreeStmt(ctx context.Context, num int) (*stmtcache.Stmt, error) { + return m.stmtCache.GetStmt(ctx, insertSubtreeMultiSQL, num, "VALUES(?, ?, ?, ?)", "(?, ?, ?, ?)") } func (m *mySQLTreeStorage) beginTreeTx(ctx context.Context, tree *trillian.Tree, hashSizeBytes int, subtreeCache *cache.SubtreeCache) (treeTX, error) { @@ -183,7 +138,7 @@ func (t *treeTX) getSubtrees(ctx context.Context, treeRevision int64, ids [][]by if err != nil { return nil, err } - stx := t.tx.StmtContext(ctx, tmpl) + stx := tmpl.WithTx(ctx, t.tx) defer stx.Close() args := make([]interface{}, 0, len(ids)+3) @@ -291,7 +246,7 @@ func (t *treeTX) storeSubtrees(ctx context.Context, subtrees []*storagepb.Subtre if err != nil { return err } - stx := t.tx.StmtContext(ctx, tmpl) + stx := tmpl.WithTx(ctx, t.tx) defer stx.Close() r, err := stx.ExecContext(ctx, args...) diff --git a/storage/stmtcache/stmtcache.go b/storage/stmtcache/stmtcache.go new file mode 100644 index 0000000000..2e39e88e1c --- /dev/null +++ b/storage/stmtcache/stmtcache.go @@ -0,0 +1,200 @@ +// Package stmtcache contains tools for managing the prepared-statement cache. +package stmtcache + +import ( + "context" + "database/sql" + "fmt" + "strings" + "sync" + + "github.com/google/trillian/monitoring" + "k8s.io/klog/v2" +) + +var ( + once sync.Once + errStmtCounter monitoring.Counter +) + +// PlaceholderSQL SQL statement placeholder. +const PlaceholderSQL = "" + +// Stmt is wraps the sql.Stmt struct for handling and monitoring SQL errors. +// If Stmt execution errors occur, it is automatically closed and the prepared statements in the cache are cleared. +type Stmt struct { + statement string + placeholderNum int + stmtCache *StmtCache + stmt *sql.Stmt + parentStmt *Stmt +} + +// errHandler handling and monitoring SQL errors +// This err parameter is not currently used, but it may be necessary to perform more granular processing and monitoring of different errs in the future. +func (s *Stmt) errHandler(_ error) { + o := s + if s.parentStmt != nil { + o = s.parentStmt + } + + if err := o.Close(); err != nil { + klog.Warningf("Failed to close stmt: %s", err) + } + + if o.stmtCache != nil { + once.Do(func() { + errStmtCounter = o.stmtCache.mf.NewCounter("sql_stmt_errors", "Number of statement execution errors") + }) + + errStmtCounter.Inc() + } +} + +// SQLStmt returns the referenced sql.Stmt struct. +func (s *Stmt) SQLStmt() *sql.Stmt { + return s.stmt +} + +// Close closes the Stmt. +// Clear if Stmt belongs to cache +func (s *Stmt) Close() error { + if cache := s.stmtCache; cache != nil { + cache.clearOne(s) + } + + return s.stmt.Close() +} + +// WithTx returns a transaction-specific prepared statement from +// an existing statement. +// The transaction-specific Stmt is closed by the caller. +func (s *Stmt) WithTx(ctx context.Context, tx *sql.Tx) *Stmt { + parent := s + if s.parentStmt != nil { + parent = s.parentStmt + } + return &Stmt{ + parentStmt: parent, + stmt: tx.StmtContext(ctx, parent.stmt), + } +} + +// ExecContext executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { + res, err := s.stmt.ExecContext(ctx, args...) + if err != nil { + s.errHandler(err) + } + return res, err +} + +// QueryContext executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { + res, err := s.stmt.QueryContext(ctx, args...) + if err != nil { + s.errHandler(err) + } + return res, err +} + +// QueryRowContext executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *sql.Row { + res := s.stmt.QueryRowContext(ctx, args...) + if err := res.Err(); err != nil { + s.errHandler(err) + } + return res +} + +// StmtCache is a cache of the sql.Stmt structs. +type StmtCache struct { + db *sql.DB + statementMutex sync.Mutex + statements map[string]map[int]*sql.Stmt + mf monitoring.MetricFactory +} + +// New creates a StmtCache instance. +func New(db *sql.DB, mf monitoring.MetricFactory) *StmtCache { + if mf == nil { + mf = monitoring.InertMetricFactory{} + } + + return &StmtCache{ + db: db, + statements: make(map[string]map[int]*sql.Stmt), + mf: mf, + } +} + +// clearOne clear the cache of a sql.Stmt. +func (sc *StmtCache) clearOne(s *Stmt) { + if s == nil || s.stmt == nil || s.stmtCache != sc { + return + } + + sc.statementMutex.Lock() + defer sc.statementMutex.Unlock() + + if _s, ok := sc.statements[s.statement][s.placeholderNum]; ok && _s == s.stmt { + sc.statements[s.statement][s.placeholderNum] = nil + } +} + +func (sc *StmtCache) getStmt(ctx context.Context, statement string, num int, first, rest string) (*sql.Stmt, error) { + sc.statementMutex.Lock() + defer sc.statementMutex.Unlock() + + if sc.statements[statement] != nil { + if sc.statements[statement][num] != nil { + return sc.statements[statement][num], nil + } + } else { + sc.statements[statement] = make(map[int]*sql.Stmt) + } + + s, err := sc.db.PrepareContext(ctx, expandPlaceholderSQL(statement, num, first, rest)) + if err != nil { + klog.Warningf("Failed to prepare statement %d: %s", num, err) + return nil, err + } + + sc.statements[statement][num] = s + + return s, nil +} + +// expandPlaceholderSQL expands an sql statement by adding a specified number of '?' +// placeholder slots. At most one placeholder will be expanded. +func expandPlaceholderSQL(sql string, num int, first, rest string) string { + if num <= 0 { + panic(fmt.Errorf("trying to expand SQL placeholder with <= 0 parameters: %s", sql)) + } + + parameters := first + strings.Repeat(","+rest, num-1) + + return strings.Replace(sql, PlaceholderSQL, parameters, 1) +} + +// GetStmt creates and caches sql.Stmt and returns their wrapper Stmt. +func (sc *StmtCache) GetStmt(ctx context.Context, statement string, num int, first, rest string) (*Stmt, error) { + stmt, err := sc.getStmt(ctx, statement, num, first, rest) + if err != nil { + return nil, err + } + + return &Stmt{ + statement: statement, + placeholderNum: num, + stmtCache: sc, + stmt: stmt, + }, nil +} diff --git a/storage/stmtcache/stmtcache_test.go b/storage/stmtcache/stmtcache_test.go new file mode 100644 index 0000000000..bf645c0879 --- /dev/null +++ b/storage/stmtcache/stmtcache_test.go @@ -0,0 +1,165 @@ +package stmtcache_test + +import ( + "context" + "database/sql" + "os" + "testing" + + "github.com/google/trillian/storage/stmtcache" + "github.com/google/trillian/storage/testdb" + "k8s.io/klog/v2" +) + +var db *sql.DB + +func TestMain(m *testing.M) { + if !testdb.MySQLAvailable() { + klog.Errorf("MySQL not available, skipping all stmt tests") + return + } + ctx := context.Background() + + var done func(context.Context) + var err error + db, done, err = testdb.NewTrillianDB(ctx, testdb.DriverMySQL) + if err != nil { + panic(err) + } + + status := m.Run() + done(context.Background()) + os.Exit(status) +} + +func TestStmtExecContext(t *testing.T) { + ctx := context.Background() + cache := stmtcache.New(db, nil) + sql := `SELECT ` + stmtcache.PlaceholderSQL + stmt, err := cache.GetStmt(ctx, sql, 1, "?", "") + if err != nil { + t.Fatalf("Failed to cache.GetStmt: %s", err) + } + if _, err = stmt.ExecContext(ctx, ""); err != nil { + t.Fatalf("Failed to stmt.ExecContext: %s", err) + } +} + +func TestStmtQueryContext(t *testing.T) { + ctx := context.Background() + cache := stmtcache.New(db, nil) + sql := `SELECT ` + stmtcache.PlaceholderSQL + stmt, err := cache.GetStmt(ctx, sql, 1, "?", "") + if err != nil { + t.Fatalf("Failed to cache.GetStmt: %s", err) + } + want := "TestQuery" + rows, err := stmt.QueryContext(ctx, want) + if err != nil { + t.Fatalf("Failed to stmt.QueryContext: %s", err) + } + defer rows.Close() + + rows.Next() + var res string + if err := rows.Scan(&res); err != nil { + t.Fatalf("Failed to rows.Scan: %s", err) + } + + if res != want { + t.Errorf("Unexpected results") + } +} + +func TestStmtQueryRowContext(t *testing.T) { + ctx := context.Background() + cache := stmtcache.New(db, nil) + sql := `SELECT ` + stmtcache.PlaceholderSQL + stmt, err := cache.GetStmt(ctx, sql, 1, "?", "") + if err != nil { + t.Fatalf("Failed to cache.GetStmt: %s", err) + } + want := "TestQueryRow" + row := stmt.QueryRowContext(ctx, want) + if err != nil { + t.Fatalf("Failed to stmt.QueryRowContext: %s", err) + } + var res string + if err := row.Scan(&res); err != nil { + t.Fatalf("Failed to row.Scan: %s", err) + } + + if res != want { + t.Errorf("Unexpected results") + } +} + +func TestStmtWithTx(t *testing.T) { + ctx := context.Background() + if _, err := db.ExecContext(ctx, "CREATE TABLE TestStmtWithTx(ID int)"); err != nil { + t.Fatalf("Failed to create table: %s", err) + } + defer func() { + if _, err := db.ExecContext(ctx, "DROP TABLE TestStmtWithTx"); err != nil { + klog.Errorf("Failed to drop table: %s", err) + } + }() + + cache := stmtcache.New(db, nil) + sql := `INSERT INTO TestStmtWithTx(ID) ` + stmtcache.PlaceholderSQL + stmt, err := cache.GetStmt(ctx, sql, 1, "VALUES(?)", "(?)") + if err != nil { + t.Fatalf("Failed to cache.GetStmt: %s", err) + } + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("Failed to db.BeginTx: %s", err) + } + stx := stmt.WithTx(ctx, tx) + defer stx.Close() + + id := 1 + _, err = stx.ExecContext(ctx, id) + if err != nil { + t.Fatalf("Failed to stx.ExecContext: %s", err) + } + + if err := tx.Rollback(); err != nil { + klog.Errorf("Failed to tx.Rollback: %s", err) + } + + row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM TestStmtWithTx WHERE ID = ?", id) + var count int + if err = row.Scan(&count); err != nil { + t.Fatalf("Failed to row.Scan: %s", err) + } + + if count != 0 { + t.Errorf("Transaction not rolled back") + } +} + +func TestStmtExecutionError(t *testing.T) { + cache := stmtcache.New(db, nil) + ctx := context.Background() + sql := `SELECT ` + stmtcache.PlaceholderSQL + stmt, err := cache.GetStmt(ctx, sql, 1, "?", "") + if err != nil { + t.Fatalf("Failed to cache.GetStmt: %s", err) + } + + if err = stmt.SQLStmt().Close(); err != nil { + t.Fatalf("Failed to close sql.Stmt: %s", err) + } + // Execution error trigger cache clear logic. + if _, err = stmt.ExecContext(ctx, ""); err == nil { + t.Fatal("Unexpected execution succeeded") + } + stmt, err = cache.GetStmt(ctx, sql, 1, "?", "") + if err != nil { + t.Fatalf("Failed to cache.GetStmt: %s", err) + } + if _, err = stmt.ExecContext(ctx, ""); err != nil { + t.Fatalf("Failed to stmt.ExecContext: %s", err) + } +}