Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Publish Event Grid event from runner after processing portfolio #49

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions azure/azevents/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ go_library(
importpath = "github.com/RMI/pacta/azure/azevents",
visibility = ["//visibility:public"],
deps = [
"//task",
"@com_github_go_chi_chi_v5//:chi",
"@org_uber_go_zap//:zap",
],
Expand Down
50 changes: 39 additions & 11 deletions azure/azevents/azevents.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"path"
"time"

"github.com/RMI/pacta/task"
"github.com/go-chi/chi/v5"
"go.uber.org/zap"
)
Expand All @@ -21,6 +21,8 @@ type Config struct {

Subscription string
ResourceGroup string

ProcessedPortfolioTopicName string
}

func (c *Config) validate() error {
Expand All @@ -33,6 +35,9 @@ func (c *Config) validate() error {
if c.ResourceGroup == "" {
return errors.New("no resource group given")
}
if c.ProcessedPortfolioTopicName == "" {
return errors.New("no resource group given")
}
return nil
}

Expand All @@ -42,6 +47,7 @@ type Server struct {

subscription string
resourceGroup string
pathToTopic map[string]string
}

func NewServer(cfg *Config) (*Server, error) {
Expand All @@ -53,16 +59,15 @@ func NewServer(cfg *Config) (*Server, error) {
logger: cfg.Logger,
subscription: cfg.Subscription,
resourceGroup: cfg.ResourceGroup,
pathToTopic: map[string]string{
"/events/processed_portfolio": cfg.ProcessedPortfolioTopicName,
},
}, nil
}

var pathToTopic = map[string]string{
"/events/processed_portfolio": "processed-portfolios",
}

func (s *Server) verifyWebhook(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
topic, ok := pathToTopic[r.URL.Path]
topic, ok := s.pathToTopic[r.URL.Path]
if !ok {
s.logger.Error("no topic found for path", zap.String("path", r.URL.Path))
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
Expand Down Expand Up @@ -138,12 +143,35 @@ func (s *Server) verifyWebhook(next http.Handler) http.Handler {
func (s *Server) RegisterHandlers(r chi.Router) {
r.Use(s.verifyWebhook)
r.Post("/events/processed_portfolio", func(w http.ResponseWriter, r *http.Request) {
dat, err := io.ReadAll(r.Body)
if err != nil {
s.logger.Error("failed to read webhook request body", zap.Error(err))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
var reqs []struct {
Data *task.ProcessPortfolioResponse `json:"data"`
EventType string `json:"eventType"`
ID string `json:"id"`
Subject string `json:"subject"`
DataVersion string `json:"dataVersion"`
MetadataVersion string `json:"metadataVersion"`
EventTime time.Time `json:"eventTime"`
Topic string `json:"topic"`
}
if err := json.NewDecoder(r.Body).Decode(&reqs); err != nil {
s.logger.Error("failed to parse webhook request body", zap.Error(err))
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
if len(reqs) != 1 {
s.logger.Error("webhook response had unexpected number of events", zap.Int("event_count", len(reqs)))
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
req := reqs[0]

if req.Data == nil {
s.logger.Error("webhook response had no payload", zap.String("event_grid_id", req.ID))
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
s.logger.Info("processed porfolio", zap.String("portfolio_data", string(dat)))

// TODO: Add any database persistence and other things we'd want to do after a portfolio was processed.
s.logger.Info("processed portfolio", zap.String("task_id", string(req.Data.TaskID)), zap.Strings("outputs", req.Data.Outputs))
})
}
4 changes: 2 additions & 2 deletions azure/aztask/aztask.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func NewRunner(creds azcore.TokenCredential, cfg *Config) (*Runner, error) {
}, nil
}

func (r *Runner) Run(ctx context.Context, cfg *task.Config) (task.ID, error) {
func (r *Runner) Run(ctx context.Context, cfg *task.Config) (task.RunnerID, error) {

name := r.gen.NewID()
identity := r.cfg.Identity.String()
Expand Down Expand Up @@ -204,7 +204,7 @@ func (r *Runner) Run(ctx context.Context, cfg *task.Config) (task.ID, error) {
return "", fmt.Errorf("failed to poll for container app start: %w", err)
}

return task.ID(*res.ID), nil
return task.RunnerID(*res.ID), nil
}

func toPtrs[T any](in []T) []*T {
Expand Down
2 changes: 2 additions & 0 deletions cmd/runner/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ go_library(
"//task",
"@com_github_azure_azure_sdk_for_go_sdk_azcore//:azcore",
"@com_github_azure_azure_sdk_for_go_sdk_azcore//policy",
"@com_github_azure_azure_sdk_for_go_sdk_azcore//to",
"@com_github_azure_azure_sdk_for_go_sdk_azidentity//:azidentity",
"@com_github_azure_azure_sdk_for_go_sdk_messaging_azeventgrid//publisher",
"@com_github_namsral_flag//:flag",
"@org_uber_go_zap//:zap",
"@org_uber_go_zap//zapcore",
Expand Down
3 changes: 3 additions & 0 deletions cmd/runner/configs/dev.conf
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
env dev
min_log_level warn

azure_processed_portfolio_topic processed-portfolios-dev
azure_topic_location centralus-1

azure_storage_account rmipactadev
azure_source_portfolio_container uploadedportfolios
azure_dest_portfolio_container processedportfolios
Expand Down
3 changes: 3 additions & 0 deletions cmd/runner/configs/local.conf
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
env local
min_log_level debug

azure_processed_portfolio_topic processed-portfolios-local
azure_topic_location centralus-1

azure_storage_account rmipactalocal
azure_source_portfolio_container uploadedportfolios
azure_dest_portfolio_container processedportfolios
Expand Down
62 changes: 53 additions & 9 deletions cmd/runner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventgrid/publisher"
"github.com/RMI/pacta/azure/azblob"
"github.com/RMI/pacta/azure/azlog"
"github.com/RMI/pacta/blob"
Expand Down Expand Up @@ -47,6 +49,9 @@ func run(args []string) error {
var (
env = fs.String("env", "", "The environment we're running in.")

azProcessedPortfolioTopic = fs.String("azure_processed_portfolio_topic", "", "The EventGrid topic to send notifications of processed portfolios")
azTopicLocation = fs.String("azure_topic_location", "", "The location (like 'centralus-1') where our EventGrid topics are hosted")

azStorageAccount = fs.String("azure_storage_account", "", "The storage account to authenticate against for blob operations")
azSourcePortfolioContainer = fs.String("azure_source_portfolio_container", "", "The container in the storage account where we read raw portfolios from")
azDestPortfolioContainer = fs.String("azure_dest_portfolio_container", "", "The container in the storage account where we read/write processed portfolios")
Expand Down Expand Up @@ -94,23 +99,36 @@ func run(args []string) error {
}
}

pubsubClient, err := publisher.NewClient(fmt.Sprintf("https://%s.%s.eventgrid.azure.net/api/events", *azProcessedPortfolioTopic, *azTopicLocation), creds, nil)
if err != nil {
return fmt.Errorf("failed to init pub/sub client: %w", err)
}

blobClient, err := azblob.NewClient(creds, *azStorageAccount)
if err != nil {
return fmt.Errorf("failed to init blob client: %w", err)
}

h := handler{
blob: blobClient,
blob: blobClient,
pubsub: pubsubClient,
logger: logger,

sourcePortfolioContainer: *azSourcePortfolioContainer,
destPortfolioContainer: *azDestPortfolioContainer,
reportContainer: *azReportContainer,
}

validTasks := map[task.Type]func(context.Context) error{
validTasks := map[task.Type]func(context.Context, task.ID) error{
task.ProcessPortfolio: toRunFn(processPortfolioReq, h.processPortfolio),
task.CreateReport: toRunFn(createReportReq, h.createReport),
}

taskID := task.ID(os.Getenv("TASK_ID"))
if taskID == "" {
return errors.New("no TASK_ID given")
}

taskType := task.Type(os.Getenv("TASK_TYPE"))
if taskType == "" {
return errors.New("no TASK_TYPE given")
Expand All @@ -123,7 +141,7 @@ func run(args []string) error {

logger.Info("running PACTA task", zap.String("task_type", string(taskType)))

if err := taskFn(ctx); err != nil {
if err := taskFn(ctx, taskID); err != nil {
return fmt.Errorf("error running task: %w", err)
}

Expand All @@ -139,7 +157,10 @@ type Blob interface {
}

type handler struct {
blob Blob
blob Blob
pubsub *publisher.Client
logger *zap.Logger

sourcePortfolioContainer string
destPortfolioContainer string
reportContainer string
Expand Down Expand Up @@ -229,7 +250,7 @@ func (h *handler) downloadBlob(ctx context.Context, srcURI, destPath string) err
return nil
}

func (h *handler) processPortfolio(ctx context.Context, req *task.ProcessPortfolioRequest) error {
func (h *handler) processPortfolio(ctx context.Context, taskID task.ID, req *task.ProcessPortfolioRequest) error {
// Load the portfolio from blob storage, place it in /mnt/raw_portfolios, where
// the `process_portfolios.R` script expects it to be.
for _, assetID := range req.AssetIDs {
Expand Down Expand Up @@ -270,13 +291,36 @@ func (h *handler) processPortfolio(ctx context.Context, req *task.ProcessPortfol
}

// NOTE: This code could benefit from some concurrency, but I'm opting not to prematurely optimize.
var out []string
for _, p := range paths {
destURI := blob.Join(h.blob.Scheme(), h.destPortfolioContainer, filepath.Base(p))
if err := h.uploadBlob(ctx, p, destURI); err != nil {
return fmt.Errorf("failed to copy processed portfolio from %q to %q: %w", p, destURI, err)
}
out = append(out, destURI)
}

events := []publisher.Event{
{
Data: task.ProcessPortfolioResponse{
TaskID: taskID,
AssetIDs: req.AssetIDs,
Outputs: out,
},
DataVersion: to.Ptr("1.0"),
EventType: to.Ptr("processed-portfolio"),
EventTime: to.Ptr(time.Now()),
ID: to.Ptr(string(taskID)),
Subject: to.Ptr("subject"),
},
}

if _, err := h.pubsub.PublishEvents(ctx, events, nil); err != nil {
return fmt.Errorf("failed to publish event: %w", err)
}

h.logger.Info("processed portfolio", zap.String("task_id", string(taskID)))

return nil
}

Expand All @@ -291,7 +335,7 @@ func createReportReq() (*task.CreateReportRequest, error) {
}, nil
}

func (h *handler) createReport(ctx context.Context, req *task.CreateReportRequest) error {
func (h *handler) createReport(ctx context.Context, taskID task.ID, req *task.CreateReportRequest) error {
baseName := string(req.PortfolioID) + ".json"

// Load the processed portfolio from blob storage, place it in /mnt/
Expand Down Expand Up @@ -340,13 +384,13 @@ func (h *handler) createReport(ctx context.Context, req *task.CreateReportReques
return nil
}

func toRunFn[T any](reqFn func() (T, error), runFn func(context.Context, T) error) func(context.Context) error {
return func(ctx context.Context) error {
func toRunFn[T any](reqFn func() (T, error), runFn func(context.Context, task.ID, T) error) func(context.Context, task.ID) error {
return func(ctx context.Context, taskID task.ID) error {
req, err := reqFn()
if err != nil {
return fmt.Errorf("failed to format request: %w", err)
}
return runFn(ctx, req)
return runFn(ctx, taskID, req)
}
}

Expand Down
1 change: 1 addition & 0 deletions cmd/runner/taskrunner/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//task",
"@com_github_google_uuid//:uuid",
"@org_uber_go_zap//:zap",
],
)
23 changes: 16 additions & 7 deletions cmd/runner/taskrunner/taskrunner.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"fmt"

"github.com/RMI/pacta/task"
"github.com/google/uuid"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -58,7 +59,7 @@ func validateImage(bi *task.BaseImage) error {
}

type Runner interface {
Run(ctx context.Context, cfg *task.Config) (task.ID, error)
Run(ctx context.Context, cfg *task.Config) (task.RunnerID, error)
}

type TaskRunner struct {
Expand All @@ -81,10 +82,10 @@ func New(cfg *Config) (*TaskRunner, error) {
}, nil
}

func (tr *TaskRunner) ProcessPortfolio(ctx context.Context, req *task.ProcessPortfolioRequest) (task.ID, error) {
func (tr *TaskRunner) ProcessPortfolio(ctx context.Context, req *task.ProcessPortfolioRequest) (task.ID, task.RunnerID, error) {
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(req.AssetIDs); err != nil {
return "", fmt.Errorf("failed to encode asset IDs: %w", err)
return "", "", fmt.Errorf("failed to encode asset IDs: %w", err)
}
return tr.run(ctx, []task.EnvVar{
{
Expand All @@ -98,7 +99,7 @@ func (tr *TaskRunner) ProcessPortfolio(ctx context.Context, req *task.ProcessPor
})
}

func (tr *TaskRunner) CreateReport(ctx context.Context, req *task.CreateReportRequest) (task.ID, error) {
func (tr *TaskRunner) CreateReport(ctx context.Context, req *task.CreateReportRequest) (task.ID, task.RunnerID, error) {
return tr.run(ctx, []task.EnvVar{
{
Key: "TASK_TYPE",
Expand All @@ -111,10 +112,14 @@ func (tr *TaskRunner) CreateReport(ctx context.Context, req *task.CreateReportRe
})
}

func (tr *TaskRunner) run(ctx context.Context, env []task.EnvVar) (task.ID, error) {
func (tr *TaskRunner) run(ctx context.Context, env []task.EnvVar) (task.ID, task.RunnerID, error) {
tr.logger.Info("triggering task run", zap.Any("env", env))
return tr.runner.Run(ctx, &task.Config{
Env: env,
taskID := uuid.NewString()
runnerID, err := tr.runner.Run(ctx, &task.Config{
Env: append(env, task.EnvVar{
Key: "TASK_ID",
Value: taskID,
}),
Flags: []string{"--config=" + tr.configPath},
Command: []string{"/runner"},
Image: &task.Image{
Expand All @@ -123,4 +128,8 @@ func (tr *TaskRunner) run(ctx context.Context, env []task.EnvVar) (task.ID, erro
Tag: "latest",
},
})
if err != nil {
return "", "", fmt.Errorf("failed to run task %q, %q: %w", taskID, runnerID, err)
}
return task.ID(taskID), runnerID, nil
}
1 change: 1 addition & 0 deletions cmd/server/configs/dev.conf
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ port 80

azure_event_subscription 69b6db12-37e3-4e1f-b48c-aa41dba612a9
azure_event_resource_group rmi-pacta-dev
azure_event_processed_portfolio_topic processed-portfolios-dev
1 change: 1 addition & 0 deletions cmd/server/configs/local.conf
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ allowed_cors_origin http://localhost:3000

azure_event_subscription 69b6db12-37e3-4e1f-b48c-aa41dba612a9
azure_event_resource_group rmi-pacta-local
azure_event_processed_portfolio_topic processed-portfolios-local

secret_postgres_host UNUSED
# Also unused
Expand Down
Loading