Skip to content

Commit

Permalink
add new fields
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Nov 28, 2024
1 parent 8efa06f commit 6df0fc1
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
6 changes: 6 additions & 0 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(
replication_policy: ReplicationPolicy = ReplicationPolicy.duplicated,
num_threads_per_block: int = 1024,
use_double_scratch_buffer: bool = False,
min_message_size: int = -1,
max_message_size: int = -1 ,
):
self.name = name
self.topo = topo
Expand All @@ -47,6 +49,8 @@ def __init__(
self.replication_policy = replication_policy
self.num_threads_per_block = num_threads_per_block
self.use_double_scratch_buffer = use_double_scratch_buffer
self.min_message_size = min_message_size
self.max_message_size = max_message_size
assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL"
self.run_opt = True # Runs optimization passes
# Initialize the input buffers
Expand Down Expand Up @@ -147,6 +151,8 @@ def lower(self):
self.collective.num_chunk_groups * self.instances,
self.num_threads_per_block,
self.use_double_scratch_buffer,
self.min_message_size,
self.max_message_size,
)
for gpu in program.gpus:
gpu.input_chunks = len(self.buffers[gpu.rank][Buffer.input]) * self.instances
Expand Down
3 changes: 3 additions & 0 deletions msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,5 +377,8 @@ def remove_empty_fields(d):
"gpus": gpus,
"num_threads_per_block": program.num_threads_per_block,
"use_double_scratch_buffer": program.use_double_scratch_buffer,
"min_message_size": program.min_message_size,
"max_message_size": program.max_message_size,
"in_place": program.inplace,
}
return json.dumps(obj, indent=2)
4 changes: 3 additions & 1 deletion msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class Program:
num_chunk_groups: int = 1
num_threads_per_block: int = 1024
use_double_scratch_buffer: bool = False

min_message_size: int = -1
max_message_size: int = -1


@dataclass
class Gpu:
Expand Down

0 comments on commit 6df0fc1

Please sign in to comment.