From fdb254ab77c58ff55803e72137311307e4be0a86 Mon Sep 17 00:00:00 2001 From: liqun Date: Thu, 10 Oct 2024 11:27:34 +0800 Subject: [PATCH] support custom image --- taskweaver/ces/__init__.py | 4 ++- taskweaver/ces/environment.py | 35 +++++++++++++++++--------- taskweaver/ces/manager/sub_proc.py | 2 ++ taskweaver/module/execution_service.py | 7 ++++++ 4 files changed, 35 insertions(+), 13 deletions(-) diff --git a/taskweaver/ces/__init__.py b/taskweaver/ces/__init__.py index 6ecc8ee8..7d8a45eb 100644 --- a/taskweaver/ces/__init__.py +++ b/taskweaver/ces/__init__.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Optional from taskweaver.ces.common import Manager from taskweaver.ces.manager.defer import DeferredManager @@ -8,11 +8,13 @@ def code_execution_service_factory( env_dir: str, kernel_mode: Literal["local", "container"] = "local", + custom_image: Optional[str] = None, ) -> Manager: def sub_proc_manager_factory() -> SubProcessManager: return SubProcessManager( env_dir=env_dir, kernel_mode=kernel_mode, + custom_image=custom_image, ) return DeferredManager( diff --git a/taskweaver/ces/environment.py b/taskweaver/ces/environment.py index ae7c21d2..84000204 100644 --- a/taskweaver/ces/environment.py +++ b/taskweaver/ces/environment.py @@ -111,12 +111,15 @@ class EnvMode(enum.Enum): class Environment: + DEFAULT_IMAGE = "taskweavercontainers/taskweaver-executor:latest" + def __init__( self, env_id: Optional[str] = None, env_dir: Optional[str] = None, env_mode: Optional[EnvMode] = EnvMode.Local, port_start_inside_container: Optional[int] = 12345, + custom_image: Optional[str] = None, ) -> None: self.session_dict: Dict[str, EnvSession] = {} self.id = get_id(prefix="env") if env_id is None else env_id @@ -145,19 +148,27 @@ def __init__( except docker.errors.DockerException as e: raise docker.errors.DockerException(f"Failed to connect to Docker daemon: {e}. ") - self.image_name = "taskweavercontainers/taskweaver-executor:latest" - try: - local_image = self.docker_client.images.get(self.image_name) - registry_image = self.docker_client.images.get_registry_data(self.image_name) - if local_image.id != registry_image.id: - logger.info(f"Local image {local_image.id} does not match registry image {registry_image.id}.") - raise docker.errors.ImageNotFound("Local image is outdated.") - except docker.errors.ImageNotFound: - logger.info("Pulling image from docker.io.") + if custom_image: + logger.info(f"Using custom image {custom_image}.") + self.image_name = custom_image + try: + self.docker_client.images.get(self.image_name) + except docker.errors.ImageNotFound: + raise docker.errors.ImageNotFound(f"Custom image {self.image_name} not found.") + else: + self.image_name = self.DEFAULT_IMAGE try: - self.docker_client.images.pull(self.image_name) - except docker.errors.DockerException as e: - raise docker.errors.DockerException(f"Failed to pull image: {e}. ") + local_image = self.docker_client.images.get(self.image_name) + registry_image = self.docker_client.images.get_registry_data(self.image_name) + if local_image.id != registry_image.id: + logger.info(f"Local image {local_image.id} does not match registry image {registry_image.id}.") + raise docker.errors.ImageNotFound("Local image is outdated.") + except docker.errors.ImageNotFound: + logger.info("Pulling image from docker.io.") + try: + self.docker_client.images.pull(self.image_name) + except docker.errors.DockerException as e: + raise docker.errors.DockerException(f"Failed to pull image: {e}. ") self.session_container_dict: Dict[str, str] = {} self.port_start_inside_container = port_start_inside_container diff --git a/taskweaver/ces/manager/sub_proc.py b/taskweaver/ces/manager/sub_proc.py index aee81c28..e7fa450d 100644 --- a/taskweaver/ces/manager/sub_proc.py +++ b/taskweaver/ces/manager/sub_proc.py @@ -57,6 +57,7 @@ def __init__( env_id: Optional[str] = None, env_dir: Optional[str] = None, kernel_mode: KernelModeType = "local", + custom_image: Optional[str] = None, ) -> None: from taskweaver.ces.environment import Environment, EnvMode @@ -76,6 +77,7 @@ def __init__( env_id, env_dir, env_mode=env_mode, + custom_image=custom_image, ) def initialize(self) -> None: diff --git a/taskweaver/module/execution_service.py b/taskweaver/module/execution_service.py index 0640ae39..ac46ac67 100644 --- a/taskweaver/module/execution_service.py +++ b/taskweaver/module/execution_service.py @@ -19,6 +19,7 @@ def _configure(self) -> None: "kernel_mode", "container", ) + assert self.kernel_mode in ["local", "container"], f"Invalid kernel mode: {self.kernel_mode}" if self.kernel_mode == "local": print( "TaskWeaver is running in the `local` mode. This implies that " @@ -27,6 +28,11 @@ def _configure(self) -> None: "More information can be found in the documentation " "(https://microsoft.github.io/TaskWeaver/docs/code_execution/).", ) + self.custom_image = self._get_str( + "custom_image", + default=None, + required=False, + ) class ExecutionServiceModule(Module): @@ -39,5 +45,6 @@ def provide_executor_manager(self, config: ExecutionServiceConfig) -> Manager: self.manager = code_execution_service_factory( env_dir=config.env_dir, kernel_mode=config.kernel_mode, + custom_image=config.custom_image, ) return self.manager