Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Nov 4, 2024
1 parent 4b30b54 commit 78cf779
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 23 deletions.
5 changes: 2 additions & 3 deletions examples/mscclang/mscclpp/allreduce_nvls.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def allreduce_allpairs(gpus, instances):
collective,
instances,
):

# Each rank sends the nth chunk to the nth rank into scratch space
for rank in range(size):
index = rank
Expand All @@ -27,9 +26,9 @@ def allreduce_allpairs(gpus, instances):
# make sure the data is ready
for nghr in range(size):
if rank != nghr:
c_peer = chunk(rank, Buffer.input, index)
c_peer = chunk(nghr, Buffer.input, index)
reduce_chunks.append(c_peer)
c_peer.signal(nghr, Buffer.input, index, sendtb=0)
c.signal(nghr, Buffer.input, index, sendtb=0)
for nghr in range(size):
if rank != nghr:
c.wait(nghr, Buffer.input, index, recvtb=0)
Expand Down
50 changes: 32 additions & 18 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,28 +235,42 @@ def add_group_store(self, rank, send_ref, recv_refs, tb, ch_type):
def complete_channels(self):
send_op = [Instruction.put, Instruction.signal, Instruction.put_packet]
recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy]
group_send_op = [Instruction.group_store]
group_recv_op = [Instruction.group_load_reduce]
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
chans = set()
for op in tb.ops:
if op.channel_type == ChannelType.none or op.channel_type == ChannelType.nvls:
continue
src_buffer = (
Buffer.scratch
if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output
else op.src.buffer
)
dst_buffer = (
Buffer.scratch
if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output
else op.dst.buffer
)
if op.inst in send_op:
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank)
chans.add(chan)
elif op.inst in recv_op:
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank)
chans.add(chan)
if op.src != None:
src_buffer = (
Buffer.scratch
if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output
else op.src.buffer
)
if op.dst != None:
dst_buffer = (
Buffer.scratch
if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output
else op.dst.buffer
)
if op.channel_type == ChannelType.nvls:
if op.inst in group_send_op:
ranks = [dst[0].rank for dst in op.dsts]
ranks.append(op.rank)
chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks)
chans.add(chan)
elif op.inst in group_recv_op:
ranks = [src[0].rank for src in op.srcs]
ranks.append(op.rank)
chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks)
chans.add(chan)
else:
if op.inst in send_op:
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank)
chans.add(chan)
elif op.inst in recv_op:
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank)
chans.add(chan)
tb.channels = list(chans)

def remove_redundant_signal_wait(self):
Expand Down
9 changes: 7 additions & 2 deletions msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dataclasses import dataclass, field
from enum import Enum
from typing import Union
from typing import Union, List

from msccl.language.buffer import Buffer

Expand Down Expand Up @@ -158,7 +158,12 @@ class Channel:
srcBuffer: Buffer
dstBuffer: Buffer
type: ChannelType
connected_to: int
connected_to: Union[int, List[int]]

def __hash__(self):
# Ensure connected_to is converted to a tuple if it's a list
connected_to_hashable = tuple(self.connected_to) if isinstance(self.connected_to, list) else self.connected_to
return hash((self.srcBuffer, self.dstBuffer, self.type, connected_to_hashable))


@dataclass
Expand Down

0 comments on commit 78cf779

Please sign in to comment.