Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Adds ability to publish AzureContainerInstanceJob block as a azure-container-instance work pool #130

Merged
merged 8 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Security

## 0.3.3

Released December 11th, 2023.

### Added

- Ability to publish `AzureContainerInstanceJob` blocks as an azure-container-instance work pool - [#130](https://github.com/PrefectHQ/prefect-azure/pull/130)

## 0.3.1

Released October 10th, 2023.
Expand Down
79 changes: 79 additions & 0 deletions prefect_azure/container_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@
import datetime
import json
import random
import shlex
import string
import sys
import time
import uuid
from copy import deepcopy
from enum import Enum
from typing import Dict, List, Optional, Union

Expand All @@ -95,6 +97,7 @@
ResourceRequirements,
UserAssignedIdentities,
)
from prefect.blocks.core import BlockNotSavedError
from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound
from prefect.infrastructure.base import Infrastructure, InfrastructureResult
from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible
Expand Down Expand Up @@ -462,6 +465,82 @@ def preview(self) -> str:

return json.dumps(preview)

def get_corresponding_worker_type(self) -> str:
"""Return the corresponding worker type for this infrastructure block."""
from prefect_azure.workers.container_instance import AzureContainerWorker

return AzureContainerWorker.type

async def generate_work_pool_base_job_template(self) -> dict:
"""
Generate a base job template for an `Azure Container Instance` work pool
with the same configuration as this block.

Returns:
- dict: a base job template for an `Azure Container Instance` work pool
"""
from prefect_azure.workers.container_instance import AzureContainerWorker

base_job_template = deepcopy(
AzureContainerWorker.get_default_base_job_template()
)
for key, value in self.dict(exclude_unset=True, exclude_defaults=True).items():
if key == "command":
base_job_template["variables"]["properties"]["command"][
"default"
] = shlex.join(value)
elif key in [
"type",
"block_type_slug",
"_block_document_id",
"_block_document_name",
"_is_anonymous",
]:
continue
elif key == "subscription_id":
base_job_template["variables"]["properties"]["subscription_id"][
"default"
] = value.get_secret_value()
elif key == "aci_credentials":
if not self.aci_credentials._block_document_id:
raise BlockNotSavedError(
"It looks like you are trying to use a block that"
" has not been saved. Please call `.save` on your block"
" before publishing it as a work pool."
)
base_job_template["variables"]["properties"]["aci_credentials"][
"default"
] = {
"$ref": {
"block_document_id": str(
self.aci_credentials._block_document_id
)
}
}
elif key == "image_registry":
if not self.image_registry._block_document_id:
raise BlockNotSavedError(
"It looks like you are trying to use a block that"
" has not been saved. Please call `.save` on your block"
" before publishing it as a work pool."
)
base_job_template["variables"]["properties"]["image_registry"][
"default"
] = {
"$ref": {
"block_document_id": str(self.image_registry._block_document_id)
}
}
elif key in base_job_template["variables"]["properties"]:
base_job_template["variables"]["properties"][key]["default"] = value
else:
self.logger.warning(
f"Variable {key!r} is not supported by `Azure Container Instance`"
" work pools. Skipping."
)

return base_job_template

