Skip to content

Commit

Permalink
Pass GPU type as a string to the server (#2785)
Browse files Browse the repository at this point in the history
Pass gpu type as a string to the server
  • Loading branch information
erikbern authored Jan 21, 2025
1 parent 0878eff commit 56aae12
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
37 changes: 16 additions & 21 deletions modal/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

@dataclass(frozen=True)
class _GPUConfig:
type: "api_pb2.GPUType.V"
type: "api_pb2.GPUType.V" # Deprecated, at some point
count: int
gpu_type: str
memory: int = 0

def _to_proto(self) -> api_pb2.GPUConfig:
Expand All @@ -19,6 +20,7 @@ def _to_proto(self) -> api_pb2.GPUConfig:
type=self.type,
count=self.count,
memory=self.memory,
gpu_type=self.gpu_type,
)


Expand All @@ -33,7 +35,7 @@ def __init__(
self,
count: int = 1, # Number of GPUs per container. Defaults to 1.
):
super().__init__(api_pb2.GPU_TYPE_T4, count, 0)
super().__init__(api_pb2.GPU_TYPE_T4, count, "T4")

def __repr__(self):
return f"GPU(T4, count={self.count})"
Expand All @@ -51,7 +53,7 @@ def __init__(
self,
count: int = 1, # Number of GPUs per container. Defaults to 1.
):
super().__init__(api_pb2.GPU_TYPE_L4, count, 0)
super().__init__(api_pb2.GPU_TYPE_L4, count, "L4")

def __repr__(self):
return f"GPU(L4, count={self.count})"
Expand All @@ -70,21 +72,14 @@ def __init__(
count: int = 1, # Number of GPUs per container. Defaults to 1.
size: Union[str, None] = None, # Select GiB configuration of GPU device: "40GB" or "80GB". Defaults to "40GB".
):
allowed_size_values = {"40GB", "80GB"}

if size:
if size not in allowed_size_values:
raise ValueError(
f"size='{size}' is invalid. A100s can only have memory values of {allowed_size_values}."
)
memory = int(size.replace("GB", ""))
if size == "40GB" or not size:
super().__init__(api_pb2.GPU_TYPE_A100, count, "A100-40GB", 40)
elif size == "80GB":
super().__init__(api_pb2.GPU_TYPE_A100_80GB, count, "A100-80GB", 80)
else:
memory = 40

if memory == 80:
super().__init__(api_pb2.GPU_TYPE_A100_80GB, count, memory)
else:
super().__init__(api_pb2.GPU_TYPE_A100, count, memory)
raise ValueError(
f"size='{size}' is invalid. A100s can only have memory values of 40GB or 80GB."
)

def __repr__(self):
if self.memory == 80:
Expand All @@ -109,7 +104,7 @@ def __init__(
# Useful if you have very large models that don't fit on a single GPU.
count: int = 1,
):
super().__init__(api_pb2.GPU_TYPE_A10G, count)
super().__init__(api_pb2.GPU_TYPE_A10G, count, "A10G")

def __repr__(self):
return f"GPU(A10G, count={self.count})"
Expand All @@ -131,7 +126,7 @@ def __init__(
# Useful if you have very large models that don't fit on a single GPU.
count: int = 1,
):
super().__init__(api_pb2.GPU_TYPE_H100, count)
super().__init__(api_pb2.GPU_TYPE_H100, count, "H100")

def __repr__(self):
return f"GPU(H100, count={self.count})"
Expand All @@ -152,7 +147,7 @@ def __init__(
# Useful if you have very large models that don't fit on a single GPU.
count: int = 1,
):
super().__init__(api_pb2.GPU_TYPE_L40S, count)
super().__init__(api_pb2.GPU_TYPE_L40S, count, "L40S")

def __repr__(self):
return f"GPU(L40S, count={self.count})"
Expand All @@ -162,7 +157,7 @@ class Any(_GPUConfig):
"""Selects any one of the GPU classes available within Modal, according to availability."""

def __init__(self, *, count: int = 1):
super().__init__(api_pb2.GPU_TYPE_ANY, count)
super().__init__(api_pb2.GPU_TYPE_ANY, count, "ANY")

def __repr__(self):
return f"GPU(Any, count={self.count})"
Expand Down
3 changes: 2 additions & 1 deletion modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1667,9 +1667,10 @@ message FunctionUpdateSchedulingParamsRequest {
message FunctionUpdateSchedulingParamsResponse {}

message GPUConfig {
GPUType type = 1;
GPUType type = 1; // Deprecated, at some point
uint32 count = 2;
uint32 memory = 3;
string gpu_type = 4;
}

message GeneratorDone { // Sent as the output when a generator finishes running.
Expand Down

0 comments on commit 56aae12

Please sign in to comment.