Skip to content

Commit

Permalink
support custom image
Browse files Browse the repository at this point in the history
  • Loading branch information
liqul committed Oct 10, 2024
1 parent 5494e49 commit fdb254a
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
4 changes: 3 additions & 1 deletion taskweaver/ces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, Optional

from taskweaver.ces.common import Manager
from taskweaver.ces.manager.defer import DeferredManager
Expand All @@ -8,11 +8,13 @@
def code_execution_service_factory(
env_dir: str,
kernel_mode: Literal["local", "container"] = "local",
custom_image: Optional[str] = None,
) -> Manager:
def sub_proc_manager_factory() -> SubProcessManager:
return SubProcessManager(
env_dir=env_dir,
kernel_mode=kernel_mode,
custom_image=custom_image,
)

return DeferredManager(
Expand Down
35 changes: 23 additions & 12 deletions taskweaver/ces/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,15 @@ class EnvMode(enum.Enum):


class Environment:
DEFAULT_IMAGE = "taskweavercontainers/taskweaver-executor:latest"

def __init__(
self,
env_id: Optional[str] = None,
env_dir: Optional[str] = None,
env_mode: Optional[EnvMode] = EnvMode.Local,
port_start_inside_container: Optional[int] = 12345,
custom_image: Optional[str] = None,
) -> None:
self.session_dict: Dict[str, EnvSession] = {}
self.id = get_id(prefix="env") if env_id is None else env_id
Expand Down Expand Up @@ -145,19 +148,27 @@ def __init__(
except docker.errors.DockerException as e:
raise docker.errors.DockerException(f"Failed to connect to Docker daemon: {e}. ")

self.image_name = "taskweavercontainers/taskweaver-executor:latest"
try:
local_image = self.docker_client.images.get(self.image_name)
registry_image = self.docker_client.images.get_registry_data(self.image_name)
if local_image.id != registry_image.id:
logger.info(f"Local image {local_image.id} does not match registry image {registry_image.id}.")
raise docker.errors.ImageNotFound("Local image is outdated.")
except docker.errors.ImageNotFound:
logger.info("Pulling image from docker.io.")
if custom_image:
logger.info(f"Using custom image {custom_image}.")
self.image_name = custom_image
try:
self.docker_client.images.get(self.image_name)
except docker.errors.ImageNotFound:
raise docker.errors.ImageNotFound(f"Custom image {self.image_name} not found.")
else:
self.image_name = self.DEFAULT_IMAGE
try:
self.docker_client.images.pull(self.image_name)
except docker.errors.DockerException as e:
raise docker.errors.DockerException(f"Failed to pull image: {e}. ")
local_image = self.docker_client.images.get(self.image_name)
registry_image = self.docker_client.images.get_registry_data(self.image_name)
if local_image.id != registry_image.id:
logger.info(f"Local image {local_image.id} does not match registry image {registry_image.id}.")
raise docker.errors.ImageNotFound("Local image is outdated.")
except docker.errors.ImageNotFound:
logger.info("Pulling image from docker.io.")
try:
self.docker_client.images.pull(self.image_name)
except docker.errors.DockerException as e:
raise docker.errors.DockerException(f"Failed to pull image: {e}. ")

self.session_container_dict: Dict[str, str] = {}
self.port_start_inside_container = port_start_inside_container
Expand Down
2 changes: 2 additions & 0 deletions taskweaver/ces/manager/sub_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
env_id: Optional[str] = None,
env_dir: Optional[str] = None,
kernel_mode: KernelModeType = "local",
custom_image: Optional[str] = None,
) -> None:
from taskweaver.ces.environment import Environment, EnvMode

Expand All @@ -76,6 +77,7 @@ def __init__(
env_id,
env_dir,
env_mode=env_mode,
custom_image=custom_image,
)

def initialize(self) -> None:
Expand Down
7 changes: 7 additions & 0 deletions taskweaver/module/execution_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def _configure(self) -> None:
"kernel_mode",
"container",
)
assert self.kernel_mode in ["local", "container"], f"Invalid kernel mode: {self.kernel_mode}"
if self.kernel_mode == "local":
print(
"TaskWeaver is running in the `local` mode. This implies that "
Expand All @@ -27,6 +28,11 @@ def _configure(self) -> None:
"More information can be found in the documentation "
"(https://microsoft.github.io/TaskWeaver/docs/code_execution/).",
)
self.custom_image = self._get_str(
"custom_image",
default=None,
required=False,
)


class ExecutionServiceModule(Module):
Expand All @@ -39,5 +45,6 @@ def provide_executor_manager(self, config: ExecutionServiceConfig) -> Manager:
self.manager = code_execution_service_factory(
env_dir=config.env_dir,
kernel_mode=config.kernel_mode,
custom_image=config.custom_image,
)
return self.manager

0 comments on commit fdb254a

Please sign in to comment.