Skip to content

Commit

Permalink
fix: make compliance jobs use multiple integrations
Browse files Browse the repository at this point in the history
  • Loading branch information
artaasadi committed Dec 17, 2024
1 parent 82dd6ae commit c94120e
Show file tree
Hide file tree
Showing 13 changed files with 253 additions and 199 deletions.
26 changes: 14 additions & 12 deletions jobs/compliance-quick-run-job/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import (
)

type AuditJob struct {
JobID uint
FrameworkID string
IntegrationID string
IncludeResult []string
JobID uint
FrameworkID string
IntegrationIDs []string
IncludeResult []string

JobReportControlSummary *types.ComplianceJobReportControlSummary
JobReportControlView *types.ComplianceJobReportControlView
Expand Down Expand Up @@ -77,18 +77,20 @@ func (w *Worker) RunJob(ctx context.Context, job *AuditJob) error {
totalControls := make(map[string]bool)
failedControls := make(map[string]bool)

err := w.RunJobForIntegration(ctx, job, job.IntegrationID, &totalControls, &failedControls)
if err != nil {
w.logger.Error("failed to run audit job for integration", zap.String("integration_id", job.IntegrationID), zap.Error(err))
return err
for _, integrationID := range job.IntegrationIDs {
err := w.RunJobForIntegration(ctx, job, integrationID, &totalControls, &failedControls)
if err != nil {
w.logger.Error("failed to run audit job for integration", zap.String("integration_id", integrationID), zap.Error(err))
return err
}
w.logger.Info("audit job for integration completed", zap.String("integration_id", integrationID))
}
w.logger.Info("audit job for integration completed", zap.String("integration_id", job.IntegrationID))

keys, idx := job.JobReportControlView.KeysAndIndex()
job.JobReportControlView.EsID = es.HashOf(keys...)
job.JobReportControlView.EsIndex = idx

err = sendDataToOpensearch(w.esClient.ES(), *job.JobReportControlView)
err := sendDataToOpensearch(w.esClient.ES(), *job.JobReportControlView)
if err != nil {
return err
}
Expand Down Expand Up @@ -159,8 +161,8 @@ func (w *Worker) RunJobForIntegration(ctx context.Context, job *AuditJob, integr
queryJob := QueryJob{
AuditJobID: job.JobID,
ExecutionPlan: ExecutionPlan{
Query: *control.Query,
IntegrationID: job.IntegrationID,
Query: *control.Query,
IntegrationIDs: job.IntegrationIDs,
},
}
queryResults, err := w.RunQuery(ctx, queryJob)
Expand Down
12 changes: 6 additions & 6 deletions jobs/compliance-quick-run-job/query_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type QueryResult struct {
type ExecutionPlan struct {
Query complianceApi.Query

IntegrationID string
IntegrationIDs []string
}

type QueryJob struct {
Expand All @@ -41,7 +41,7 @@ type QueryJob struct {
func (w *Worker) RunQuery(ctx context.Context, j QueryJob) ([]QueryResult, error) {
w.logger.Info("Running query",
zap.String("query_id", j.ExecutionPlan.Query.ID),
zap.String("integration_ids", j.ExecutionPlan.IntegrationID),
zap.Strings("integration_ids", j.ExecutionPlan.IntegrationIDs),
)

queryParams, err := w.metadataClient.ListQueryParameters(&httpclient.Context{Ctx: ctx, UserRole: authApi.AdminRole})
Expand All @@ -59,15 +59,15 @@ func (w *Worker) RunQuery(ctx context.Context, j QueryJob) ([]QueryResult, error
w.logger.Error("required query parameter not found",
zap.String("key", param.Key),
zap.String("query_id", j.ExecutionPlan.Query.ID),
zap.String("integration_id", j.ExecutionPlan.IntegrationID),
zap.Strings("integration_id", j.ExecutionPlan.IntegrationIDs),
)
return nil, fmt.Errorf("required query parameter not found: %s for query: %s", param.Key, j.ExecutionPlan.Query.ID)
}
if _, ok := queryParamMap[param.Key]; !ok && !param.Required {
w.logger.Info("optional query parameter not found",
zap.String("key", param.Key),
zap.String("query_id", j.ExecutionPlan.Query.ID),
zap.String("integration_id", j.ExecutionPlan.IntegrationID),
zap.Strings("integration_id", j.ExecutionPlan.IntegrationIDs),
)
queryParamMap[param.Key] = ""
}
Expand Down Expand Up @@ -106,7 +106,7 @@ func (w *Worker) runSqlWorkerJob(ctx context.Context, j QueryJob, queryParamMap
w.logger.Error("failed to execute query template",
zap.Error(err),
zap.String("query_id", j.ExecutionPlan.Query.ID),
zap.String("integration_id", j.ExecutionPlan.IntegrationID),
zap.Strings("integration_id", j.ExecutionPlan.IntegrationIDs),
zap.Uint("job_id", j.AuditJobID),
)
return nil, fmt.Errorf("failed to execute query template: %w for query: %s", err, j.ExecutionPlan.Query.ID)
Expand All @@ -119,7 +119,7 @@ func (w *Worker) runSqlWorkerJob(ctx context.Context, j QueryJob, queryParamMap
zap.String("query", queryOutput.String()))
res, err := w.steampipeConn.QueryAll(ctx, queryOutput.String())
if err != nil {
w.logger.Error("failed to run query", zap.Error(err), zap.String("query_id", j.ExecutionPlan.Query.ID), zap.String("integration_id", j.ExecutionPlan.IntegrationID))
w.logger.Error("failed to run query", zap.Error(err), zap.String("query_id", j.ExecutionPlan.Query.ID), zap.Strings("integration_id", j.ExecutionPlan.IntegrationIDs))
return nil, err
}

Expand Down
8 changes: 4 additions & 4 deletions services/compliance/http_routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -7934,15 +7934,15 @@ func (h HttpHandler) GetJobReportSummary(ctx echo.Context) error {
jobId = strconv.Itoa(int(*complianceJob.SummaryJobId))
}

framework, err := h.db.GetBenchmark(ctx.Request().Context(), complianceJob.BenchmarkId)
framework, err := h.db.GetBenchmark(ctx.Request().Context(), complianceJob.FrameworkId)
if err != nil {
h.logger.Error("failed to get framework by frameworkID", zap.String("framework", complianceJob.BenchmarkId), zap.Error(err))
h.logger.Error("failed to get framework by frameworkID", zap.String("framework", complianceJob.FrameworkId), zap.Error(err))
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get framework by frameworkID")
}

controlsMap, err := h.getControlsUnderBenchmark(ctx.Request().Context(), complianceJob.BenchmarkId, make(map[string]BenchmarkControlsCache))
controlsMap, err := h.getControlsUnderBenchmark(ctx.Request().Context(), complianceJob.FrameworkId, make(map[string]BenchmarkControlsCache))
if err != nil {
h.logger.Error("failed to get controls under benchmark", zap.String("framework", complianceJob.BenchmarkId), zap.Error(err))
h.logger.Error("failed to get controls under benchmark", zap.String("framework", complianceJob.FrameworkId), zap.Error(err))
return echo.NewHTTPError(http.StatusInternalServerError, "could not get framework by frameworkID")
}
var controlsStr []string
Expand Down
18 changes: 9 additions & 9 deletions services/describe/api/jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ type GetComplianceJobsHistoryResponse struct {
BenchmarkId string `json:"benchmark_id"`
JobStatus ComplianceJobStatus `json:"job_status"`
DateTime time.Time `json:"date_time"`
IntegrationInfo IntegrationInfo `json:"integration_info"`
IntegrationInfo []IntegrationInfo `json:"integration_info"`
}

type BenchmarkAuditHistoryItem struct {
Expand Down Expand Up @@ -209,14 +209,14 @@ type GetDescribeJobStatusResponse struct {
}

type GetComplianceJobStatusResponse struct {
JobId uint `json:"job_id"`
WithIncidents bool `json:"with_incidents"`
SummaryJobId *uint `json:"summary_job_id"`
IntegrationInfo IntegrationInfo `json:"integration_info"`
JobStatus string `json:"job_status"`
BenchmarkId string `json:"benchmark_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
JobId uint `json:"job_id"`
WithIncidents bool `json:"with_incidents"`
SummaryJobId *uint `json:"summary_job_id"`
IntegrationInfo []IntegrationInfo `json:"integration_info"`
JobStatus string `json:"job_status"`
FrameworkId string `json:"framework_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}

type GetAsyncQueryRunJobStatusResponse struct {
Expand Down
41 changes: 23 additions & 18 deletions services/describe/db/compliance_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package db
import (
"errors"
"fmt"
"github.com/lib/pq"
"math/rand"
"time"

Expand Down Expand Up @@ -114,11 +115,11 @@ func (db Database) CleanupComplianceJobsOlderThan(t time.Time) error {
return nil
}

func (db Database) GetLastComplianceJob(withIncidents bool, benchmarkID string) (*model.ComplianceJob, error) {
func (db Database) GetLastComplianceJob(withIncidents bool, frameworkID string) (*model.ComplianceJob, error) {
var job model.ComplianceJob
tx := db.ORM.Model(&model.ComplianceJob{}).
Where("with_incidents = ?", withIncidents).
Where("benchmark_id = ?", benchmarkID).Order("created_at DESC").First(&job)
Where("framework_id = ?", frameworkID).Order("created_at DESC").First(&job)
if tx.Error != nil {
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
Expand Down Expand Up @@ -187,7 +188,7 @@ func (db Database) ListComplianceJobsForInterval(withIncidents *bool, interval,
return job, nil
}

func (db Database) ListComplianceJobsWithSummaryJob(withIncidents *bool, interval, triggerType, createdBy string, benchmarkIDs []string) ([]model.ComplianceJobWithSummarizerJob, error) {
func (db Database) ListComplianceJobsWithSummaryJob(withIncidents *bool, interval, triggerType, createdBy string, frameworkIDs []string) ([]model.ComplianceJobWithSummarizerJob, error) {
var result []model.ComplianceJobWithSummarizerJob

// Base query
Expand All @@ -196,9 +197,9 @@ func (db Database) ListComplianceJobsWithSummaryJob(withIncidents *bool, interva
compliance_jobs.id,
compliance_jobs.created_at,
compliance_jobs.updated_at,
compliance_jobs.benchmark_id,
compliance_jobs.framework_id,
compliance_jobs.status,
compliance_jobs.integration_id,
compliance_jobs.integration_ids,
compliance_jobs.trigger_type,
compliance_jobs.created_by,
COALESCE(array_agg(COALESCE(compliance_summarizers.id::text, '')), '{}') as summarizer_jobs
Expand All @@ -220,8 +221,8 @@ func (db Database) ListComplianceJobsWithSummaryJob(withIncidents *bool, interva
if createdBy != "" {
tx = tx.Where("compliance_jobs.created_by = ?", createdBy)
}
if len(benchmarkIDs) > 0 {
tx = tx.Where("compliance_jobs.benchmark_id IN ?", benchmarkIDs)
if len(frameworkIDs) > 0 {
tx = tx.Where("compliance_jobs.framework_id IN ?", frameworkIDs)
}

// Execute the query
Expand All @@ -237,7 +238,7 @@ func (db Database) ListComplianceJobsWithSummaryJob(withIncidents *bool, interva

func (db Database) ListComplianceJobsByIntegrationID(withIncidents *bool, integrationIds []string) ([]model.ComplianceJob, error) {
var job []model.ComplianceJob
tx := db.ORM.Model(&model.ComplianceJob{}).Where("integration_id IN ?", integrationIds)
tx := db.ORM.Model(&model.ComplianceJob{}).Where("integration_ids && ?", pq.Array(integrationIds))
if withIncidents != nil {
tx = tx.Where("with_incidents = ?", *withIncidents)
}
Expand All @@ -254,7 +255,7 @@ func (db Database) ListComplianceJobsByIntegrationID(withIncidents *bool, integr
func (db Database) ListPendingComplianceJobsByIntegrationID(withIncidents *bool, integrationIds []string) ([]model.ComplianceJob, error) {
var job []model.ComplianceJob
tx := db.ORM.Model(&model.ComplianceJob{}).
Where("integration_id IN ?", integrationIds).
Where("integration_ids && ?", integrationIds).
Where("status IN ?", []model.ComplianceJobStatus{model.ComplianceJobCreated, model.ComplianceJobRunnersInProgress})
if withIncidents != nil {
tx = tx.Where("with_incidents = ?", *withIncidents)
Expand All @@ -269,9 +270,9 @@ func (db Database) ListPendingComplianceJobsByIntegrationID(withIncidents *bool,
return job, nil
}

func (db Database) ListComplianceJobsByBenchmarkID(withIncidents *bool, benchmarkIds []string) ([]model.ComplianceJob, error) {
func (db Database) ListComplianceJobsByFrameworkID(withIncidents *bool, frameworkIDs []string) ([]model.ComplianceJob, error) {
var job []model.ComplianceJob
tx := db.ORM.Model(&model.ComplianceJob{}).Where("benchmark_id IN ?", benchmarkIds)
tx := db.ORM.Model(&model.ComplianceJob{}).Where("framework_id IN ?", frameworkIDs)
if withIncidents != nil {
tx = tx.Where("with_incidents = ?", *withIncidents)
}
Expand Down Expand Up @@ -422,7 +423,7 @@ SELECT * FROM compliance_jobs j WHERE status = 'SUMMARIZER_IN_PROGRESS' AND with
return jobs, nil
}

func (db Database) ListComplianceJobsByFilters(withIncidents *bool, integrationId []string, benchmarkId []string, status []string,
func (db Database) ListComplianceJobsByFilters(withIncidents *bool, integrationId []string, frameworkId []string, status []string,
startTime, endTime *time.Time) ([]model.ComplianceJob, error) {
var jobs []model.ComplianceJob
tx := db.ORM.Model(&model.ComplianceJob{})
Expand All @@ -432,11 +433,11 @@ func (db Database) ListComplianceJobsByFilters(withIncidents *bool, integrationI
}

if len(integrationId) > 0 {
tx = tx.Where("integration_id IN ?", integrationId)
tx = tx.Where("integration_ids && ?", pq.Array(integrationId))
}

if len(benchmarkId) > 0 {
tx = tx.Where("benchmark_id IN ?", benchmarkId)
if len(frameworkId) > 0 {
tx = tx.Where("framework_id IN ?", frameworkId)
}
if len(status) > 0 {
tx = tx.Where("status IN ?", status)
Expand All @@ -458,14 +459,18 @@ func (db Database) ListComplianceJobsByFilters(withIncidents *bool, integrationI

func (db Database) GetComplianceJobsIntegrations() ([]string, error) {
var uniqueIntegrationIDs []string
if err := db.ORM.Model(&model.ComplianceJob{}).Distinct("integration_id").Pluck("integration_id", &uniqueIntegrationIDs).Error; err != nil {
query := `
SELECT DISTINCT unnest(integration_ids) AS integration
FROM compliance_jobs
`
if err := db.ORM.Raw(query).Pluck("integration", &uniqueIntegrationIDs).Error; err != nil {
return nil, err
}
return uniqueIntegrationIDs, nil
}

func (db Database) CleanupAllComplianceJobsForIntegrations(integrations []string) error {
tx := db.ORM.Where("integration_id IN ?", integrations).Unscoped().Delete(&model.ComplianceJob{})
func (db Database) CleanupAllComplianceJobsForIntegrations(integrationId []string) error {
tx := db.ORM.Where("integration_ids && ?", pq.Array(integrationId)).Unscoped().Delete(&model.ComplianceJob{})
if tx.Error != nil {
return tx.Error
}
Expand Down
10 changes: 5 additions & 5 deletions services/describe/db/model/compliance_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ func (c ComplianceJobStatus) ToApi() api.ComplianceJobStatus {

type ComplianceJob struct {
gorm.Model
BenchmarkID string
FrameworkID string
WithIncidents bool
Status ComplianceJobStatus
IncludeResults pq.StringArray `gorm:"type:text[]"`
AreAllRunnersQueued bool
IntegrationID string
IntegrationIDs pq.StringArray `gorm:"type:text[]"`
FailureMessage string
TriggerType ComplianceTriggerType
ParentID *uint
Expand All @@ -51,7 +51,7 @@ type ComplianceJob struct {
func (c ComplianceJob) ToApi() api.ComplianceJob {
return api.ComplianceJob{
ID: c.ID,
BenchmarkID: c.BenchmarkID,
BenchmarkID: c.FrameworkID,
Status: c.Status.ToApi(),
FailureMessage: c.FailureMessage,
}
Expand All @@ -61,7 +61,7 @@ type ComplianceRunner struct {
gorm.Model

Callers string
BenchmarkID string
FrameworkID string
QueryID string
IntegrationID *string
ResourceCollectionID *string
Expand All @@ -82,7 +82,7 @@ func (cr *ComplianceRunner) GetKeyIdentifier() string {
if cr.IntegrationID != nil {
cid = *cr.IntegrationID
}
return fmt.Sprintf("%s-%s-%s-%d", cr.BenchmarkID, cr.QueryID, cid, cr.ParentJobID)
return fmt.Sprintf("%s-%s-%s-%d", cr.FrameworkID, cr.QueryID, cid, cr.ParentJobID)
}

func (cr *ComplianceRunner) GetCallers() ([]runner.Caller, error) {
Expand Down
2 changes: 1 addition & 1 deletion services/describe/scheduler_job_sequencer.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (s *Scheduler) runNextJob(ctx context.Context, job model.JobSequencer) erro
}

runnerJob := model.ComplianceRunner{
BenchmarkID: parameters.BenchmarkID,
FrameworkID: parameters.BenchmarkID,
QueryID: control.Query.ID,
IntegrationID: &connectionID,
StartedAt: time.Time{},
Expand Down
16 changes: 6 additions & 10 deletions services/describe/scheduler_quickscan_sequence.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,13 @@ type RunQuickComplianceScan struct {
}

func (s *RunQuickComplianceScan) Do(ctx context.Context) error {
var jobIDs []uint
for _, i := range s.job.IntegrationIDs {
jobId, err := s.s.complianceScheduler.CreateComplianceReportJobs(false, s.job.FrameworkID, nil, i, true, "QuickScanSequencer", &s.job.ID)
if err != nil {
return fmt.Errorf("error while creating compliance job: %v", err)
}
jobIDs = append(jobIDs, jobId)
jobs, err := s.s.complianceScheduler.CreateComplianceReportJobs(false, s.job.FrameworkID, nil, s.job.IntegrationIDs, true, "QuickScanSequencer", &s.job.ID)
if err != nil {
return fmt.Errorf("error while creating compliance job: %v", err)
}

s.s.logger.Info("Waiting for quick scan", zap.Uint("JobID", s.job.ID))
err := s.s.db.UpdateQuickScanSequenceStatus(s.job.ID, model.QuickScanSequenceComplianceRunning, "")
err = s.s.db.UpdateQuickScanSequenceStatus(s.job.ID, model.QuickScanSequenceComplianceRunning, "")
if err != nil {
return err
}
Expand All @@ -116,8 +112,8 @@ func (s *RunQuickComplianceScan) Do(ctx context.Context) error {

for ; ; <-t.C {
allFinished := true
for _, jobID := range jobIDs {
run, err := s.s.db.GetComplianceJobByID(jobID)
for _, job := range jobs {
run, err := s.s.db.GetComplianceJobByID(job.ID)
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ func (s *JobScheduler) runPublisher(ctx context.Context) error {
s.logger.Info("Fetch Created Query Runner Jobs", zap.Any("Jobs Count", len(jobs)))
for _, job := range jobs {
auditJobMsg := auditjob.AuditJob{
JobID: job.ID,
FrameworkID: job.BenchmarkID,
IntegrationID: job.IntegrationID,
IncludeResult: job.IncludeResults,
JobID: job.ID,
FrameworkID: job.FrameworkID,
IntegrationIDs: job.IntegrationIDs,
IncludeResult: job.IncludeResults,
}

jobJson, err := json.Marshal(auditJobMsg)
Expand Down
Loading

0 comments on commit c94120e

Please sign in to comment.