From cfd02f7f88f481a9a18be8220ec245d46628d65b Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 10 Sep 2024 07:02:12 +0000 Subject: [PATCH] WIP --- msccl/language/mscclpp/__init__.py | 3 +++ msccl/language/mscclpp/ir.py | 1 + msccl/language/types.py | 1 + 3 files changed, 5 insertions(+) diff --git a/msccl/language/mscclpp/__init__.py b/msccl/language/mscclpp/__init__.py index 8ba171e..d478342 100644 --- a/msccl/language/mscclpp/__init__.py +++ b/msccl/language/mscclpp/__init__.py @@ -33,6 +33,7 @@ def __init__( protocol: str = "Simple", instr_fusion: bool = True, replication_policy: ReplicationPolicy = ReplicationPolicy.duplicated, + num_threads_per_block: int = 1024, ): self.name = name self.topo = topo @@ -42,6 +43,7 @@ def __init__( self.protocol = protocol self.instr_fusion = instr_fusion self.replication_policy = replication_policy + self.num_threads_per_block = num_threads_per_block assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL" self.run_opt = True # Runs optimization passes # Initialize the input buffers @@ -134,6 +136,7 @@ def lower(self): self.protocol, gpu_prgms, self.collective.num_chunk_groups * self.instances, + self.num_threads_per_block, ) def generate_json(self): diff --git a/msccl/language/mscclpp/ir.py b/msccl/language/mscclpp/ir.py index baee9e6..8331cd6 100644 --- a/msccl/language/mscclpp/ir.py +++ b/msccl/language/mscclpp/ir.py @@ -310,5 +310,6 @@ def remove_empty_fields(d): "protocol": program.protocol, "inplace": program.inplace, "gpus": gpus, + "num_threads_per_block": program.num_threads_per_block, } return json.dumps(obj, indent=2) diff --git a/msccl/language/types.py b/msccl/language/types.py index be5677d..734e1c5 100644 --- a/msccl/language/types.py +++ b/msccl/language/types.py @@ -16,6 +16,7 @@ class Program: protocol: str gpus: list = field(default_factory=list) num_chunk_groups: int = 1 + num_threads_per_block: int = 1024 @dataclass