def _configure_container(self) -> Container:
"""
Configures an Azure `Container` using data from the block's fields.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
prefect>=2.13.5
prefect>=2.14.10
azure_mgmt_containerinstance>=10.0
azure_identity>=1.10
azure-mgmt-resource>=21.2
Expand Down
151 changes: 149 additions & 2 deletions tests/test_aci_infrastructure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import uuid
from copy import deepcopy
from typing import Dict, List, Tuple, Union
from unittest.mock import MagicMock, Mock

Expand All @@ -19,6 +20,8 @@
from prefect.settings import get_current_settings
from pydantic import VERSION as PYDANTIC_VERSION

from prefect_azure.workers.container_instance import AzureContainerWorker

if PYDANTIC_VERSION.startswith("2."):
from pydantic.v1 import SecretStr
else:
Expand Down Expand Up @@ -93,6 +96,7 @@ def aci_credentials():
credentials = AzureContainerInstanceCredentials(
client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
)
credentials.save("test-block", overwrite=True)
return credentials


Expand Down Expand Up @@ -534,9 +538,9 @@ def test_preview(aci_credentials):

preview = json.loads(block.preview())

for (k, v) in block_args.items():
for k, v in block_args.items():
if k == "env":
for (k2, v2) in block_args["env"].items():
for k2, v2 in block_args["env"].items():
assert preview[k][k2] == block_args["env"][k2]

else:
Expand Down Expand Up @@ -925,3 +929,146 @@ def test_azure_container_instance_job_default_factory():
subscription_id="test_sub_id",
)
assert isinstance(instance_job.aci_credentials, AzureContainerInstanceCredentials)


@pytest.fixture
def default_base_job_template():
return deepcopy(AzureContainerWorker.get_default_base_job_template())


@pytest.fixture
def base_job_template_with_defaults(
default_base_job_template, aci_credentials, image_registry_block
):
base_job_template_with_defaults = deepcopy(default_base_job_template)
base_job_template_with_defaults["variables"]["properties"]["command"][
"default"
] = "python my_script.py"
base_job_template_with_defaults["variables"]["properties"]["env"]["default"] = {
"VAR1": "value1",
"VAR2": "value2",
}
base_job_template_with_defaults["variables"]["properties"]["labels"]["default"] = {
"label1": "value1",
"label2": "value2",
}
base_job_template_with_defaults["variables"]["properties"]["name"][
"default"
] = "prefect-job"
base_job_template_with_defaults["variables"]["properties"]["image"][
"default"
] = "docker.io/my_image:latest"
base_job_template_with_defaults["variables"]["properties"]["resource_group_name"][
"default"
] = "testgroup"
base_job_template_with_defaults["variables"]["properties"]["subscription_id"][
"default"
] = "subid"
base_job_template_with_defaults["variables"]["properties"]["aci_credentials"][
"default"
] = {"$ref": {"block_document_id": str(aci_credentials._block_document_id)}}
base_job_template_with_defaults["variables"]["properties"]["identities"][
"default"
] = ["/my/managed_identity/one", "/my/managed_identity/two"]
base_job_template_with_defaults["variables"]["properties"]["entrypoint"][
"default"
] = "/test/entrypoint.sh"
base_job_template_with_defaults["variables"]["properties"]["image_registry"][
"default"
] = {"$ref": {"block_document_id": str(image_registry_block._block_document_id)}}
base_job_template_with_defaults["variables"]["properties"]["cpu"]["default"] = 2.0
base_job_template_with_defaults["variables"]["properties"]["gpu_count"][
"default"
] = 1
base_job_template_with_defaults["variables"]["properties"]["gpu_sku"][
"default"
] = "V100"
base_job_template_with_defaults["variables"]["properties"]["memory"][
"default"
] = 3.0
base_job_template_with_defaults["variables"]["properties"]["subnet_ids"][
"default"
] = ["subnet1", "subnet2", "subnet3"]
base_job_template_with_defaults["variables"]["properties"]["dns_servers"][
"default"
] = ["dns1", "dns2", "dns3"]
base_job_template_with_defaults["variables"]["properties"]["stream_output"][
"default"
] = True
base_job_template_with_defaults["variables"]["properties"][
"task_start_timeout_seconds"
]["default"] = 120
base_job_template_with_defaults["variables"]["properties"][
"task_watch_poll_interval"
]["default"] = 0.1
return base_job_template_with_defaults


@pytest.fixture
async def image_registry_block():
block = DockerRegistry(
username="username",
password="password",
registry_url="https://myregistry.dockerhub.com",
)
await block.save("test-for-publish", overwrite=True)
return block


@pytest.mark.parametrize(
"job_config",
[
"default",
"custom",
],
)
async def test_generate_work_pool_base_job_template(
job_config,
base_job_template_with_defaults,
aci_credentials,
default_base_job_template,
image_registry_block,
):
job = AzureContainerInstanceJob(
aci_credentials=aci_credentials,
resource_group_name="testgroup",
subscription_id="subid",
)
expected_template = default_base_job_template
default_base_job_template["variables"]["properties"]["resource_group_name"][
"default"
] = "testgroup"
default_base_job_template["variables"]["properties"]["subscription_id"][
"default"
] = "subid"
default_base_job_template["variables"]["properties"]["aci_credentials"][
"default"
] = {"$ref": {"block_document_id": str(aci_credentials._block_document_id)}}
if job_config == "custom":
expected_template = base_job_template_with_defaults
job = AzureContainerInstanceJob(
command=["python", "my_script.py"],
env={"VAR1": "value1", "VAR2": "value2"},
labels={"label1": "value1", "label2": "value2"},
name="prefect-job",
image="docker.io/my_image:latest",
aci_credentials=aci_credentials,
resource_group_name="testgroup",
subscription_id="subid",
identities=["/my/managed_identity/one", "/my/managed_identity/two"],
entrypoint="/test/entrypoint.sh",
image_registry=image_registry_block,
cpu=2.0,
gpu_count=1,
gpu_sku="V100",
memory=3.0,
subnet_ids=["subnet1", "subnet2", "subnet3"],
dns_servers=["dns1", "dns2", "dns3"],
stream_output=True,
task_start_timeout_seconds=120,
task_watch_poll_interval=0.1,
)

template = await job.generate_work_pool_base_job_template()

assert template == expected_template
Loading