From 5c3a30c3c7f370939fb16bc050f818b647ee5b1f Mon Sep 17 00:00:00 2001 From: Gerrit Date: Fri, 23 Aug 2024 13:21:43 +0200 Subject: [PATCH] Small improvements. --- auditing/auditing-interceptor.go | 10 +- auditing/auditing.go | 36 +++---- auditing/meilisearch.go | 8 +- auditing/meilisearch_integration_test.go | 21 ++-- auditing/timescaledb.go | 129 +++++++++++++++++------ auditing/timescaledb_integration_test.go | 21 ++-- 6 files changed, 147 insertions(+), 78 deletions(-) diff --git a/auditing/auditing-interceptor.go b/auditing/auditing-interceptor.go index 94749da..1e2d979 100644 --- a/auditing/auditing-interceptor.go +++ b/auditing/auditing-interceptor.go @@ -71,7 +71,7 @@ func UnaryServerInterceptor(a Auditing, logger *slog.Logger, shouldAudit func(fu auditReqContext.StatusCode = statusCodeFromGrpc(err) if err != nil { - auditReqContext.Error = err.Error() + auditReqContext.Error = err err2 := a.Index(auditReqContext) if err2 != nil { logger.Error("unable to index", "error", err2) @@ -129,7 +129,7 @@ func StreamServerInterceptor(a Auditing, logger *slog.Logger, shouldAudit func(f auditReqContext.StatusCode = statusCodeFromGrpc(err) if err != nil { - auditReqContext.Error = err.Error() + auditReqContext.Error = err err2 := a.Index(auditReqContext) if err2 != nil { logger.Error("unable to index", "error", err2) @@ -244,7 +244,7 @@ func (a auditingConnectInterceptor) WrapStreamingHandler(next connect.StreamingH auditReqContext.StatusCode = statusCodeFromGrpc(err) if err != nil { - auditReqContext.Error = err.Error() + auditReqContext.Error = err err2 := a.auditing.Index(auditReqContext) if err2 != nil { a.logger.Error("unable to index", "error", err2) @@ -311,7 +311,7 @@ func (i auditingConnectInterceptor) WrapUnary(next connect.UnaryFunc) connect.Un auditReqContext.StatusCode = statusCodeFromGrpc(err) if err != nil { - auditReqContext.Error = err.Error() + auditReqContext.Error = err err2 := i.auditing.Index(auditReqContext) if err2 != nil { i.logger.Error("unable to index", "error", err2) @@ -432,7 +432,7 @@ func HttpFilter(a Auditing, logger *slog.Logger) (restful.FilterFunction, error) err = json.Unmarshal(body, &auditReqContext.Body) if err != nil { auditReqContext.Body = strBody - auditReqContext.Error = err.Error() + auditReqContext.Error = err } err = a.Index(auditReqContext) diff --git a/auditing/auditing.go b/auditing/auditing.go index 424470a..b7238ec 100644 --- a/auditing/auditing.go +++ b/auditing/auditing.go @@ -1,6 +1,7 @@ package auditing import ( + "context" "log/slog" "os" "path/filepath" @@ -49,39 +50,38 @@ const ( const EntryFilterDefaultLimit int64 = 100 type Entry struct { - Id string `db:"-"` // filled by the auditing driver - - Component string `db:"component"` - RequestId string `db:"rqid" json:"rqid"` - Type EntryType `db:"type"` - Timestamp time.Time `db:"timestamp"` + Id string // filled by the auditing driver + Component string + RequestId string `json:"rqid"` + Type EntryType + Timestamp time.Time - User string `db:"userid"` - Tenant string `db:"tenant"` + User string + Tenant string // For `EntryDetailHTTP` the HTTP method get, post, put, delete, ... // For `EntryDetailGRPC` unary, stream - Detail EntryDetail `db:"detail"` + Detail EntryDetail // e.g. Request, Response, Error, Opened, Close - Phase EntryPhase `db:"phase"` + Phase EntryPhase // For `EntryDetailHTTP` /api/v1/... // For `EntryDetailGRPC` /api.v1/... (the method name) - Path string `db:"path"` - ForwardedFor string `db:"forwardedfor"` - RemoteAddr string `db:"remoteaddr"` + Path string + ForwardedFor string + RemoteAddr string - Body any `db:"body"` // JSON, string or numbers - StatusCode int `db:"statuscode"` // for `EntryDetailHTTP` the HTTP status code, for EntryDetailGRPC` the grpc status code + Body any // JSON, string or numbers + StatusCode int // for `EntryDetailHTTP` the HTTP status code, for EntryDetailGRPC` the grpc status code // Internal errors - Error string `db:"error"` + Error error } func (e *Entry) prepareForNextPhase() { e.Id = "" e.Timestamp = time.Now() e.Body = nil - e.Error = "" + e.Error = nil switch e.Phase { case EntryPhaseRequest: @@ -133,7 +133,7 @@ type Auditing interface { // Searches for entries matching the given filter. // By default only recent entries will be returned. // The returned entries will be sorted by timestamp in descending order. - Search(EntryFilter) ([]Entry, error) + Search(context.Context, EntryFilter) ([]Entry, error) } func defaultComponent() (string, error) { diff --git a/auditing/meilisearch.go b/auditing/meilisearch.go index df29b78..aa2fb3b 100644 --- a/auditing/meilisearch.go +++ b/auditing/meilisearch.go @@ -129,7 +129,7 @@ func (a *meiliAuditing) Index(entry Entry) error { return nil } -func (a *meiliAuditing) Search(filter EntryFilter) ([]Entry, error) { +func (a *meiliAuditing) Search(_ context.Context, filter EntryFilter) ([]Entry, error) { predicates := make([]string, 0) if filter.Component != "" { predicates = append(predicates, fmt.Sprintf("component = %q", filter.Component)) @@ -274,8 +274,8 @@ func (a *meiliAuditing) encodeEntry(entry Entry) map[string]any { if entry.StatusCode != 0 { doc["status-code"] = entry.StatusCode } - if entry.Error != "" { - doc["error"] = entry.Error + if entry.Error != nil { + doc["error"] = entry.Error.Error() } if entry.Body != nil { doc["body"] = entry.Body @@ -347,7 +347,7 @@ func (a *meiliAuditing) decodeEntry(doc map[string]any) Entry { entry.StatusCode = int(statusCode) } if err, ok := doc["error"].(string); ok { - entry.Error = err + entry.Error = errors.New(err) } if body, ok := doc["body"]; ok { entry.Body = body diff --git a/auditing/meilisearch_integration_test.go b/auditing/meilisearch_integration_test.go index 0d2f9d7..7220880 100644 --- a/auditing/meilisearch_integration_test.go +++ b/auditing/meilisearch_integration_test.go @@ -70,6 +70,7 @@ func StartMeilisearch(t testing.TB) (container testcontainers.Container, c *conn } func TestAuditing_Meilisearch(t *testing.T) { + ctx := context.Background() container, c, err := StartMeilisearch(t) require.NoError(t, err) defer func() { @@ -99,7 +100,7 @@ func TestAuditing_Meilisearch(t *testing.T) { RemoteAddr: "10.0.0.0", Body: "This is the body of 00000000-0000-0000-0000-000000000000", StatusCode: 200, - Error: "", + Error: nil, }, { Component: "auditing.test", @@ -115,7 +116,7 @@ func TestAuditing_Meilisearch(t *testing.T) { RemoteAddr: "10.0.0.1", Body: "This is the body of 00000000-0000-0000-0000-000000000001", StatusCode: 201, - Error: "", + Error: nil, }, { Component: "auditing.test", @@ -131,7 +132,7 @@ func TestAuditing_Meilisearch(t *testing.T) { RemoteAddr: "10.0.0.2", Body: "This is the body of 00000000-0000-0000-0000-000000000002", StatusCode: 0, - Error: "", + Error: nil, }, } } @@ -143,7 +144,7 @@ func TestAuditing_Meilisearch(t *testing.T) { { name: "no entries, no search results", t: func(t *testing.T, a Auditing) { - entries, err := a.Search(EntryFilter{}) + entries, err := a.Search(ctx, EntryFilter{}) require.NoError(t, err) assert.Empty(t, entries) }, @@ -158,7 +159,7 @@ func TestAuditing_Meilisearch(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ Body: "test", }) require.NoError(t, err) @@ -177,7 +178,7 @@ func TestAuditing_Meilisearch(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{}) + entries, err := a.Search(ctx, EntryFilter{}) require.NoError(t, err) assert.Len(t, entries, len(es)) @@ -187,7 +188,7 @@ func TestAuditing_Meilisearch(t *testing.T) { t.Errorf("diff (+got -want):\n %s", diff) } - entries, err = a.Search(EntryFilter{ + entries, err = a.Search(ctx, EntryFilter{ Body: "This", }) require.NoError(t, err) @@ -206,7 +207,7 @@ func TestAuditing_Meilisearch(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ RequestId: es[0].RequestId, }) require.NoError(t, err) @@ -234,7 +235,7 @@ func TestAuditing_Meilisearch(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ Phase: EntryPhaseResponse, }) require.NoError(t, err) @@ -259,7 +260,7 @@ func TestAuditing_Meilisearch(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ // we want to run a phrase search as otherwise we return the other entries as well // https://www.meilisearch.com/docs/reference/api/search#phrase-search-2 Body: fmt.Sprintf("%q", es[0].Body.(string)), diff --git a/auditing/timescaledb.go b/auditing/timescaledb.go index 6b19833..821273a 100644 --- a/auditing/timescaledb.go +++ b/auditing/timescaledb.go @@ -3,6 +3,8 @@ package auditing import ( "context" "database/sql" + "encoding/json" + "errors" "fmt" "log/slog" "reflect" @@ -17,22 +19,50 @@ import ( _ "github.com/lib/pq" ) -type TimescaleDbConfig struct { - Host string - Port string - DB string - User string - Password string -} - -type timescaleAuditing struct { - component string - db *sqlx.DB - log *slog.Logger +type ( + TimescaleDbConfig struct { + Host string + Port string + DB string + User string + Password string + } + + timescaleAuditing struct { + component string + db *sqlx.DB + log *slog.Logger + + cols []string + vals []any + } + + // to keep the public interface free from field tags like "db" and "json" (as these might differ for different dbs) + // we introduce an internal type. unfortunately, this requires a conversion, which takes effort to maintain + timescaleEntry struct { + Component string `db:"component"` + RequestId string `db:"rqid" json:"rqid"` + Type EntryType `db:"type"` + Timestamp time.Time `db:"timestamp"` + User string `db:"userid"` + Tenant string `db:"tenant"` + Detail EntryDetail `db:"detail"` + Phase EntryPhase `db:"phase"` + Path string `db:"path"` + ForwardedFor string `db:"forwardedfor"` + RemoteAddr string `db:"remoteaddr"` + Body any `db:"body"` + StatusCode int `db:"statuscode"` + Error string `db:"error" json:"-"` + } + + sqlCompOp string +) - cols []string - vals []any -} +const ( + equals sqlCompOp = "equals" + like sqlCompOp = "like" +) func NewTimescaleDB(c Config, tc TimescaleDbConfig) (Auditing, error) { if c.Component == "" { @@ -187,12 +217,15 @@ func (a *timescaleAuditing) Index(entry Entry) error { return err } + internalEntry, err := a.toInternal(entry) + if err != nil { + return fmt.Errorf("unable to convert audit trace to database entry: %w", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - entry.Id = entry.RequestId - - _, err = a.db.NamedExecContext(ctx, q, entry) + _, err = a.db.NamedExecContext(ctx, q, internalEntry) if err != nil { return fmt.Errorf("unable to index audit trace: %w", err) } @@ -200,18 +233,11 @@ func (a *timescaleAuditing) Index(entry Entry) error { return nil } -type compOp string - -const ( - equals compOp = "equals" - like compOp = "like" -) - -func (a *timescaleAuditing) Search(filter EntryFilter) ([]Entry, error) { +func (a *timescaleAuditing) Search(ctx context.Context, filter EntryFilter) ([]Entry, error) { var ( where []string values = map[string]interface{}{} - addFilter = func(field string, value any, op compOp) error { + addFilter = func(field string, value any, op sqlCompOp) error { if reflect.ValueOf(value).IsZero() { return nil } @@ -303,7 +329,7 @@ func (a *timescaleAuditing) Search(filter EntryFilter) ([]Entry, error) { return nil, err } - rows, err := a.db.NamedQueryContext(context.TODO(), q, values) // TODO: search needs a ctx! + rows, err := a.db.NamedQueryContext(ctx, q, values) if err != nil { return nil, err } @@ -312,17 +338,58 @@ func (a *timescaleAuditing) Search(filter EntryFilter) ([]Entry, error) { var entries []Entry for rows.Next() { - var e Entry + var e timescaleEntry err = rows.StructScan(&e) if err != nil { return nil, err } - e.Id = e.RequestId + entry, err := a.toExternal(e) + if err != nil { + return nil, fmt.Errorf("unable to convert entry: %w", err) + } - entries = append(entries, e) + entries = append(entries, entry) } return entries, nil } + +func (_ *timescaleAuditing) toInternal(e Entry) (*timescaleEntry, error) { + intermediate, err := json.Marshal(e) // nolint + if err != nil { + return nil, err + } + var internalEntry timescaleEntry + err = json.Unmarshal(intermediate, &internalEntry) // nolint + if err != nil { + return nil, err + } + + internalEntry.RequestId = e.RequestId + if e.Error != nil { + internalEntry.Error = e.Error.Error() + } + + return &internalEntry, nil +} + +func (_ *timescaleAuditing) toExternal(e timescaleEntry) (Entry, error) { + intermediate, err := json.Marshal(e) // nolint + if err != nil { + return Entry{}, err + } + var externalEntry Entry + err = json.Unmarshal(intermediate, &externalEntry) // nolint + if err != nil { + return Entry{}, err + } + + externalEntry.Id = e.RequestId + if e.Error != "" { + externalEntry.Error = errors.New(e.Error) + } + + return externalEntry, nil +} diff --git a/auditing/timescaledb_integration_test.go b/auditing/timescaledb_integration_test.go index 12e6635..c6a188d 100644 --- a/auditing/timescaledb_integration_test.go +++ b/auditing/timescaledb_integration_test.go @@ -20,6 +20,7 @@ import ( ) func TestAuditing_TimescaleDB(t *testing.T) { + ctx := context.Background() container, auditing := StartTimescaleDB(t, Config{ Log: slog.Default(), }) @@ -50,7 +51,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { RemoteAddr: "10.0.0.0", Body: "This is the body of 00000000-0000-0000-0000-000000000000", StatusCode: 200, - Error: "", + Error: nil, }, { Component: "auditing.test", @@ -66,7 +67,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { RemoteAddr: "10.0.0.1", Body: "This is the body of 00000000-0000-0000-0000-000000000001", StatusCode: 201, - Error: "", + Error: nil, }, { Component: "auditing.test", @@ -82,7 +83,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { RemoteAddr: "10.0.0.2", Body: "This is the body of 00000000-0000-0000-0000-000000000002", StatusCode: 0, - Error: "", + Error: nil, }, } } @@ -94,7 +95,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { { name: "no entries, no search results", t: func(t *testing.T, a Auditing) { - entries, err := a.Search(EntryFilter{}) + entries, err := a.Search(ctx, EntryFilter{}) require.NoError(t, err) assert.Empty(t, entries) }, @@ -109,7 +110,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ Body: "test", }) require.NoError(t, err) @@ -128,7 +129,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { err := a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{}) + entries, err := a.Search(ctx, EntryFilter{}) require.NoError(t, err) assert.Len(t, entries, len(es)) @@ -138,7 +139,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { t.Errorf("diff (+got -want):\n %s", diff) } - entries, err = a.Search(EntryFilter{ + entries, err = a.Search(ctx, EntryFilter{ Body: "This", }) require.NoError(t, err) @@ -157,7 +158,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { err := a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ RequestId: es[0].RequestId, }) require.NoError(t, err) @@ -185,7 +186,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { err := a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ Phase: EntryPhaseResponse, }) require.NoError(t, err) @@ -210,7 +211,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { // err := a.Flush() // require.NoError(t, err) - // entries, err := a.Search(EntryFilter{ + // entries, err := a.Search(ctx, EntryFilter{ // Body: fmt.Sprintf("%q", es[0].Body.(string)), // }) // require.NoError(t, err)