Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Nov 3, 2024
1 parent de000fd commit 797fb65
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 6 deletions.
56 changes: 50 additions & 6 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,57 @@ def reduce_packet(self, other_chunkref, recvtb=-1):
# Group operations. These operations are used to perform collective operations across multiple chunks.
# For now, all chunks must has the same buffer type and offset.
# """
# # Reads the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
# def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, channel_type=ChannelType.none):
# pass
def _assert_same_index(self, other_chunkrefs):
for i in range(len(other_chunkrefs)):
assert self.index == other_chunkrefs[i].index, "Group operations only supports chunks with the same index"

def _assert_same_buffer(self, other_chunkrefs):
for i in range(len(other_chunkrefs)):
assert (
self.buffer == other_chunkrefs[i].buffer
), "Group operations only supports chunks with the same buffer"

def _group_load_reduce(self, other_chunkrefs: list, recvtb=-1):
# may need to check if sharp supported in topologies
pass

def _group_store(self, other_chunkrefs: list, sendtb=-1):
pass

# Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref
def group_load_reduce(self, other_chunkrefs: list, recvtb: int, chan_type=ChannelType.nvls):
assert (
len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls
), "Group load reduce only supports nvls channel"
self._assert_same_rank(other_chunkrefs)
self._assert_same_index(other_chunkrefs)
self._assert_same_buffer(other_chunkrefs)
self._group_load_reduce(other_chunkrefs, recvtb)

# Copies the chunk(s) referenced by this chunkref onto other_chunkrefs
def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls):
for dst in dsts:
self.prog.check_buffer_exists(dst, buffer)
assert index == -1 or self.index == index, "Group store only supports chunks with the same index"
buffer, index = self._get_buffer_index(dst, buffer, index)

# # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index)
# def group_store(self, other_chunkrefs: list, sendtb=-1, channel_type=ChannelType.none):
# pass
# # Direct put
# assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}"
# dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
# self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
# if use_packet:
# self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, True)
# self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none)
# self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none)
# else:
# self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type)
# return dst_chunkref
# assert (
# len(other_chunkrefs) > 0 and channel_type == ChannelType.nvls
# ), "Group store only supports nvls channel"
# self._assert_same_rank(other_chunkrefs)
# self._assert_same_index(other_chunkrefs)
# self._group_store(other_chunkrefs, sendtb)

def get_origin_index(self, index=0):
return self._get_chunk(index + self.index).origin_index
Expand Down
1 change: 1 addition & 0 deletions msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class ChannelType(Enum):
proxy = "proxy"
sm = "sm"
none = "none"
nvls = "nvls"

def __str__(self):
return self.value
Expand Down

0 comments on commit 797fb65

Please sign in to comment.