Skip to content

Commit

Permalink
Initialization all the chuncks in AllGather Collective (#26)
Browse files Browse the repository at this point in the history
* implementing allgather collective creating all the chuncks

* adding comments
  • Loading branch information
caiomcbr authored Nov 26, 2024
1 parent c7ba9b0 commit 514a1c9
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions msccl/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@ def check(self, prog):


class AllGather(Collective):
def __init__(self, num_ranks, chunk_factor, inplace):
def __init__(self, num_ranks, chunk_factor, inplace, create_all_chunks=False):
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "allgather"
# This flag is a temporary solution, which initialize all the chuncks only for inputbuffer
# In this future we need to remove this flag and always initialize all the chunks
self.create_all_chunks = create_all_chunks

# Initializes input buffer for an allgather
def init_buffers(self):
Expand All @@ -84,8 +87,13 @@ def init_buffers(self):
# Inplace AllGather only uses the output buffer
for r in range(self.num_ranks):
output_buffer = [None] * (self.num_ranks * self.chunk_factor)
for ch in range(self.chunk_factor):
output_buffer[r * self.chunk_factor + ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch)
if not self.create_all_chunks:
for ch in range(self.chunk_factor):
output_buffer[r * self.chunk_factor + ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch)
else:
for rank in range(self.num_ranks):
for ch in range(self.chunk_factor):
output_buffer[rank * self.chunk_factor + ch] = Chunk(rank, ch, -1, rank * self.chunk_factor + ch)
buffers = {
Buffer.input: output_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor],
Buffer.output: output_buffer,
Expand Down

0 comments on commit 514a1c9

Please sign in to comment.