Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for JSON string in defaultImage field for overriding pipeline specific images in the mappings #293

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
42 changes: 42 additions & 0 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,43 @@ type DockerManager struct {
mu *sync.Mutex
}

// updatePipelineMappings updates the specified mapping with pipeline to image overriding.
// It logs a warning if a pipeline is not found in the given mapping.
//
// Parameters:
// - overrides: A map of pipeline names to custom image names.
// - mapping: The map to be updated with the provided overrides.
// - mapName: The name of the map (used for logging purposes).
func updatePipelineMappings(overrides map[string]string, mapping map[string]string, mapName string) {
for pipeline, image := range overrides {
if _, exists := mapping[pipeline]; exists {
mapping[pipeline] = image
} else {
slog.Warn("Pipeline not found in map", "map", mapName, "pipeline", pipeline)
}
}
}

// overridePipelineImages function parses a JSON string containing pipeline-to-image mappings and overrides the default mappings if valid.
// It updates the `pipelineToImage` and `livePipelineToImage` maps with custom images.
// Parameters:
// - defaultImage: A string that can either be containerImage name or a JSON string with overrides for pipeline-to-image mappings.
//
// Returns:
// - error: An error if the JSON parsing fails or if the mapping is not found in existing maps else `nil`.
func overridePipelineImages(defaultImage string) error {
if strings.HasPrefix(defaultImage, "{") || strings.HasSuffix(defaultImage, "}") {
var pipelineOverrides map[string]string
if err := json.Unmarshal([]byte(defaultImage), &pipelineOverrides); err != nil {
slog.Error("Error parsing JSON", "error", err)
return err
}
updatePipelineMappings(pipelineOverrides, pipelineToImage, "pipelineToImage")
updatePipelineMappings(pipelineOverrides, livePipelineToImage, "livePipelineToImage")
}
return nil
}

func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
ctx, cancel := context.WithTimeout(context.Background(), containerTimeout)
if err := removeExistingContainers(ctx, client); err != nil {
Expand All @@ -107,6 +144,11 @@ func NewDockerManager(defaultImage string, gpus []string, modelDir string, clien
}
cancel()

// call to handle image overriding logic
if err := overridePipelineImages(defaultImage); err != nil {
return nil, err
}

manager := &DockerManager{
defaultImage: defaultImage,
gpus: gpus,
Expand Down
72 changes: 72 additions & 0 deletions worker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,75 @@ func TestDockerWaitUntilRunning(t *testing.T) {
mockDockerClient.AssertExpectations(t)
})
}

func TestDockerManager_overridePipelineImages(t *testing.T) {
mockDockerClient := new(MockDockerClient)
dockerManager := createDockerManager(mockDockerClient)

tests := []struct {
name string
inputJSON string
pipeline string
expectedImage string
expectError bool
}{
{
name: "ValidOverride",
inputJSON: `{"segment-anything-2": "custom-image:1.0"}`,
pipeline: "segment-anything-2",
expectedImage: "custom-image:1.0",
expectError: false,
},
{
name: "MultipleOverrides",
inputJSON: `{"segment-anything-2": "custom-image:1.0", "text-to-speech": "speech-image:2.0"}`,
pipeline: "text-to-speech",
expectedImage: "speech-image:2.0",
expectError: false,
},
{
name: "NoOverrideFallback",
inputJSON: `{"segment-anything-2": "custom-image:1.0"}`,
pipeline: "streamdiffusion",
expectedImage: "default-image",
expectError: false,
},
{
name: "EmptyJSON",
inputJSON: `{}`,
pipeline: "segment-anything-2",
expectedImage: "custom-image:1.0",
expectError: false,
},
{
name: "MalformedJSON",
inputJSON: `{"segment-anything-2": "custom-image:1.0"`,
pipeline: "segment-anything-2",
expectError: true,
},
{
name: "RegularStringInput",
inputJSON: "",
pipeline: "image-to-video",
expectedImage: "default-image",
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Call overridePipelineImages function with the mock data.
err := overridePipelineImages(tt.inputJSON)

if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)

// Verify the expected image.
image, _ := dockerManager.getContainerImageName(tt.pipeline, "")
require.Equal(t, tt.expectedImage, image)
}
})
}
}
Loading