Skip to content

Commit

Permalink
Reduce where clause fanout when updating workflow, node & task execut…
Browse files Browse the repository at this point in the history
…ions (#5953)
  • Loading branch information
katrogan authored Nov 5, 2024
1 parent 636cc23 commit a87585a
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 9 deletions.
4 changes: 4 additions & 0 deletions flyteadmin/pkg/repositories/gormimpl/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,7 @@ func applyScopedFilters(tx *gorm.DB, inlineFilters []common.InlineFilter, mapFil
}
return tx, nil
}

func getIDFilter(id uint) (query string, args interface{}) {
return fmt.Sprintf("%s = ?", ID), id
}
2 changes: 1 addition & 1 deletion flyteadmin/pkg/repositories/gormimpl/execution_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (r *ExecutionRepo) Get(ctx context.Context, input interfaces.Identifier) (m

func (r *ExecutionRepo) Update(ctx context.Context, execution models.Execution) error {
timer := r.metrics.UpdateDuration.Start()
tx := r.db.WithContext(ctx).Model(&execution).Updates(execution)
tx := r.db.WithContext(ctx).Model(&models.Execution{}).Where(getIDFilter(execution.ID)).Updates(execution)
timer.Stop()
if err := tx.Error; err != nil {
return r.errorTransformer.ToFlyteAdminError(err)
Expand Down
5 changes: 1 addition & 4 deletions flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ func TestUpdateExecution(t *testing.T) {
updated := false

// Only match on queries that append expected filters
GlobalMock.NewMock().WithQuery(`UPDATE "executions" SET "updated_at"=$1,"execution_project"=$2,` +
`"execution_domain"=$3,"execution_name"=$4,"launch_plan_id"=$5,"workflow_id"=$6,"phase"=$7,"closure"=$8,` +
`"spec"=$9,"started_at"=$10,"execution_created_at"=$11,"execution_updated_at"=$12,"duration"=$13 WHERE "` +
`execution_project" = $14 AND "execution_domain" = $15 AND "execution_name" = $16`).WithCallback(
GlobalMock.NewMock().WithQuery(`UPDATE "executions" SET "updated_at"=$1,"execution_project"=$2,"execution_domain"=$3,"execution_name"=$4,"launch_plan_id"=$5,"workflow_id"=$6,"phase"=$7,"closure"=$8,"spec"=$9,"started_at"=$10,"execution_created_at"=$11,"execution_updated_at"=$12,"duration"=$13 WHERE id = $14`).WithCallback(
func(s string, values []driver.NamedValue) {
updated = true
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (r *NodeExecutionRepo) GetWithChildren(ctx context.Context, input interface

func (r *NodeExecutionRepo) Update(ctx context.Context, nodeExecution *models.NodeExecution) error {
timer := r.metrics.UpdateDuration.Start()
tx := r.db.WithContext(ctx).Model(&nodeExecution).Updates(nodeExecution)
tx := r.db.WithContext(ctx).Model(&models.NodeExecution{}).Where(getIDFilter(nodeExecution.ID)).Updates(nodeExecution)
timer.Stop()
if err := tx.Error; err != nil {
return r.errorTransformer.ToFlyteAdminError(err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestUpdateNodeExecution(t *testing.T) {
GlobalMock := mocket.Catcher.Reset()
// Only match on queries that append the name filter
nodeExecutionQuery := GlobalMock.NewMock()
nodeExecutionQuery.WithQuery(`UPDATE "node_executions" SET "id"=$1,"updated_at"=$2,"execution_project"=$3,"execution_domain"=$4,"execution_name"=$5,"node_id"=$6,"phase"=$7,"input_uri"=$8,"closure"=$9,"started_at"=$10,"node_execution_created_at"=$11,"node_execution_updated_at"=$12,"duration"=$13 WHERE "execution_project" = $14 AND "execution_domain" = $15 AND "execution_name" = $16 AND "node_id" = $17`)
nodeExecutionQuery.WithQuery(`UPDATE "node_executions" SET "id"=$1,"updated_at"=$2,"execution_project"=$3,"execution_domain"=$4,"execution_name"=$5,"node_id"=$6,"phase"=$7,"input_uri"=$8,"closure"=$9,"started_at"=$10,"node_execution_created_at"=$11,"node_execution_updated_at"=$12,"duration"=$13 WHERE id = $14`)
err := nodeExecutionRepo.Update(context.Background(),
&models.NodeExecution{
BaseModel: models.BaseModel{ID: 1},
Expand Down
3 changes: 2 additions & 1 deletion flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ func (r *TaskExecutionRepo) Get(ctx context.Context, input interfaces.GetTaskExe

func (r *TaskExecutionRepo) Update(ctx context.Context, execution models.TaskExecution) error {
timer := r.metrics.UpdateDuration.Start()
tx := r.db.WithContext(ctx).WithContext(ctx).Save(&execution)
tx := r.db.WithContext(ctx).Model(&models.TaskExecution{}).Where(getIDFilter(execution.ID)).
Updates(&execution)
timer.Stop()

if err := tx.Error; err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestUpdateTaskExecution(t *testing.T) {
GlobalMock.Logging = true

taskExecutionQuery := GlobalMock.NewMock()
taskExecutionQuery.WithQuery(`UPDATE "task_executions" SET "id"=$1,"created_at"=$2,"updated_at"=$3,"deleted_at"=$4,"phase"=$5,"phase_version"=$6,"input_uri"=$7,"closure"=$8,"started_at"=$9,"task_execution_created_at"=$10,"task_execution_updated_at"=$11,"duration"=$12 WHERE "project" = $13 AND "domain" = $14 AND "name" = $15 AND "version" = $16 AND "execution_project" = $17 AND "execution_domain" = $18 AND "execution_name" = $19 AND "node_id" = $20 AND "retry_attempt" = $21`)
taskExecutionQuery.WithQuery(`UPDATE "task_executions" SET "updated_at"=$1,"project"=$2,"domain"=$3,"name"=$4,"version"=$5,"execution_project"=$6,"execution_domain"=$7,"execution_name"=$8,"node_id"=$9,"retry_attempt"=$10,"phase"=$11,"input_uri"=$12,"closure"=$13,"started_at"=$14,"task_execution_created_at"=$15,"task_execution_updated_at"=$16,"duration"=$17 WHERE id = $18`)
err := taskExecutionRepo.Update(context.Background(), testTaskExecution)
assert.NoError(t, err)
assert.True(t, taskExecutionQuery.Triggered)
Expand Down

0 comments on commit a87585a

Please sign in to comment.