From 52b1ccdf3578418aa427adef9dcf942d90ae4fdd Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sun, 19 Dec 2021 10:15:24 +0200 Subject: [PATCH] fix: add Event.QueryTemplate and change Event.Query to be always formatted --- db.go | 51 ++++++++++++++++++------------ extra/bunotel/otel.go | 14 ++++---- hook.go | 7 ++-- internal/dbtest/query_hook_test.go | 2 +- query_base.go | 4 +-- query_select.go | 4 +-- 6 files changed, 48 insertions(+), 34 deletions(-) diff --git a/db.go b/db.go index 62f73b492..78969c019 100644 --- a/db.go +++ b/db.go @@ -227,8 +227,9 @@ func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { func (db *DB) ExecContext( ctx context.Context, query string, args ...interface{}, ) (sql.Result, error) { - ctx, event := db.beforeQuery(ctx, nil, query, args, nil) - res, err := db.DB.ExecContext(ctx, db.format(query, args)) + formattedQuery := db.format(query, args) + ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + res, err := db.DB.ExecContext(ctx, formattedQuery) db.afterQuery(ctx, event, res, err) return res, err } @@ -240,8 +241,9 @@ func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { func (db *DB) QueryContext( ctx context.Context, query string, args ...interface{}, ) (*sql.Rows, error) { - ctx, event := db.beforeQuery(ctx, nil, query, args, nil) - rows, err := db.DB.QueryContext(ctx, db.format(query, args)) + formattedQuery := db.format(query, args) + ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + rows, err := db.DB.QueryContext(ctx, formattedQuery) db.afterQuery(ctx, event, nil, err) return rows, err } @@ -251,8 +253,9 @@ func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row { } func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - ctx, event := db.beforeQuery(ctx, nil, query, args, nil) - row := db.DB.QueryRowContext(ctx, db.format(query, args)) + formattedQuery := db.format(query, args) + ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + row := db.DB.QueryRowContext(ctx, formattedQuery) db.afterQuery(ctx, event, nil, row.Err()) return row } @@ -282,8 +285,9 @@ func (db *DB) Conn(ctx context.Context) (Conn, error) { func (c Conn) ExecContext( ctx context.Context, query string, args ...interface{}, ) (sql.Result, error) { - ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil) - res, err := c.Conn.ExecContext(ctx, c.db.format(query, args)) + formattedQuery := c.db.format(query, args) + ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + res, err := c.Conn.ExecContext(ctx, formattedQuery) c.db.afterQuery(ctx, event, res, err) return res, err } @@ -291,15 +295,17 @@ func (c Conn) ExecContext( func (c Conn) QueryContext( ctx context.Context, query string, args ...interface{}, ) (*sql.Rows, error) { - ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil) - rows, err := c.Conn.QueryContext(ctx, c.db.format(query, args)) + formattedQuery := c.db.format(query, args) + ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + rows, err := c.Conn.QueryContext(ctx, formattedQuery) c.db.afterQuery(ctx, event, nil, err) return rows, err } func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil) - row := c.Conn.QueryRowContext(ctx, c.db.format(query, args)) + formattedQuery := c.db.format(query, args) + ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + row := c.Conn.QueryRowContext(ctx, formattedQuery) c.db.afterQuery(ctx, event, nil, row.Err()) return row } @@ -413,7 +419,7 @@ func (db *DB) Begin() (Tx, error) { } func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { - ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, nil) + ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, "BEGIN", nil) tx, err := db.DB.BeginTx(ctx, opts) db.afterQuery(ctx, event, nil, err) if err != nil { @@ -427,14 +433,14 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { } func (tx Tx) Commit() error { - ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, nil) + ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, "COMMIT", nil) err := tx.Tx.Commit() tx.db.afterQuery(ctx, event, nil, err) return err } func (tx Tx) Rollback() error { - ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, nil) + ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, "ROLLBACK", nil) err := tx.Tx.Rollback() tx.db.afterQuery(ctx, event, nil, err) return err @@ -447,8 +453,9 @@ func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { func (tx Tx) ExecContext( ctx context.Context, query string, args ...interface{}, ) (sql.Result, error) { - ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil) - res, err := tx.Tx.ExecContext(ctx, tx.db.format(query, args)) + formattedQuery := tx.db.format(query, args) + ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + res, err := tx.Tx.ExecContext(ctx, formattedQuery) tx.db.afterQuery(ctx, event, res, err) return res, err } @@ -460,8 +467,9 @@ func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { func (tx Tx) QueryContext( ctx context.Context, query string, args ...interface{}, ) (*sql.Rows, error) { - ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil) - rows, err := tx.Tx.QueryContext(ctx, tx.db.format(query, args)) + formattedQuery := tx.db.format(query, args) + ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + rows, err := tx.Tx.QueryContext(ctx, formattedQuery) tx.db.afterQuery(ctx, event, nil, err) return rows, err } @@ -471,8 +479,9 @@ func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row { } func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil) - row := tx.Tx.QueryRowContext(ctx, tx.db.format(query, args)) + formattedQuery := tx.db.format(query, args) + ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + row := tx.Tx.QueryRowContext(ctx, formattedQuery) tx.db.afterQuery(ctx, event, nil, row.Err()) return row } diff --git a/extra/bunotel/otel.go b/extra/bunotel/otel.go index 03301d8a1..25e4be312 100644 --- a/extra/bunotel/otel.go +++ b/extra/bunotel/otel.go @@ -147,10 +147,10 @@ func eventQuery(event *bun.QueryEvent) string { var query string - if len(event.Query) > softQueryLimit { - query = unformattedQuery(event) - } else { + if len(event.Query) <= softQueryLimit { query = event.Query + } else { + query = unformattedQuery(event) } if len(query) > hardQueryLimit { @@ -161,10 +161,12 @@ func eventQuery(event *bun.QueryEvent) string { } func unformattedQuery(event *bun.QueryEvent) string { - if b, err := event.QueryAppender.AppendQuery(schema.NewNopFormatter(), nil); err == nil { - return bytesToString(b) + if event.IQuery != nil { + if b, err := event.IQuery.AppendQuery(schema.NewNopFormatter(), nil); err == nil { + return bytesToString(b) + } } - return string(event.Query) + return string(event.QueryTemplate) } func dbSystem(db *bun.DB) attribute.KeyValue { diff --git a/hook.go b/hook.go index 7cca7ef6a..81249329a 100644 --- a/hook.go +++ b/hook.go @@ -13,9 +13,10 @@ import ( type QueryEvent struct { DB *DB - QueryAppender schema.QueryAppender // Deprecated: use IQuery instead + QueryAppender schema.QueryAppender // DEPRECATED: use IQuery instead IQuery Query Query string + QueryTemplate string QueryArgs []interface{} Model Model @@ -51,8 +52,9 @@ type QueryHook interface { func (db *DB) beforeQuery( ctx context.Context, iquery Query, - query string, + queryTemplate string, queryArgs []interface{}, + query string, model Model, ) (context.Context, *QueryEvent) { atomic.AddUint32(&db.stats.Queries, 1) @@ -68,6 +70,7 @@ func (db *DB) beforeQuery( QueryAppender: iquery, IQuery: iquery, Query: query, + QueryTemplate: queryTemplate, QueryArgs: queryArgs, StartTime: time.Now(), diff --git a/internal/dbtest/query_hook_test.go b/internal/dbtest/query_hook_test.go index 06d3976e8..4e0116709 100644 --- a/internal/dbtest/query_hook_test.go +++ b/internal/dbtest/query_hook_test.go @@ -27,7 +27,7 @@ func testQueryHook(t *testing.T, dbName string, db *bun.DB) { require.Equal( t, "SELECT * FROM (SELECT 1) AS t WHERE ('foo' = 'bar')", string(event.Query)) - b, err := event.QueryAppender.AppendQuery(schema.NewNopFormatter(), nil) + b, err := event.IQuery.AppendQuery(schema.NewNopFormatter(), nil) require.NoError(t, err) require.Equal(t, "SELECT * FROM (SELECT 1) AS t WHERE (? = ?)", string(b)) diff --git a/query_base.go b/query_base.go index 09df2dbc6..8b78d25e1 100644 --- a/query_base.go +++ b/query_base.go @@ -468,7 +468,7 @@ func (q *baseQuery) scan( model Model, hasDest bool, ) (sql.Result, error) { - ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model) + ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) rows, err := q.conn.QueryContext(ctx, query) if err != nil { @@ -498,7 +498,7 @@ func (q *baseQuery) exec( iquery Query, query string, ) (sql.Result, error) { - ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model) + ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) res, err := q.conn.ExecContext(ctx, query) q.db.afterQuery(ctx, event, nil, err) return res, err diff --git a/query_select.go b/query_select.go index a07cb5d5f..401bf1acc 100644 --- a/query_select.go +++ b/query_select.go @@ -773,7 +773,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { } query := internal.String(queryBytes) - ctx, event := q.db.beforeQuery(ctx, qq, query, nil, q.model) + ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var num int err = q.conn.QueryRowContext(ctx, query).Scan(&num) @@ -858,7 +858,7 @@ func (q *SelectQuery) Exists(ctx context.Context) (bool, error) { } query := internal.String(queryBytes) - ctx, event := q.db.beforeQuery(ctx, qq, query, nil, q.model) + ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var exists bool err = q.conn.QueryRowContext(ctx, query).Scan(&exists)