Skip to content

Commit

Permalink
add ability to map CPEs directly to packages
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Goodman <[email protected]>
  • Loading branch information
wagoodman committed Nov 25, 2024
1 parent e17dd0c commit 6716371
Show file tree
Hide file tree
Showing 4 changed files with 708 additions and 178 deletions.
63 changes: 61 additions & 2 deletions grype/db/v6/affected_cpe_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestAffectedCPEStore_AddAffectedCPEs(t *testing.T) {
VulnerabilityID: 1,
CpeID: 1,
CPE: &Cpe{
Type: "a",
Part: "a",
Vendor: "vendor-1",
Product: "product-1",
Edition: "edition-1",
Expand Down Expand Up @@ -80,10 +80,69 @@ func TestAffectedCPEStore_GetCPEsByProduct(t *testing.T) {
}
}

func TestAffectedCPEStore_PreventDuplicateCPEs(t *testing.T) {
db := setupTestStore(t).db
bw := newBlobStore(db)
s := newAffectedCPEStore(db, bw)

cpe1 := &AffectedCPEHandle{
VulnerabilityID: 1,
CpeID: 1,
CPE: &Cpe{
Part: "a",
Vendor: "vendor-1",
Product: "product-1",
Edition: "edition-1",
},
BlobValue: &AffectedPackageBlob{
CVEs: []string{"CVE-2023-5678"},
},
}

err := s.AddAffectedCPEs(cpe1)
require.NoError(t, err)

// attempt to add a duplicate CPE with the same values
duplicateCPE := &AffectedCPEHandle{
VulnerabilityID: 2, // different VulnerabilityID for testing
CpeID: 2,
CPE: &Cpe{
Part: "a", // same
Vendor: "vendor-1", // same
Product: "product-1", // same
Edition: "edition-1", // same
},
BlobValue: &AffectedPackageBlob{
CVEs: []string{"CVE-2024-1234"},
},
}

err = s.AddAffectedCPEs(duplicateCPE)
require.NoError(t, err)

require.Equal(t, cpe1.CpeID, duplicateCPE.CpeID, "expected the CPE DB ID to be the same")

var existingCPEs []Cpe
err = db.Find(&existingCPEs).Error
require.NoError(t, err)
require.Len(t, existingCPEs, 1, "expected only one CPE to exist")

actualHandles, err := s.GetCPEsByProduct(cpe1.CPE.Product, &GetAffectedCPEOptions{
PreloadCPE: true,
PreloadBlob: true,
})
require.NoError(t, err)
expected := []AffectedCPEHandle{*cpe1, *duplicateCPE}
require.Len(t, actualHandles, len(expected), "expected both handles to be stored")
if d := cmp.Diff(expected, actualHandles); d != "" {
t.Errorf("unexpected result (-want +got):\n%s", d)
}
}

func testAffectedCPEHandle() *AffectedCPEHandle {
return &AffectedCPEHandle{
CPE: &Cpe{
Type: "application",
Part: "application",
Vendor: "vendor",
Product: "product",
Edition: "edition",
Expand Down
232 changes: 161 additions & 71 deletions grype/db/v6/affected_package_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"gorm.io/gorm"

"github.com/anchore/grype/internal/log"
"github.com/anchore/syft/syft/cpe"
)

var NoDistroSpecified = &DistroSpecifier{}
Expand All @@ -18,11 +19,12 @@ var ErrDistroNotPresent = errors.New("distro not present")
var ErrMultipleOSMatches = errors.New("multiple OS matches found but not allowed")

type GetAffectedPackageOptions struct {
PreloadOS bool
PreloadPackage bool
PreloadBlob bool
PackageType string
Distro *DistroSpecifier
PreloadOS bool
PreloadPackage bool
PreloadPackageCPEs bool
PreloadBlob bool
PackageType string
Distro *DistroSpecifier
}

// DistroSpecifier is a struct that represents a distro in a way that can be used to query the affected package store.
Expand Down Expand Up @@ -107,15 +109,6 @@ func newAffectedPackageStore(db *gorm.DB, bs *blobStore) *affectedPackageStore {

func (s *affectedPackageStore) AddAffectedPackages(packages ...*AffectedPackageHandle) error {
for _, v := range packages {
if v.Package != nil {
var existingPackage Package
result := s.db.Where("name = ? AND type = ?", v.Package.Name, v.Package.Type).FirstOrCreate(&existingPackage, v.Package)
if result.Error != nil {
return fmt.Errorf("failed to create package (name=%q type=%q): %w", v.Package.Name, v.Package.Type, result.Error)
}
v.Package = &existingPackage
}

if err := s.blobStore.addBlobable(v); err != nil {
return fmt.Errorf("unable to add affected blob: %w", err)
}
Expand All @@ -126,33 +119,89 @@ func (s *affectedPackageStore) AddAffectedPackages(packages ...*AffectedPackageH
return nil
}

// func (s *affectedPackageStore) writePackage(ah *AffectedPackageHandle) error {
// if ah.Package == nil {
// return nil
// }
// var existingPackage Package
// // check if the package already exists in the table
// result := s.db.Where("type = ? AND name = ?", ah.Package.Type, ah.Package.Name).First(&existingPackage)
// if result.Error != nil {
// if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
// return fmt.Errorf("failed to query package by name and type: %w", result.Error)
// }
//
// // merge CPEs if the package already exists
// for _, c := range ah.Package.CPEs {
// c.PackageID = existingPackage.ID
// err := s.db.FirstOrCreate(&c, c).Error
// if err != nil {
// return fmt.Errorf("failed to create or find CPE (type=%q vendor=%q product=%q): %w", c.Type, c.Vendor, c.Product, err)
// }
// }
// // Use the existing package instead of creating a new one
// ah.Package = &existingPackage
// return nil
// }
//
// // insert the package and its CPEs if it doesn't exist
// err := s.db.Create(ah.Package).Error
// if err != nil {
// return fmt.Errorf("failed to create package (name=%q type=%q): %w", ah.Package.Name, ah.Package.Type, err)
// }
// for _, c := range ah.Package.CPEs {
// c.PackageID = ah.Package.ID
// err := s.db.Create(&c).Error
// if err != nil {
// return fmt.Errorf("failed to create CPE (type=%q vendor=%q product=%q): %w", c.Type, c.Vendor, c.Product, err)
// }
// }
//
// return nil
//}

func (s *affectedPackageStore) GetAffectedPackagesByName(packageName string, config *GetAffectedPackageOptions) ([]AffectedPackageHandle, error) {
if config == nil {
config = &GetAffectedPackageOptions{}
}

log.WithFields("name", packageName, "distro", distroDisplay(config.Distro)).Trace("fetching AffectedPackage record")
log.WithFields("name", packageName, "distro", distroDisplay(config.Distro)).Trace("fetching AffectedPackage by name record")

if hasDistroSpecified(config.Distro) {
return s.getPackageByNameAndDistro(packageName, *config)
return s.getAffectedPackagesWithOptions(
s.handlePackageName(s.db, packageName),
config,
)
}

func (s *affectedPackageStore) GetAffectedPackagesByCPE(cpe cpe.Attributes, config *GetAffectedPackageOptions) ([]AffectedPackageHandle, error) {
if config == nil {
config = &GetAffectedPackageOptions{}
}

return s.getNonDistroPackageByName(packageName, *config)
}
log.WithFields("cpe", cpe.String(), "distro", distroDisplay(config.Distro)).Trace("fetching AffectedPackage by CPE record")

func (s *affectedPackageStore) getNonDistroPackageByName(packageName string, config GetAffectedPackageOptions) ([]AffectedPackageHandle, error) {
var pkgs []AffectedPackageHandle
query := s.db.Joins("JOIN packages ON affected_package_handles.package_id = packages.id")
return s.getAffectedPackagesWithOptions(
s.handlePackageCPE(s.db, cpe),
config)
}

if config.Distro != AnyDistroSpecified {
query = query.Where("operating_system_id IS NULL")
func (s *affectedPackageStore) getAffectedPackagesWithOptions(query *gorm.DB, config *GetAffectedPackageOptions) ([]AffectedPackageHandle, error) {
if config == nil {
config = &GetAffectedPackageOptions{}
}

query = s.handlePackage(query, packageName, config)
query = s.handlePreload(query, config)
query = s.handlePackageOptions(query, *config)

err := query.Find(&pkgs).Error
var err error
query, err = s.handleDistroOptions(query, *config)
if err != nil {
return nil, err
}

query = s.handlePreload(query, *config)

var pkgs []AffectedPackageHandle
err = query.Find(&pkgs).Error
if err != nil {
return nil, fmt.Errorf("unable to fetch non-distro affected package record: %w", err)
}
Expand All @@ -169,10 +218,74 @@ func (s *affectedPackageStore) getNonDistroPackageByName(packageName string, con
return pkgs, nil
}

func (s *affectedPackageStore) getPackageByNameAndDistro(packageName string, config GetAffectedPackageOptions) ([]AffectedPackageHandle, error) {
func (s *affectedPackageStore) handlePackageName(query *gorm.DB, packageName string) *gorm.DB {
return query.Joins("JOIN packages ON affected_package_handles.package_id = packages.id").Where("packages.name = ?", packageName)
}

func (s *affectedPackageStore) handlePackageCPE(query *gorm.DB, c cpe.Attributes) *gorm.DB {
query = query.Joins("JOIN packages ON affected_package_handles.package_id = packages.id").Joins("JOIN cpes ON packages.id = cpes.package_id")

if c.Part != cpe.Any {
query = query.Where("cpes.part = ?", c.Part)
}

if c.Vendor != cpe.Any {
query = query.Where("cpes.vendor = ?", c.Vendor)
}

if c.Product != cpe.Any {
query = query.Where("cpes.product = ?", c.Product)
}

if c.Version != cpe.Any {
query = query.Where("cpes.version = ?", c.Version)
}

if c.Update != cpe.Any {
query = query.Where("cpes.update = ?", c.Update)
}

if c.Edition != cpe.Any {
query = query.Where("cpes.edition = ?", c.Edition)
}

if c.Language != cpe.Any {
query = query.Where("cpes.language = ?", c.Language)
}

if c.SWEdition != cpe.Any {
query = query.Where("cpes.sw_edition = ?", c.SWEdition)
}

if c.TargetSW != cpe.Any {
query = query.Where("cpes.target_sw = ?", c.TargetSW)
}

if c.TargetHW != cpe.Any {
query = query.Where("cpes.target_hw = ?", c.TargetHW)
}

if c.Other != cpe.Any {
query = query.Where("cpes.other = ?", c.Other)
}

return query
}

func (s *affectedPackageStore) handlePackageOptions(query *gorm.DB, config GetAffectedPackageOptions) *gorm.DB {
if config.PackageType != "" {
query = query.Where("packages.type = ?", config.PackageType)
}

return query
}

func (s *affectedPackageStore) handleDistroOptions(query *gorm.DB, config GetAffectedPackageOptions) (*gorm.DB, error) {
var resolvedDistros []OperatingSystem
var err error
if config.Distro != NoDistroSpecified || config.Distro != AnyDistroSpecified {

switch {
case hasDistroSpecified(config.Distro):
resolvedDistros, err = s.resolveDistro(*config.Distro)
if err != nil {
return nil, fmt.Errorf("unable to resolve distro: %w", err)
Expand All @@ -184,31 +297,28 @@ func (s *affectedPackageStore) getPackageByNameAndDistro(packageName string, con
case len(resolvedDistros) > 1 && !config.Distro.AllowMultiple:
return nil, ErrMultipleOSMatches
}
case config.Distro == AnyDistroSpecified:
// TODO: one enhancement we may want to do later is "has OS defined but is not specific" which this does NOT cover. This is "may or may not have an OS defined" which is different.
return query, nil
case *config.Distro == *NoDistroSpecified:
return query.Where("operating_system_id IS NULL"), nil
}

var pkgs []AffectedPackageHandle
query := s.db.Joins("JOIN packages ON affected_package_handles.package_id = packages.id").
Joins("JOIN operating_systems ON affected_package_handles.operating_system_id = operating_systems.id")
query = query.Joins("JOIN operating_systems ON affected_package_handles.operating_system_id = operating_systems.id")

query = s.handlePackage(query, packageName, config)
query = s.handleDistros(query, resolvedDistros)
query = s.handlePreload(query, config)

err = query.Find(&pkgs).Error
if err != nil {
return nil, fmt.Errorf("unable to fetch affected package record: %w", err)
}

if config.PreloadBlob {
for i := range pkgs {
err := s.attachBlob(&pkgs[i])
if err != nil {
return nil, fmt.Errorf("unable to attach blob %#v: %w", pkgs[i], err)
var count int
for _, o := range resolvedDistros {
if o.ID != 0 {
if count == 0 {
query = query.Where("operating_systems.id = ?", o.ID)
} else {
query = query.Or("operating_systems.id = ?", o.ID)
}
count++
}
}

return pkgs, nil
return query, nil
}

func (s *affectedPackageStore) resolveDistro(d DistroSpecifier) ([]OperatingSystem, error) {
Expand Down Expand Up @@ -355,33 +465,13 @@ func (s *affectedPackageStore) applyAlias(d *DistroSpecifier) error {
return nil
}

func (s *affectedPackageStore) handlePackage(query *gorm.DB, packageName string, config GetAffectedPackageOptions) *gorm.DB {
query = query.Where("packages.name = ?", packageName)

if config.PackageType != "" {
query = query.Where("packages.type = ?", config.PackageType)
}
return query
}

func (s *affectedPackageStore) handleDistros(query *gorm.DB, resolvedDistros []OperatingSystem) *gorm.DB {
var count int
for _, o := range resolvedDistros {
if o.ID != 0 {
if count == 0 {
query = query.Where("operating_systems.id = ?", o.ID)
} else {
query = query.Or("operating_systems.id = ?", o.ID)
}
count++
}
}
return query
}

func (s *affectedPackageStore) handlePreload(query *gorm.DB, config GetAffectedPackageOptions) *gorm.DB {
if config.PreloadPackage {
query = query.Preload("Package")

if config.PreloadPackageCPEs {
query = query.Preload("Package.CPEs")
}
}

if config.PreloadOS {
Expand Down
Loading

0 comments on commit 6716371

Please sign in to comment.