Skip to content

Commit

Permalink
Send cloud provider as string (#2786)
Browse files Browse the repository at this point in the history
* Send cloud provider as string

* Appease the type checker
  • Loading branch information
erikbern authored Jan 21, 2025
1 parent 64a85fe commit 575225d
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 4 deletions.
3 changes: 2 additions & 1 deletion modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,8 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona
task_idle_timeout_secs=container_idle_timeout or 0,
concurrency_limit=concurrency_limit or 0,
pty_info=pty_info,
cloud_provider=cloud_provider,
cloud_provider=cloud_provider, # Deprecated at some point
cloud_provider_str=cloud.upper() if cloud else "", # Supersedes cloud_provider
warm_pool_size=keep_warm or 0,
runtime=config.get("function_runtime"),
runtime_debug=config.get("function_runtime_debug"),
Expand Down
3 changes: 2 additions & 1 deletion modal/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona
resources=convert_fn_config_to_resources_config(
cpu=cpu, memory=memory, gpu=gpu, ephemeral_disk=ephemeral_disk
),
cloud_provider=parse_cloud_provider(cloud) if cloud else None,
cloud_provider=parse_cloud_provider(cloud) if cloud else None, # Deprecated at some point
cloud_provider_str=cloud.upper() if cloud else None, # Supersedes cloud_provider
nfs_mounts=network_file_system_mount_protos(validated_network_file_systems, False),
runtime_debug=config.get("function_runtime_debug"),
cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts),
Expand Down
8 changes: 6 additions & 2 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ message Function {

uint32 task_idle_timeout_secs = 25;

optional CloudProvider cloud_provider = 26;
optional CloudProvider cloud_provider = 26; // Deprecated at some point

uint32 warm_pool_size = 27;

Expand Down Expand Up @@ -1257,6 +1257,8 @@ message Function {
bool method_definitions_set = 75;

bool _experimental_custom_scaling = 76;

string cloud_provider_str = 77; // Supersedes cloud_provider
}

message FunctionAsyncInvokeRequest {
Expand Down Expand Up @@ -2193,7 +2195,7 @@ message Sandbox {
repeated string secret_ids = 4;

Resources resources = 5;
CloudProvider cloud_provider = 6;
CloudProvider cloud_provider = 6; // Deprecated at some point

uint32 timeout_secs = 7;

Expand Down Expand Up @@ -2237,6 +2239,8 @@ message Sandbox {
// Used to pin gVisor version for memory-snapshottable sandboxes.
// This field is set by the server, not the client.
optional uint32 snapshot_version = 25;

string cloud_provider_str = 26; // Supersedes cloud_provider
}

message SandboxCreateRequest {
Expand Down
1 change: 1 addition & 0 deletions test/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ def test_default_cloud_provider(client, servicer, monkeypatch):
f = servicer.app_functions[object_id]

assert f.cloud_provider == api_pb2.CLOUD_PROVIDER_OCI
assert f.cloud_provider_str == "OCI"


def test_not_hydrated():
Expand Down
1 change: 1 addition & 0 deletions test/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_cloud_provider_selection(client, servicer):
assert len(servicer.app_functions) == 1
func_def = next(iter(servicer.app_functions.values()))
assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_GCP
assert func_def.cloud_provider_str == "GCP"

assert func_def.resources.gpu_config.count == 1
assert func_def.resources.gpu_config.type == api_pb2.GPU_TYPE_A100
Expand Down
1 change: 1 addition & 0 deletions test/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def test_image_run_function_with_cloud_selection(servicer, client):
assert len(servicer.app_functions) == 2
func_def = next(iter(servicer.app_functions.values()))
assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_OCI
assert func_def.cloud_provider_str == "OCI"


def test_poetry(builder_version, servicer, client):
Expand Down

0 comments on commit 575225d

Please sign in to comment.