Skip to content

Commit

Permalink
change store to support search
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Goodman <[email protected]>
  • Loading branch information
wagoodman committed Dec 18, 2024
1 parent 69330e5 commit 01f1def
Show file tree
Hide file tree
Showing 12 changed files with 660 additions and 276 deletions.
59 changes: 47 additions & 12 deletions grype/db/internal/gormadapter/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,71 @@ package gormadapter

import (
"context"
"fmt"
"time"

"gorm.io/gorm/logger"

anchoreLogger "github.com/anchore/go-logger"
"github.com/anchore/grype/internal/log"
)

// logAdapter is meant to adapt the gorm logger interface (see https://github.com/go-gorm/gorm/blob/v1.25.12/logger/logger.go)
// to the anchore logger interface.
type logAdapter struct {
debug bool
slowThreshold time.Duration
level logger.LogLevel
}

func newLogger() logger.Interface {
return logAdapter{}
// LogMode sets the log level for the logger and returns a new instance
func (l *logAdapter) LogMode(level logger.LogLevel) logger.Interface {
newlogger := *l
newlogger.level = level
return &newlogger
}

func (l logAdapter) LogMode(logger.LogLevel) logger.Interface {
return l
}

func (l logAdapter) Info(_ context.Context, _ string, _ ...interface{}) {
// unimplemented
func (l logAdapter) Info(_ context.Context, fmt string, v ...interface{}) {
if l.level >= logger.Info {
if l.debug {
log.Infof("[sql] "+fmt, v...)
}
}
}

func (l logAdapter) Warn(_ context.Context, fmt string, v ...interface{}) {
log.Warnf("gorm: "+fmt, v...)
if l.level >= logger.Warn {
log.Warnf("[sql] "+fmt, v...)
}
}

func (l logAdapter) Error(_ context.Context, fmt string, v ...interface{}) {
log.Errorf("gorm: "+fmt, v...)
if l.level >= logger.Error {
log.Errorf("[sql] "+fmt, v...)
}
}

func (l logAdapter) Trace(_ context.Context, _ time.Time, _ func() (sql string, rowsAffected int64), _ error) {
// unimplemented
// Trace logs the SQL statement and the duration it took to run the statement
func (l logAdapter) Trace(_ context.Context, t time.Time, fn func() (sql string, rowsAffected int64), _ error) {
if l.level <= logger.Silent {
return
}

if l.debug {
sql, rowsAffected := fn()
elapsed := time.Since(t)
fields := anchoreLogger.Fields{
"rows": rowsAffected,
"duration": elapsed,
}

isSlow := l.slowThreshold != 0 && elapsed > l.slowThreshold
if isSlow {
fields["is-slow"] = isSlow
fields["slow-threshold"] = fmt.Sprintf("> %s", l.slowThreshold)
log.WithFields(fields).Warnf("[sql] %s", sql)
} else {
log.WithFields(fields).Tracef("[sql] %s", sql)
}
}
}
36 changes: 27 additions & 9 deletions grype/db/internal/gormadapter/open.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"strings"
"time"

"github.com/glebarez/sqlite"
"gorm.io/gorm"
Expand Down Expand Up @@ -34,6 +35,7 @@ var readConnectionOptions = []string{
}

type config struct {
debug bool
path string
writable bool
truncate bool
Expand All @@ -46,6 +48,12 @@ type config struct {

type Option func(*config)

func WithDebug(debug bool) Option {
return func(c *config) {
c.debug = debug
}
}

func WithTruncate(truncate bool, models []any, initialData []any) Option {
return func(c *config) {
c.truncate = truncate
Expand All @@ -70,9 +78,10 @@ func WithModels(models []any) Option {
}
}

func WithWritable(write bool) Option {
func WithWritable(write bool, models []any) Option {
return func(c *config) {
c.writable = write
c.models = models
}
}

Expand Down Expand Up @@ -129,55 +138,64 @@ func Open(path string, options ...Option) (*gorm.DB, error) {
}
}

dbObj, err := gorm.Open(sqlite.Open(cfg.connectionString()), &gorm.Config{Logger: newLogger()})
dbObj, err := gorm.Open(sqlite.Open(cfg.connectionString()), &gorm.Config{Logger: &logAdapter{
debug: cfg.debug,
slowThreshold: 400 * time.Millisecond,
}})
if err != nil {
return nil, fmt.Errorf("unable to connect to DB: %w", err)
}

return prepareDB(dbObj, cfg)
}

func prepareDB(dbObj *gorm.DB, cfg config) (*gorm.DB, error) {
if cfg.writable {
log.Trace("applying writable DB statements")
log.Trace("using writable DB statements")
if err := applyStatements(dbObj, writerStatements); err != nil {
return nil, fmt.Errorf("unable to apply DB writer statements: %w", err)
}
}

if cfg.truncate && cfg.allowLargeMemoryFootprint {
log.Trace("applying large memory footprint DB statements")
log.Trace("using large memory footprint DB statements")
if err := applyStatements(dbObj, heavyWriteStatements); err != nil {
return nil, fmt.Errorf("unable to apply DB heavy writer statements: %w", err)
}
}

if len(commonStatements) > 0 {
log.Trace("applying common DB statements")
if err := applyStatements(dbObj, commonStatements); err != nil {
return nil, fmt.Errorf("unable to apply DB common statements: %w", err)
}
}

if len(cfg.statements) > 0 {
log.Trace("applying custom DB statements")
if err := applyStatements(dbObj, cfg.statements); err != nil {
return nil, fmt.Errorf("unable to apply DB custom statements: %w", err)
}
}

if len(cfg.models) > 0 {
if len(cfg.models) > 0 && cfg.writable {
log.Trace("applying DB migrations")
if err := dbObj.AutoMigrate(cfg.models...); err != nil {
return nil, fmt.Errorf("unable to migrate: %w", err)
}
}

if len(cfg.initialData) > 0 {
log.Trace("applying initial data")
if len(cfg.initialData) > 0 && cfg.truncate {
log.Trace("writing initial data")
for _, d := range cfg.initialData {
if err := dbObj.Create(d).Error; err != nil {
return nil, fmt.Errorf("unable to create initial data: %w", err)
}
}
}

if cfg.debug {
dbObj = dbObj.Debug()
}

return dbObj, nil
}

Expand Down
78 changes: 54 additions & 24 deletions grype/db/v6/affected_cpe_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package v6

import (
"fmt"
"time"

"gorm.io/gorm"

Expand All @@ -22,7 +23,8 @@ type GetAffectedCPEOptions struct {
PreloadCPE bool
PreloadVulnerability bool
PreloadBlob bool
Vulnerability *VulnerabilitySpecifier
Vulnerabilities []VulnerabilitySpecifier
Limit int
}

type affectedCPEStore struct {
Expand Down Expand Up @@ -63,42 +65,62 @@ func (s *affectedCPEStore) GetAffectedCPEs(cpe *cpe.Attributes, config *GetAffec
} else {
fields["cpe"] = cpe.String()
}
log.WithFields(fields).Trace("fetching AffectedCPE record")
start := time.Now()
defer func() {
fields["duration"] = time.Since(start)
log.WithFields(fields).Trace("fetched affected CPE record")
}()

query := s.handleCPE(s.db, cpe)

var err error
query, err = s.handleVulnerabilityOptions(query, config.Vulnerability)
query, err = s.handleVulnerabilityOptions(query, config.Vulnerabilities)
if err != nil {
return nil, err
}

query = s.handlePreload(query, *config)

var pkgs []AffectedCPEHandle
if err = query.Find(&pkgs).Error; err != nil {
return nil, fmt.Errorf("unable to fetch affected package record: %w", err)
}
var models []AffectedCPEHandle

if config.PreloadBlob {
for i := range pkgs {
err := s.blobStore.attachBlobValue(&pkgs[i])
if err != nil {
return nil, fmt.Errorf("unable to attach blob %#v: %w", pkgs[i], err)
var results []*AffectedCPEHandle
if err := query.FindInBatches(&results, batchSize, func(_ *gorm.DB, _ int) error { // nolint:dupl
if config.PreloadBlob {
var blobs []blobable
for _, r := range results {
blobs = append(blobs, r)
}
if err := s.blobStore.attachBlobValue(blobs...); err != nil {
return fmt.Errorf("unable to attach blobs: %w", err)
}
}
}

if config.PreloadVulnerability {
for i := range pkgs {
err := s.blobStore.attachBlobValue(pkgs[i].Vulnerability)
if err != nil {
return nil, fmt.Errorf("unable to attach vulnerability blob %#v: %w", pkgs[i], err)
if config.PreloadVulnerability {
var vulns []blobable
for _, r := range results {
if r.Vulnerability != nil {
vulns = append(vulns, r.Vulnerability)
}
}
if err := s.blobStore.attachBlobValue(vulns...); err != nil {
return fmt.Errorf("unable to attach vulnerability blob: %w", err)
}
}

for _, r := range results {
models = append(models, *r)
}

if config.Limit > 0 && len(models) >= config.Limit {
return ErrLimitReached
}

return nil
}).Error; err != nil {
return models, fmt.Errorf("unable to fetch affected CPE records: %w", err)
}

return pkgs, nil
return models, nil
}

func (s *affectedCPEStore) handleCPE(query *gorm.DB, c *cpe.Attributes) *gorm.DB {
Expand All @@ -110,23 +132,31 @@ func (s *affectedCPEStore) handleCPE(query *gorm.DB, c *cpe.Attributes) *gorm.DB
return handleCPEOptions(query, c)
}

func (s *affectedCPEStore) handleVulnerabilityOptions(query *gorm.DB, config *VulnerabilitySpecifier) (*gorm.DB, error) {
if config == nil {
func (s *affectedCPEStore) handleVulnerabilityOptions(query *gorm.DB, configs []VulnerabilitySpecifier) (*gorm.DB, error) {
if len(configs) == 0 {
return query, nil
}

query = query.Joins("JOIN vulnerability_handles ON affected_cpe_handles.vulnerability_id = vulnerability_handles.id")

return handleVulnerabilityOptions(query, config)
return handleVulnerabilityOptions(s.db, query, configs...)
}

func (s *affectedCPEStore) handlePreload(query *gorm.DB, config GetAffectedCPEOptions) *gorm.DB {
var limitArgs []interface{}
if config.Limit > 0 {
query = query.Limit(config.Limit)
limitArgs = append(limitArgs, func(db *gorm.DB) *gorm.DB {
return db.Limit(config.Limit)
})
}

if config.PreloadCPE {
query = query.Preload("CPE")
query = query.Preload("CPE", limitArgs...)
}

if config.PreloadVulnerability {
query = query.Preload("Vulnerability").Preload("Vulnerability.Provider")
query = query.Preload("Vulnerability", limitArgs...).Preload("Vulnerability.Provider", limitArgs...)
}

return query
Expand Down
Loading

0 comments on commit 01f1def

Please sign in to comment.