Skip to content

Commit

Permalink
Run Analysis Tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
gbdubs committed Jan 2, 2024
1 parent c1257de commit 9acec60
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 30 deletions.
62 changes: 45 additions & 17 deletions cmd/runner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ func run(args []string) error {
validTasks := map[task.Type]func(context.Context, task.ID) error{
task.ParsePortfolio: toRunFn(parsePortfolioReq, h.parsePortfolio),
task.CreateReport: toRunFn(createReportReq, h.createReport),
task.CreateAudit: toRunFn(createAuditReq, h.createAudit),
}

taskID := task.ID(os.Getenv("TASK_ID"))
Expand Down Expand Up @@ -338,6 +339,7 @@ func (h *handler) parsePortfolio(ctx context.Context, taskID task.ID, req *task.
return nil
}

// TODO(grady): Move this line counting into the image to prevent having our code do any read of the actual underlying data.
func countCSVLines(path string) (int, error) {
file, err := os.Open(path)
if err != nil {
Expand All @@ -356,28 +358,46 @@ func countCSVLines(path string) (int, error) {
return lineCount - 1, nil
}

func createAuditReq() (*task.CreateAuditRequest, error) {
car := os.Getenv("CREATE_AUDIT_REQUEST")
if car == "" {
return nil, errors.New("no CREATE_AUDIT_REQUEST was given")
}
var task task.CreateAuditRequest
if err := json.NewDecoder(strings.NewReader(car)).Decode(&task); err != nil {
return nil, fmt.Errorf("failed to load CreateAuditRequest: %w", err)
}
return &task, nil
}

func createReportReq() (*task.CreateReportRequest, error) {
pID := os.Getenv("PORTFOLIO_ID")
if pID == "" {
return nil, errors.New("no PORTFOLIO_ID was given")
crr := os.Getenv("CREATE_REPORT_REQUEST")
if crr == "" {
return nil, errors.New("no CREATE_REPORT_REQUEST was given")
}
var task task.CreateReportRequest
if err := json.NewDecoder(strings.NewReader(crr)).Decode(&task); err != nil {
return nil, fmt.Errorf("failed to load CreateReportRequest: %w", err)
}
return &task, nil
}

return &task.CreateReportRequest{
PortfolioID: pacta.PortfolioID(pID),
}, nil
func (h *handler) createAudit(ctx context.Context, taskID task.ID, req *task.CreateAuditRequest) error {
return fmt.Errorf("not implemented")
}

func (h *handler) createReport(ctx context.Context, taskID task.ID, req *task.CreateReportRequest) error {
baseName := string(req.PortfolioID) + ".json"

// Load the parsed portfolio from blob storage, place it in /mnt/
// processed_portfolios, where the `create_report.R` script expects it
// to be.
srcURI := blob.Join(h.blob.Scheme(), h.destPortfolioContainer, baseName)
destPath := filepath.Join("/", "mnt", "processed_portfolios", baseName)

if err := h.downloadBlob(ctx, srcURI, destPath); err != nil {
return fmt.Errorf("failed to download processed portfolio blob: %w", err)
fileNames := []string{}
for i, blobURI := range req.BlobURIs {
// Load the parsed portfolio from blob storage, place it in /mnt/
// processed_portfolios, where the `create_report.R` script expects it
// to be.
fileName := fmt.Sprintf("%d.json", i)
fileNames = append(fileNames, fileName)
destPath := filepath.Join("/", "mnt", "processed_portfolios", fileName)
if err := h.downloadBlob(ctx, string(blobURI), destPath); err != nil {
return fmt.Errorf("failed to download processed portfolio blob: %w", err)
}
}

reportDir := filepath.Join("/", "mnt", "reports")
Expand All @@ -387,7 +407,7 @@ func (h *handler) createReport(ctx context.Context, taskID task.ID, req *task.Cr

cmd := exec.CommandContext(ctx, "/usr/local/bin/Rscript", "/app/create_report.R")
cmd.Env = append(cmd.Env,
"PORTFOLIO="+string(req.PortfolioID),
"PORTFOLIO="+strings.Join(fileNames, ","),
"HOME=/root", /* Required by pandoc */
)
cmd.Stdout = os.Stdout
Expand Down Expand Up @@ -437,3 +457,11 @@ func (a *azureTokenCredential) GetToken(ctx context.Context, options policy.Toke
ExpiresOn: time.Now().AddDate(1, 0, 0),
}, nil
}

func asStrs[T ~string](in []T) []string {
out := make([]string, len(in))
for i, v := range in {
out[i] = string(v)
}
return out
}
41 changes: 35 additions & 6 deletions cmd/runner/taskrunner/taskrunner.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,22 @@ func New(cfg *Config) (*TaskRunner, error) {
}, nil
}

func (tr *TaskRunner) ParsePortfolio(ctx context.Context, req *task.ParsePortfolioRequest) (task.ID, task.RunnerID, error) {
func encodeRequestForCommandLineFlag(request any) (string, error) {
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(req); err != nil {
return "", "", fmt.Errorf("failed to encode ParsePortfolioRequest: %w", err)
if err := json.NewEncoder(&buf).Encode(request); err != nil {
return "", fmt.Errorf("failed to encode request: %w", err)
}
value := buf.String()
if len(value) > 128*1024 {
return "", "", fmt.Errorf("ParsePortfolioRequest is too large: %d bytes > 128 kb", len(value))
return "", fmt.Errorf("request is too large: %d bytes > 128 kb", len(value))
}
return value, nil
}

func (tr *TaskRunner) ParsePortfolio(ctx context.Context, req *task.ParsePortfolioRequest) (task.ID, task.RunnerID, error) {
value, err := encodeRequestForCommandLineFlag(req)
if err != nil {
return "", "", fmt.Errorf("failed to encode ParsePortfolioRequest: %w", err)
}
return tr.run(ctx, []task.EnvVar{
{
Expand All @@ -103,15 +111,36 @@ func (tr *TaskRunner) ParsePortfolio(ctx context.Context, req *task.ParsePortfol
})
}

func (tr *TaskRunner) CreateAudit(ctx context.Context, req *task.CreateAuditRequest) (task.ID, task.RunnerID, error) {
value, err := encodeRequestForCommandLineFlag(req)
if err != nil {
return "", "", fmt.Errorf("failed to encode CreateAuditRequest: %w", err)
}
return tr.run(ctx, []task.EnvVar{
{
Key: "TASK_TYPE",
Value: string(task.CreateAudit),
},
{
Key: "CREATE_AUDIT_REQUEST",
Value: value,
},
})
}

func (tr *TaskRunner) CreateReport(ctx context.Context, req *task.CreateReportRequest) (task.ID, task.RunnerID, error) {
value, err := encodeRequestForCommandLineFlag(req)
if err != nil {
return "", "", fmt.Errorf("failed to encode CreateReportRequest: %w", err)
}
return tr.run(ctx, []task.EnvVar{
{
Key: "TASK_TYPE",
Value: string(task.CreateReport),
},
{
Key: "PORTFOLIO_ID",
Value: string(req.PortfolioID),
Key: "CREATE_REPORT_REQUEST",
Value: value,
},
})
}
Expand Down
67 changes: 61 additions & 6 deletions cmd/server/pactasrv/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/RMI/pacta/oapierr"
api "github.com/RMI/pacta/openapi/pacta"
"github.com/RMI/pacta/pacta"
"github.com/RMI/pacta/task"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
Expand Down Expand Up @@ -234,7 +235,8 @@ func (s *Server) RunAnalysis(ctx context.Context, request api.RunAnalysisRequest
fs := append(fields, zap.String(fmt.Sprintf("%s_id", typeName), string(id)))
return oapierr.NotFound(fmt.Sprintf("%s not found", typeName), fs...)
}
var result pacta.AnalysisID
var analysisID pacta.AnalysisID
var blobURIs []pacta.BlobURI
var endUserErr error
err = s.DB.Transactional(ctx, func(tx db.Tx) error {
var pvID pacta.PACTAVersionID
Expand All @@ -254,6 +256,7 @@ func (s *Server) RunAnalysis(ctx context.Context, request api.RunAnalysisRequest
}

var snapshotID pacta.PortfolioSnapshotID
var blobIDs []pacta.BlobID
if pID != "" {
p, err := s.DB.Portfolio(tx, pID)
if err != nil {
Expand All @@ -275,6 +278,7 @@ func (s *Server) RunAnalysis(ctx context.Context, request api.RunAnalysisRequest
return fmt.Errorf("creating snapshot of portfolio: %w", err)
}
snapshotID = sID
blobIDs = []pacta.BlobID{p.Blob.ID}
} else if pgID != "" {
pg, err := s.DB.PortfolioGroup(tx, pgID)
if err != nil {
Expand All @@ -296,8 +300,19 @@ func (s *Server) RunAnalysis(ctx context.Context, request api.RunAnalysisRequest
return fmt.Errorf("creating snapshot of portfolio group: %w", err)
}
snapshotID = sID
pids := []pacta.PortfolioID{}
for _, pm := range pg.PortfolioGroupMemberships {
pids = append(pids, pm.Portfolio.ID)
}
portfolios, err := s.DB.Portfolios(tx, pids)
if err != nil {
return fmt.Errorf("looking up portfolios: %w", err)
}
for _, p := range portfolios {
blobIDs = append(blobIDs, p.Blob.ID)
}
} else if iID != "" {
_, err := s.DB.Initiative(tx, iID)
i, err := s.DB.Initiative(tx, iID)
if err != nil {
if db.IsNotFound(err) {
endUserErr = notFoundErr("initiative", string(iID), zap.Error(err))
Expand All @@ -311,12 +326,31 @@ func (s *Server) RunAnalysis(ctx context.Context, request api.RunAnalysisRequest
return fmt.Errorf("creating snapshot of initiative: %w", err)
}
snapshotID = sID
pids := []pacta.PortfolioID{}
for _, pm := range i.PortfolioInitiativeMemberships {
pids = append(pids, pm.Portfolio.ID)
}
portfolios, err := s.DB.Portfolios(tx, pids)
if err != nil {
return fmt.Errorf("looking up portfolios: %w", err)
}
for _, p := range portfolios {
blobIDs = append(blobIDs, p.Blob.ID)
}
}
if snapshotID == "" {
return fmt.Errorf("snapshot id is empty, something is wrong in the bizlogic")
}

analysisID, err := s.DB.CreateAnalysis(tx, &pacta.Analysis{
blobs, err := s.DB.Blobs(tx, blobIDs)
if err != nil {
return fmt.Errorf("looking up blobs: %w", err)
}
for _, blob := range blobs {
blobURIs = append(blobURIs, blob.BlobURI)
}

aID, err := s.DB.CreateAnalysis(tx, &pacta.Analysis{
AnalysisType: *analysisType,
PortfolioSnapshot: &pacta.PortfolioSnapshot{ID: snapshotID},
PACTAVersion: &pacta.PACTAVersion{ID: pvID},
Expand All @@ -327,7 +361,7 @@ func (s *Server) RunAnalysis(ctx context.Context, request api.RunAnalysisRequest
if err != nil {
return fmt.Errorf("creating analysis: %w", err)
}
result = analysisID
analysisID = aID
return nil
})
if endUserErr != nil {
Expand All @@ -337,7 +371,28 @@ func (s *Server) RunAnalysis(ctx context.Context, request api.RunAnalysisRequest
return nil, oapierr.Internal("failed to create analysis", zap.Error(err))
}

// TODO - here this is where we'd kick off the analysis run.
switch *analysisType {
case pacta.AnalysisType_Audit:
taskID, runnerID, err := s.TaskRunner.CreateAudit(ctx, &task.CreateAuditRequest{
AnalysisID: analysisID,
BlobURIs: blobURIs,
})
if err != nil {
return nil, oapierr.Internal("failed to create audit task", zap.Error(err))
}
s.Logger.Info("created audit task", zap.String("task_id", string(taskID)), zap.String("runner_id", string(runnerID)), zap.String("analysis_id", string(analysisID)))
case pacta.AnalysisType_Report:
taskID, runnerID, err := s.TaskRunner.CreateReport(ctx, &task.CreateReportRequest{
AnalysisID: analysisID,
BlobURIs: blobURIs,
})
if err != nil {
return nil, oapierr.Internal("failed to create report task", zap.Error(err))
}
s.Logger.Info("created report task", zap.String("task_id", string(taskID)), zap.String("runner_id", string(runnerID)), zap.String("analysis_id", string(analysisID)))
default:
return nil, oapierr.Internal("unknown analysis type", zap.String("analysis_type", string(*analysisType)))
}

return api.RunAnalysis200JSONResponse{AnalysisId: string(result)}, nil
return api.RunAnalysis200JSONResponse{AnalysisId: string(analysisID)}, nil
}
1 change: 1 addition & 0 deletions cmd/server/pactasrv/pactasrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ var (

type TaskRunner interface {
ParsePortfolio(ctx context.Context, req *task.ParsePortfolioRequest) (task.ID, task.RunnerID, error)
CreateAudit(ctx context.Context, req *task.CreateAuditRequest) (task.ID, task.RunnerID, error)
CreateReport(ctx context.Context, req *task.CreateReportRequest) (task.ID, task.RunnerID, error)
}

Expand Down
19 changes: 18 additions & 1 deletion task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Type string
const (
ParsePortfolio = Type("parse_portfolio")
CreateReport = Type("create_report")
CreateAudit = Type("create_audit")
)

type ParsePortfolioRequest struct {
Expand All @@ -37,8 +38,24 @@ type ParsePortfolioResponse struct {
Outputs []*ParsePortfolioResponseItem
}

type CreateAuditRequest struct {
AnalysisID pacta.AnalysisID
BlobURIs []pacta.BlobURI
}

type CreateAuditResponse struct {
TaskID ID
Request *CreateAuditRequest
}

type CreateReportRequest struct {
PortfolioID pacta.PortfolioID
AnalysisID pacta.AnalysisID
BlobURIs []pacta.BlobURI
}

type CreateReportResponse struct {
TaskID ID
Request *CreateReportRequest
}

type EnvVar struct {
Expand Down

0 comments on commit 9acec60

Please sign in to comment.