Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a separate unit to manage the cache of prepared statements #2937

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
different version can lead to presubmits failing due to unexpected
diffs.

* Use a separate unit (StmtCache) to manage the cache of prepared statements, which wraps the sql.Stmt struct to handle and monitor the execution errors of the prepared statement. When an error occurs during statement execution, it closes the statement and clears the cache, as well as increments the error monitoring indicator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what 'unit' means?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word 'separate unit' comes from the current TODO comment:

// TODO(al,martin): consider pulling this all out as a separate unit for reuse


### Misc

* Bump Go version from 1.17 to 1.19.
Expand Down
17 changes: 9 additions & 8 deletions storage/mysql/log_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/google/trillian/monitoring"
"github.com/google/trillian/storage"
"github.com/google/trillian/storage/cache"
"github.com/google/trillian/storage/sqlutil"
"github.com/google/trillian/storage/tree"
"github.com/google/trillian/types"
"github.com/transparency-dev/merkle/compact"
Expand Down Expand Up @@ -135,7 +136,7 @@ func NewLogStorage(db *sql.DB, mf monitoring.MetricFactory) storage.LogStorage {
}
return &mySQLLogStorage{
admin: NewAdminStorage(db),
mySQLTreeStorage: newTreeStorage(db),
mySQLTreeStorage: newTreeStorage(db, mf),
metricFactory: mf,
}
}
Expand All @@ -144,16 +145,16 @@ func (m *mySQLLogStorage) CheckDatabaseAccessible(ctx context.Context) error {
return m.db.PingContext(ctx)
}

