Skip to content

Commit

Permalink
add a feature gate (#4472)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Nov 23, 2023
1 parent fdd8fe9 commit 8a895b9
Show file tree
Hide file tree
Showing 14 changed files with 255 additions and 115 deletions.
12 changes: 12 additions & 0 deletions flyteadmin/pkg/artifacts/artifact_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package artifacts

import (
"context"
"github.com/stretchr/testify/assert"
"testing"
)

func TestEmpty(t *testing.T) {
c := InitializeArtifactClient(context.Background(), nil)
assert.Nil(t, c)
}
1 change: 0 additions & 1 deletion flyteadmin/pkg/artifacts/config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package artifacts

// gatepr: add proper config bits for this
// eduardo to consider moving to idl clients.
type Config struct {
Host string `json:"host"`
Port int `json:"port"`
Expand Down
15 changes: 9 additions & 6 deletions flyteadmin/pkg/artifacts/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package artifacts
import (
"context"
"fmt"

"google.golang.org/grpc"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
Expand All @@ -18,10 +17,11 @@ type ArtifactRegistry struct {
}

func (a *ArtifactRegistry) RegisterArtifactProducer(ctx context.Context, id *core.Identifier, ti core.TypedInterface) {
if a.client == nil {
if a == nil || a.client == nil {
logger.Debugf(ctx, "Artifact client not configured, skipping registration for task [%+v]", id)
return
}

ap := &artifact.ArtifactProducer{
EntityId: id,
Outputs: ti.Outputs,
Expand All @@ -36,7 +36,7 @@ func (a *ArtifactRegistry) RegisterArtifactProducer(ctx context.Context, id *cor
}

func (a *ArtifactRegistry) RegisterArtifactConsumer(ctx context.Context, id *core.Identifier, pm core.ParameterMap) {
if a.client == nil {
if a == nil || a.client == nil {
logger.Debugf(ctx, "Artifact client not configured, skipping registration for consumer [%+v]", id)
return
}
Expand All @@ -54,7 +54,7 @@ func (a *ArtifactRegistry) RegisterArtifactConsumer(ctx context.Context, id *cor
}

func (a *ArtifactRegistry) RegisterTrigger(ctx context.Context, plan *admin.LaunchPlan) error {
if a.client == nil {
if a == nil || a.client == nil {
logger.Debugf(ctx, "Artifact client not configured, skipping trigger [%+v]", plan)
return fmt.Errorf("artifact client not configured")
}
Expand All @@ -70,11 +70,14 @@ func (a *ArtifactRegistry) RegisterTrigger(ctx context.Context, plan *admin.Laun
}

func (a *ArtifactRegistry) GetClient() artifact.ArtifactRegistryClient {
if a == nil {
return nil
}
return a.client
}

func NewArtifactRegistry(ctx context.Context, config *Config, opts ...grpc.DialOption) ArtifactRegistry {
return ArtifactRegistry{
func NewArtifactRegistry(ctx context.Context, config *Config, opts ...grpc.DialOption) *ArtifactRegistry {
return &ArtifactRegistry{
client: InitializeArtifactClient(ctx, config, opts...),
}
}
28 changes: 28 additions & 0 deletions flyteadmin/pkg/artifacts/registry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package artifacts

import (
"context"
"github.com/stretchr/testify/assert"
"testing"
)

func TestRegistryNoClient(t *testing.T) {
r := NewArtifactRegistry(context.Background(), nil)
assert.Nil(t, r.GetClient())
}

type Parent struct {
R *ArtifactRegistry
}

func TestPointerReceivers(t *testing.T) {
p := Parent{}
nilClient := p.R.GetClient()
assert.Nil(t, nilClient)
}

func TestNilCheck(t *testing.T) {
r := NewArtifactRegistry(context.Background(), nil)
err := r.RegisterTrigger(context.Background(), nil)
assert.NotNil(t, err)
}
63 changes: 63 additions & 0 deletions flyteadmin/pkg/manager/impl/exec_manager_other_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package impl

import (
"context"
"fmt"
"github.com/flyteorg/flyte/flyteadmin/pkg/artifacts"
eventWriterMocks "github.com/flyteorg/flyte/flyteadmin/pkg/async/events/mocks"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
mockScope "github.com/flyteorg/flyte/flytestdlib/promutils"
"github.com/stretchr/testify/assert"
"testing"
)

func TestResolveNotWorking(t *testing.T) {
mockConfig := getMockExecutionsConfigProvider()

execManager := NewExecutionManager(nil, nil, mockConfig, nil, mockScope.NewTestScope(), mockScope.NewTestScope(), nil, nil, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}, artifacts.NewArtifactRegistry(context.Background(), nil)).(*ExecutionManager)

pm, artifactIDs, err := execManager.ResolveParameterMapArtifacts(context.Background(), nil, nil)
assert.Nil(t, err)
fmt.Println(pm, artifactIDs)

}

func TestTrackingBitExtract(t *testing.T) {
mockConfig := getMockExecutionsConfigProvider()

execManager := NewExecutionManager(nil, nil, mockConfig, nil, mockScope.NewTestScope(), mockScope.NewTestScope(), nil, nil, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}, artifacts.NewArtifactRegistry(context.Background(), nil)).(*ExecutionManager)

lit := core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{
Integer: 1,
},
},
},
},
},
Metadata: map[string]string{"_ua": "proj/domain/name@version"},
}
inputMap := core.LiteralMap{
Literals: map[string]*core.Literal{
"a": &lit,
},
}
inputColl := core.LiteralCollection{
Literals: []*core.Literal{
&lit,
},
}

trackers := execManager.ExtractArtifactKeys(&lit)
assert.Equal(t, 1, len(trackers))

trackers = execManager.ExtractArtifactKeys(&core.Literal{Value: &core.Literal_Map{Map: &inputMap}})
assert.Equal(t, 1, len(trackers))
trackers = execManager.ExtractArtifactKeys(&core.Literal{Value: &core.Literal_Collection{Collection: &inputColl}})
assert.Equal(t, 1, len(trackers))
assert.Equal(t, "proj/domain/name@version", trackers[0])
}
113 changes: 60 additions & 53 deletions flyteadmin/pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type ExecutionManager struct {
cloudEventPublisher notificationInterfaces.Publisher
dbEventWriter eventWriter.WorkflowExecutionEventWriter
pluginRegistry *plugins.Registry
artifactRegistry artifacts.ArtifactRegistry
artifactRegistry *artifacts.ArtifactRegistry
}

func getExecutionContext(ctx context.Context, id *core.WorkflowExecutionIdentifier) context.Context {
Expand Down Expand Up @@ -872,10 +872,10 @@ func (m *ExecutionManager) fillInTemplateArgs(ctx context.Context, query core.Ar
}

// ResolveParameterMapArtifacts will go through the parameter map, and resolve any artifact queries.
func (m *ExecutionManager) ResolveParameterMapArtifacts(ctx context.Context, inputs *core.ParameterMap, inputsForQueryTemplating map[string]*core.Literal) (*core.ParameterMap, map[string]*core.ArtifactID, error) {
func (m *ExecutionManager) ResolveParameterMapArtifacts(ctx context.Context, inputs *core.ParameterMap, inputsForQueryTemplating map[string]*core.Literal) (*core.ParameterMap, []*core.ArtifactID, error) {

// only top level replace for now. Need to make this recursive
var artifactIDs = make(map[string]*core.ArtifactID)
var artifactIDs []*core.ArtifactID
if inputs == nil {
return nil, artifactIDs, nil
}
Expand Down Expand Up @@ -911,7 +911,7 @@ func (m *ExecutionManager) ResolveParameterMapArtifacts(ctx context.Context, inp
if err != nil {
return nil, nil, err
}
artifactIDs[k] = resp.Artifact.GetArtifactId()
artifactIDs = append(artifactIDs, resp.Artifact.GetArtifactId())
logger.Debugf(ctx, "Resolved query for [%s] to [%+v]", k, resp.Artifact.ArtifactId)
outputs[k] = &core.Parameter{
Var: v.Var,
Expand All @@ -931,7 +931,7 @@ func (m *ExecutionManager) ResolveParameterMapArtifacts(ctx context.Context, inp
if err != nil {
return nil, nil, err
}
artifactIDs[k] = v.GetArtifactId()
artifactIDs = append(artifactIDs, v.GetArtifactId())
logger.Debugf(ctx, "Using specified artifactID for [%+v] for [%s]", v.GetArtifactId(), k)
outputs[k] = &core.Parameter{
Var: v.Var,
Expand Down Expand Up @@ -974,60 +974,69 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel(
return nil, nil, err
}

// Literals may have an artifact key in the metadata field. This is something the artifact service should have
// added. Pull these back out so we can keep track of them for lineage purposes. Use a dummy wrapper object for
// easier recursion.
requestInputMap := &core.Literal{
Value: &core.Literal_Map{Map: request.Inputs},
}
fixedInputMap := &core.Literal{
Value: &core.Literal_Map{Map: launchPlan.Spec.FixedInputs},
}
requestInputArtifactKeys := m.ExtractArtifactKeys(requestInputMap)
fixedInputArtifactKeys := m.ExtractArtifactKeys(fixedInputMap)
requestInputArtifactKeys = append(requestInputArtifactKeys, fixedInputArtifactKeys...)

// Put together the inputs that we've already resolved so that the artifact querying bit can fill them in.
// This is to support artifact queries that depend on other inputs using the {{ .inputs.var }} construct.
var inputsForQueryTemplating = make(map[string]*core.Literal)
if request.Inputs != nil {
for k, v := range request.Inputs.Literals {
// TODO: Artifact feature gate, remove when ready
var lpExpectedInputs *core.ParameterMap
var artifactTrackers []string
var usedArtifactIDs []*core.ArtifactID
if m.artifactRegistry.GetClient() != nil {
// Literals may have an artifact key in the metadata field. This is something the artifact service should have
// added. Pull these back out so we can keep track of them for lineage purposes. Use a dummy wrapper object for
// easier recursion.
requestInputMap := &core.Literal{
Value: &core.Literal_Map{Map: request.Inputs},
}
fixedInputMap := &core.Literal{
Value: &core.Literal_Map{Map: launchPlan.Spec.FixedInputs},
}
artifactTrackers = m.ExtractArtifactKeys(requestInputMap)
fixedInputArtifactKeys := m.ExtractArtifactKeys(fixedInputMap)
artifactTrackers = append(artifactTrackers, fixedInputArtifactKeys...)

// Put together the inputs that we've already resolved so that the artifact querying bit can fill them in.
// This is to support artifact queries that depend on other inputs using the {{ .inputs.var }} construct.
var inputsForQueryTemplating = make(map[string]*core.Literal)
if request.Inputs != nil {
for k, v := range request.Inputs.Literals {
inputsForQueryTemplating[k] = v
}
}
for k, v := range launchPlan.Spec.FixedInputs.Literals {
inputsForQueryTemplating[k] = v
}
}
for k, v := range launchPlan.Spec.FixedInputs.Literals {
inputsForQueryTemplating[k] = v
}
logger.Debugf(ctx, "Inputs for query templating: [%+v]", inputsForQueryTemplating)
logger.Debugf(ctx, "Inputs for query templating: [%+v]", inputsForQueryTemplating)

// Resolve artifact queries
// Within the launch plan, the artifact will be in the Parameter map, and can come in form of an ArtifactID,
// or as an ArtifactQuery.
// Also send in the inputsForQueryTemplating for two reasons, so we don't run queries for things we don't need to
// and so we can fill in template args.
// ArtifactIDs are also returned for lineage purposes.
resolvedExpectedInputs, usedArtifactIDs, err := m.ResolveParameterMapArtifacts(ctxPD, launchPlan.Closure.ExpectedInputs, inputsForQueryTemplating)
if err != nil {
logger.Errorf(ctx, "Error looking up launch plan closure parameter map: %v", err)
return nil, nil, err
}

// Resolve artifact queries
// Within the launch plan, the artifact will be in the Parameter map, and can come in form of an ArtifactID,
// or as an ArtifactQuery.
// Also send in the inputsForQueryTemplating for two reasons, so we don't run queries for things we don't need to
// and so we can fill in template args.
// ArtifactIDs are also returned for lineage purposes.
resolvedExpectedInputs, usedArtifactIDs, err := m.ResolveParameterMapArtifacts(ctxPD, launchPlan.Closure.ExpectedInputs, inputsForQueryTemplating)
if err != nil {
logger.Errorf(ctx, "Error looking up launch plan closure parameter map: %v", err)
return nil, nil, err
}
logger.Debugf(ctx, "Resolved launch plan closure expected inputs from [%+v] to [%+v]", launchPlan.Closure.ExpectedInputs, resolvedExpectedInputs)
logger.Debugf(ctx, "Found artifact keys: %v", artifactTrackers)
logger.Debugf(ctx, "Found artifact IDs: %v", usedArtifactIDs)

logger.Debugf(ctx, "Resolved launch plan closure expected inputs from [%+v] to [%+v]", launchPlan.Closure.ExpectedInputs, resolvedExpectedInputs)
logger.Debugf(ctx, "Found artifact keys: %v", requestInputArtifactKeys)
logger.Debugf(ctx, "Found artifact IDs: %v", usedArtifactIDs)
} else {
lpExpectedInputs = launchPlan.Closure.ExpectedInputs
}

// Artifacts retrieved will need to be stored somewhere to ensure that we can re-emit events if necessary
// in the future, and also to make sure that relaunch and recover can use it if necessary.
executionInputs, err := validation.CheckAndFetchInputsForExecution(
request.Inputs,
launchPlan.Spec.FixedInputs,
resolvedExpectedInputs,
lpExpectedInputs,
)

if err != nil {
logger.Debugf(ctx, "Failed to CheckAndFetchInputsForExecution with request.Inputs: %+v"+
"fixed inputs: %+v and expected inputs: %+v with err %v",
request.Inputs, launchPlan.Spec.FixedInputs, resolvedExpectedInputs, err)
request.Inputs, launchPlan.Spec.FixedInputs, lpExpectedInputs, err)
return nil, nil, err
}

Expand Down Expand Up @@ -1061,13 +1070,7 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel(
requestSpec.Metadata = &admin.ExecutionMetadata{}
}
requestSpec.Metadata.Principal = getUser(ctx)
// Construct a list of the values and save to request spec metadata.
// Avoids going through the model creation step.
artifactIDs := make([]*core.ArtifactID, 0, len(usedArtifactIDs))
for _, value := range usedArtifactIDs {
artifactIDs = append(artifactIDs, value)
}
requestSpec.Metadata.ArtifactIds = artifactIDs
requestSpec.Metadata.ArtifactIds = usedArtifactIDs

// Get the node and parent execution (if any) that launched this execution
var parentNodeExecutionID uint
Expand Down Expand Up @@ -1185,7 +1188,11 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel(
notificationsSettings = make([]*admin.Notification, 0)
}

m.publishExecutionStart(ctx, workflowExecutionID, request.Spec.LaunchPlan, workflow.Id, requestInputArtifactKeys, requestSpec.Metadata.ArtifactIds)
// Publish of event is also gated on the artifact client being available, even though it's not directly required.
// TODO: Artifact feature gate, remove when ready
if m.artifactRegistry.GetClient() != nil {
m.publishExecutionStart(ctx, workflowExecutionID, request.Spec.LaunchPlan, workflow.Id, artifactTrackers, usedArtifactIDs)
}

executionModel, err := transformers.CreateExecutionModel(transformers.CreateExecutionModelInput{
WorkflowExecutionID: workflowExecutionID,
Expand Down Expand Up @@ -1958,7 +1965,7 @@ func NewExecutionManager(db repositoryInterfaces.Repository, pluginRegistry *plu
publisher notificationInterfaces.Publisher, urlData dataInterfaces.RemoteURLInterface,
workflowManager interfaces.WorkflowInterface, namedEntityManager interfaces.NamedEntityInterface,
eventPublisher notificationInterfaces.Publisher, cloudEventPublisher cloudeventInterfaces.Publisher,
eventWriter eventWriter.WorkflowExecutionEventWriter, artifactRegistry artifacts.ArtifactRegistry) interfaces.ExecutionInterface {
eventWriter eventWriter.WorkflowExecutionEventWriter, artifactRegistry *artifacts.ArtifactRegistry) interfaces.ExecutionInterface {

queueAllocator := executions.NewQueueAllocator(config, db)
systemMetrics := newExecutionSystemMetrics(systemScope)
Expand Down
Loading

0 comments on commit 8a895b9

Please sign in to comment.