Skip to content

Commit

Permalink
Merge branch 'main' into binyli/flush
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 authored Sep 20, 2024
2 parents 2547209 + 9c94c02 commit 399df8a
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 0 deletions.
3 changes: 3 additions & 0 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
protocol: str = "Simple",
instr_fusion: bool = True,
replication_policy: ReplicationPolicy = ReplicationPolicy.duplicated,
num_threads_per_block: int = 1024,
):
self.name = name
self.topo = topo
Expand All @@ -42,6 +43,7 @@ def __init__(
self.protocol = protocol
self.instr_fusion = instr_fusion
self.replication_policy = replication_policy
self.num_threads_per_block = num_threads_per_block
assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL"
self.run_opt = True # Runs optimization passes
# Initialize the input buffers
Expand Down Expand Up @@ -134,6 +136,7 @@ def lower(self):
self.protocol,
gpu_prgms,
self.collective.num_chunk_groups * self.instances,
self.num_threads_per_block,
)

def generate_json(self):
Expand Down
1 change: 1 addition & 0 deletions msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,5 +318,6 @@ def remove_empty_fields(d):
"protocol": program.protocol,
"inplace": program.inplace,
"gpus": gpus,
"num_threads_per_block": program.num_threads_per_block,
}
return json.dumps(obj, indent=2)
1 change: 1 addition & 0 deletions msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Program:
protocol: str
gpus: list = field(default_factory=list)
num_chunk_groups: int = 1
num_threads_per_block: int = 1024


@dataclass
Expand Down

0 comments on commit 399df8a

Please sign in to comment.