func (m *mySQLLogStorage) getLeavesByMerkleHashStmt(ctx context.Context, num int, orderBySequence bool) (*sql.Stmt, error) {
func (m *mySQLLogStorage) getLeavesByMerkleHashStmt(ctx context.Context, num int, orderBySequence bool) (*sqlutil.Stmt, error) {
if orderBySequence {
return m.getStmt(ctx, selectLeavesByMerkleHashOrderedBySequenceSQL, num, "?", "?")
return m.stmtCache.GetStmt(ctx, selectLeavesByMerkleHashOrderedBySequenceSQL, num, "?", "?")
}

return m.getStmt(ctx, selectLeavesByMerkleHashSQL, num, "?", "?")
return m.stmtCache.GetStmt(ctx, selectLeavesByMerkleHashSQL, num, "?", "?")
}

func (m *mySQLLogStorage) getLeavesByLeafIdentityHashStmt(ctx context.Context, num int) (*sql.Stmt, error) {
return m.getStmt(ctx, selectLeavesByLeafIdentityHashSQL, num, "?", "?")
func (m *mySQLLogStorage) getLeavesByLeafIdentityHashStmt(ctx context.Context, num int) (*sqlutil.Stmt, error) {
return m.stmtCache.GetStmt(ctx, selectLeavesByLeafIdentityHashSQL, num, "?", "?")
}

func (m *mySQLLogStorage) GetActiveLogIDs(ctx context.Context) ([]int64, error) {
Expand Down Expand Up @@ -730,8 +731,8 @@ func (t *logTreeTX) StoreSignedLogRoot(ctx context.Context, root *trillian.Signe
return checkResultOkAndRowCountIs(res, err, 1)
}

func (t *logTreeTX) getLeavesByHashInternal(ctx context.Context, leafHashes [][]byte, tmpl *sql.Stmt, desc string) ([]*trillian.LogLeaf, error) {
stx := t.tx.StmtContext(ctx, tmpl)
func (t *logTreeTX) getLeavesByHashInternal(ctx context.Context, leafHashes [][]byte, tmpl *sqlutil.Stmt, desc string) ([]*trillian.LogLeaf, error) {
stx := tmpl.WithTx(ctx, t.tx)
defer stx.Close()

var args []interface{}
Expand Down
71 changes: 13 additions & 58 deletions storage/mysql/tree_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ import (
"encoding/base64"
"fmt"
"runtime/debug"
"strings"
"sync"

"github.com/google/trillian"
"github.com/google/trillian/monitoring"
"github.com/google/trillian/storage/cache"
"github.com/google/trillian/storage/sqlutil"
"github.com/google/trillian/storage/storagepb"
"github.com/google/trillian/storage/tree"
"google.golang.org/protobuf/proto"
Expand All @@ -52,20 +53,15 @@ const (
AND Subtree.SubtreeRevision = x.MaxRevision
AND Subtree.TreeId = x.TreeId
AND Subtree.TreeId = ?`
placeholderSQL = "<placeholder>"
placeholderSQL = sqlutil.PlaceholderSQL
)

// mySQLTreeStorage is shared between the mySQLLog- and (forthcoming) mySQLMap-
// Storage implementations, and contains functionality which is common to both,
type mySQLTreeStorage struct {
db *sql.DB

// Must hold the mutex before manipulating the statement map. Sharing a lock because
// it only needs to be held while the statements are built, not while they execute and
// this will be a short time. These maps are from the number of placeholder '?'
// in the query to the statement that should be used.
statementMutex sync.Mutex
statements map[string]map[int]*sql.Stmt
stmtCache *sqlutil.StmtCache
}

// OpenDB opens a database connection for all MySQL-based storage implementations.
Expand All @@ -85,60 +81,19 @@ func OpenDB(dbURL string) (*sql.DB, error) {
return db, nil
}

func newTreeStorage(db *sql.DB) *mySQLTreeStorage {
func newTreeStorage(db *sql.DB, mf monitoring.MetricFactory) *mySQLTreeStorage {
return &mySQLTreeStorage{
db: db,
statements: make(map[string]map[int]*sql.Stmt),
db: db,
stmtCache: sqlutil.NewStmtCache(db, mf),
}
}

// expandPlaceholderSQL expands an sql statement by adding a specified number of '?'
// placeholder slots. At most one placeholder will be expanded.
func expandPlaceholderSQL(sql string, num int, first, rest string) string {
if num <= 0 {
panic(fmt.Errorf("trying to expand SQL placeholder with <= 0 parameters: %s", sql))
}

parameters := first + strings.Repeat(","+rest, num-1)

return strings.Replace(sql, placeholderSQL, parameters, 1)
}

// getStmt creates and caches sql.Stmt structs based on the passed in statement
// and number of bound arguments.
// TODO(al,martin): consider pulling this all out as a separate unit for reuse
// elsewhere.
func (m *mySQLTreeStorage) getStmt(ctx context.Context, statement string, num int, first, rest string) (*sql.Stmt, error) {
m.statementMutex.Lock()
defer m.statementMutex.Unlock()

if m.statements[statement] != nil {
if m.statements[statement][num] != nil {
// TODO(al,martin): we'll possibly need to expire Stmts from the cache,
// e.g. when DB connections break etc.
return m.statements[statement][num], nil
}
} else {
m.statements[statement] = make(map[int]*sql.Stmt)
}

s, err := m.db.PrepareContext(ctx, expandPlaceholderSQL(statement, num, first, rest))
if err != nil {
klog.Warningf("Failed to prepare statement %d: %s", num, err)
return nil, err
}

m.statements[statement][num] = s

return s, nil
}

func (m *mySQLTreeStorage) getSubtreeStmt(ctx context.Context, num int) (*sql.Stmt, error) {
return m.getStmt(ctx, selectSubtreeSQL, num, "?", "?")
func (m *mySQLTreeStorage) getSubtreeStmt(ctx context.Context, num int) (*sqlutil.Stmt, error) {
return m.stmtCache.GetStmt(ctx, selectSubtreeSQL, num, "?", "?")
}

func (m *mySQLTreeStorage) setSubtreeStmt(ctx context.Context, num int) (*sql.Stmt, error) {
return m.getStmt(ctx, insertSubtreeMultiSQL, num, "VALUES(?, ?, ?, ?)", "(?, ?, ?, ?)")
func (m *mySQLTreeStorage) setSubtreeStmt(ctx context.Context, num int) (*sqlutil.Stmt, error) {
return m.stmtCache.GetStmt(ctx, insertSubtreeMultiSQL, num, "VALUES(?, ?, ?, ?)", "(?, ?, ?, ?)")
}

func (m *mySQLTreeStorage) beginTreeTx(ctx context.Context, tree *trillian.Tree, hashSizeBytes int, subtreeCache *cache.SubtreeCache) (treeTX, error) {
Expand Down Expand Up @@ -183,7 +138,7 @@ func (t *treeTX) getSubtrees(ctx context.Context, treeRevision int64, ids [][]by
if err != nil {
return nil, err
}
stx := t.tx.StmtContext(ctx, tmpl)
stx := tmpl.WithTx(ctx, t.tx)
defer stx.Close()

args := make([]interface{}, 0, len(ids)+3)
Expand Down Expand Up @@ -291,7 +246,7 @@ func (t *treeTX) storeSubtrees(ctx context.Context, subtrees []*storagepb.Subtre
if err != nil {
return err
}
stx := t.tx.StmtContext(ctx, tmpl)
stx := tmpl.WithTx(ctx, t.tx)
defer stx.Close()

r, err := stx.ExecContext(ctx, args...)
Expand Down
Loading