From 78cf77967b2d9986c1c3536ec4774d1f0c663890 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 4 Nov 2024 13:10:00 +0000 Subject: [PATCH] WIP --- examples/mscclang/mscclpp/allreduce_nvls.py | 5 +-- msccl/language/mscclpp/instruction_dag.py | 50 +++++++++++++-------- msccl/language/types.py | 9 +++- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/examples/mscclang/mscclpp/allreduce_nvls.py b/examples/mscclang/mscclpp/allreduce_nvls.py index 319f7d1..a83228b 100644 --- a/examples/mscclang/mscclpp/allreduce_nvls.py +++ b/examples/mscclang/mscclpp/allreduce_nvls.py @@ -18,7 +18,6 @@ def allreduce_allpairs(gpus, instances): collective, instances, ): - # Each rank sends the nth chunk to the nth rank into scratch space for rank in range(size): index = rank @@ -27,9 +26,9 @@ def allreduce_allpairs(gpus, instances): # make sure the data is ready for nghr in range(size): if rank != nghr: - c_peer = chunk(rank, Buffer.input, index) + c_peer = chunk(nghr, Buffer.input, index) reduce_chunks.append(c_peer) - c_peer.signal(nghr, Buffer.input, index, sendtb=0) + c.signal(nghr, Buffer.input, index, sendtb=0) for nghr in range(size): if rank != nghr: c.wait(nghr, Buffer.input, index, recvtb=0) diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index f4dc319..ddd5d54 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -235,28 +235,42 @@ def add_group_store(self, rank, send_ref, recv_refs, tb, ch_type): def complete_channels(self): send_op = [Instruction.put, Instruction.signal, Instruction.put_packet] recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy] + group_send_op = [Instruction.group_store] + group_recv_op = [Instruction.group_load_reduce] for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): chans = set() for op in tb.ops: - if op.channel_type == ChannelType.none or op.channel_type == ChannelType.nvls: - continue - src_buffer = ( - Buffer.scratch - if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output - else op.src.buffer - ) - dst_buffer = ( - Buffer.scratch - if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output - else op.dst.buffer - ) - if op.inst in send_op: - chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank) - chans.add(chan) - elif op.inst in recv_op: - chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank) - chans.add(chan) + if op.src != None: + src_buffer = ( + Buffer.scratch + if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output + else op.src.buffer + ) + if op.dst != None: + dst_buffer = ( + Buffer.scratch + if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output + else op.dst.buffer + ) + if op.channel_type == ChannelType.nvls: + if op.inst in group_send_op: + ranks = [dst[0].rank for dst in op.dsts] + ranks.append(op.rank) + chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks) + chans.add(chan) + elif op.inst in group_recv_op: + ranks = [src[0].rank for src in op.srcs] + ranks.append(op.rank) + chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks) + chans.add(chan) + else: + if op.inst in send_op: + chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank) + chans.add(chan) + elif op.inst in recv_op: + chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank) + chans.add(chan) tb.channels = list(chans) def remove_redundant_signal_wait(self): diff --git a/msccl/language/types.py b/msccl/language/types.py index 66a8a0b..ef2addd 100644 --- a/msccl/language/types.py +++ b/msccl/language/types.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Union +from typing import Union, List from msccl.language.buffer import Buffer @@ -158,7 +158,12 @@ class Channel: srcBuffer: Buffer dstBuffer: Buffer type: ChannelType - connected_to: int + connected_to: Union[int, List[int]] + + def __hash__(self): + # Ensure connected_to is converted to a tuple if it's a list + connected_to_hashable = tuple(self.connected_to) if isinstance(self.connected_to, list) else self.connected_to + return hash((self.srcBuffer, self.dstBuffer, self.type, connected_to_hashable)) @dataclass