From 797fb651ddea0610cc8dbe902ae1fe48b707a866 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sun, 3 Nov 2024 02:57:05 +0000 Subject: [PATCH] WIP --- msccl/language/mscclpp/__init__.py | 56 ++++++++++++++++++++++++++---- msccl/language/types.py | 1 + 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/msccl/language/mscclpp/__init__.py b/msccl/language/mscclpp/__init__.py index 2f4b370..dae836f 100644 --- a/msccl/language/mscclpp/__init__.py +++ b/msccl/language/mscclpp/__init__.py @@ -331,13 +331,57 @@ def reduce_packet(self, other_chunkref, recvtb=-1): # Group operations. These operations are used to perform collective operations across multiple chunks. # For now, all chunks must has the same buffer type and offset. # """ - # # Reads the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref - # def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, channel_type=ChannelType.none): - # pass + def _assert_same_index(self, other_chunkrefs): + for i in range(len(other_chunkrefs)): + assert self.index == other_chunkrefs[i].index, "Group operations only supports chunks with the same index" + + def _assert_same_buffer(self, other_chunkrefs): + for i in range(len(other_chunkrefs)): + assert ( + self.buffer == other_chunkrefs[i].buffer + ), "Group operations only supports chunks with the same buffer" + + def _group_load_reduce(self, other_chunkrefs: list, recvtb=-1): + # may need to check if sharp supported in topologies + pass + + def _group_store(self, other_chunkrefs: list, sendtb=-1): + pass + + # Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref + def group_load_reduce(self, other_chunkrefs: list, recvtb: int, chan_type=ChannelType.nvls): + assert ( + len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls + ), "Group load reduce only supports nvls channel" + self._assert_same_rank(other_chunkrefs) + self._assert_same_index(other_chunkrefs) + self._assert_same_buffer(other_chunkrefs) + self._group_load_reduce(other_chunkrefs, recvtb) + + # Copies the chunk(s) referenced by this chunkref onto other_chunkrefs + def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls): + for dst in dsts: + self.prog.check_buffer_exists(dst, buffer) + assert index == -1 or self.index == index, "Group store only supports chunks with the same index" + buffer, index = self._get_buffer_index(dst, buffer, index) - # # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) - # def group_store(self, other_chunkrefs: list, sendtb=-1, channel_type=ChannelType.none): - # pass + # # Direct put + # assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" + # dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + # self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + # if use_packet: + # self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, True) + # self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none) + # self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none) + # else: + # self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type) + # return dst_chunkref + # assert ( + # len(other_chunkrefs) > 0 and channel_type == ChannelType.nvls + # ), "Group store only supports nvls channel" + # self._assert_same_rank(other_chunkrefs) + # self._assert_same_index(other_chunkrefs) + # self._group_store(other_chunkrefs, sendtb) def get_origin_index(self, index=0): return self._get_chunk(index + self.index).origin_index diff --git a/msccl/language/types.py b/msccl/language/types.py index 9a8e676..314e657 100644 --- a/msccl/language/types.py +++ b/msccl/language/types.py @@ -144,6 +144,7 @@ class ChannelType(Enum): proxy = "proxy" sm = "sm" none = "none" + nvls = "nvls" def __str__(self): return self.value