diff --git a/worker/docker.go b/worker/docker.go index 893d371c..3ca710fe 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -99,6 +99,43 @@ type DockerManager struct { mu *sync.Mutex } +// updatePipelineMappings updates the specified mapping with pipeline to image overriding. +// It logs a warning if a pipeline is not found in the given mapping. +// +// Parameters: +// - overrides: A map of pipeline names to custom image names. +// - mapping: The map to be updated with the provided overrides. +// - mapName: The name of the map (used for logging purposes). +func updatePipelineMappings(overrides map[string]string, mapping map[string]string, mapName string) { + for pipeline, image := range overrides { + if _, exists := mapping[pipeline]; exists { + mapping[pipeline] = image + } else { + slog.Warn("Pipeline not found in map", "map", mapName, "pipeline", pipeline) + } + } +} + +// overridePipelineImages function parses a JSON string containing pipeline-to-image mappings and overrides the default mappings if valid. +// It updates the `pipelineToImage` and `livePipelineToImage` maps with custom images. +// Parameters: +// - defaultImage: A string that can either be containerImage name or a JSON string with overrides for pipeline-to-image mappings. +// +// Returns: +// - error: An error if the JSON parsing fails or if the mapping is not found in existing maps else `nil`. +func overridePipelineImages(defaultImage string) error { + if strings.HasPrefix(defaultImage, "{") || strings.HasSuffix(defaultImage, "}") { + var pipelineOverrides map[string]string + if err := json.Unmarshal([]byte(defaultImage), &pipelineOverrides); err != nil { + slog.Error("Error parsing JSON", "error", err) + return err + } + updatePipelineMappings(pipelineOverrides, pipelineToImage, "pipelineToImage") + updatePipelineMappings(pipelineOverrides, livePipelineToImage, "livePipelineToImage") + } + return nil +} + func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) { ctx, cancel := context.WithTimeout(context.Background(), containerTimeout) if err := removeExistingContainers(ctx, client); err != nil { @@ -107,6 +144,11 @@ func NewDockerManager(defaultImage string, gpus []string, modelDir string, clien } cancel() + // call to handle image overriding logic + if err := overridePipelineImages(defaultImage); err != nil { + return nil, err + } + manager := &DockerManager{ defaultImage: defaultImage, gpus: gpus, diff --git a/worker/docker_test.go b/worker/docker_test.go index cbb20086..fc1b81c6 100644 --- a/worker/docker_test.go +++ b/worker/docker_test.go @@ -784,3 +784,75 @@ func TestDockerWaitUntilRunning(t *testing.T) { mockDockerClient.AssertExpectations(t) }) } + +func TestDockerManager_overridePipelineImages(t *testing.T) { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + tests := []struct { + name string + inputJSON string + pipeline string + expectedImage string + expectError bool + }{ + { + name: "ValidOverride", + inputJSON: `{"segment-anything-2": "custom-image:1.0"}`, + pipeline: "segment-anything-2", + expectedImage: "custom-image:1.0", + expectError: false, + }, + { + name: "MultipleOverrides", + inputJSON: `{"segment-anything-2": "custom-image:1.0", "text-to-speech": "speech-image:2.0"}`, + pipeline: "text-to-speech", + expectedImage: "speech-image:2.0", + expectError: false, + }, + { + name: "NoOverrideFallback", + inputJSON: `{"segment-anything-2": "custom-image:1.0"}`, + pipeline: "streamdiffusion", + expectedImage: "default-image", + expectError: false, + }, + { + name: "EmptyJSON", + inputJSON: `{}`, + pipeline: "segment-anything-2", + expectedImage: "custom-image:1.0", + expectError: false, + }, + { + name: "MalformedJSON", + inputJSON: `{"segment-anything-2": "custom-image:1.0"`, + pipeline: "segment-anything-2", + expectError: true, + }, + { + name: "RegularStringInput", + inputJSON: "", + pipeline: "image-to-video", + expectedImage: "default-image", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call overridePipelineImages function with the mock data. + err := overridePipelineImages(tt.inputJSON) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + + // Verify the expected image. + image, _ := dockerManager.getContainerImageName(tt.pipeline, "") + require.Equal(t, tt.expectedImage, image) + } + }) + } +}