Skip to content

Commit

Permalink
fix out of bound error
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu authored and Ubuntu committed Dec 12, 2024
1 parent e88b672 commit daf360d
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions msccl/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

class Collective:

def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
def __init__(self, num_ranks, chunk_factor, inplace, root=0, num_ranks_per_node=-1, **kwargs):
self.num_ranks = num_ranks
self.root=root
self.chunk_factor = chunk_factor
self.inplace = inplace
self.name = "custom"
self.root=root
# Devide the buffer into num_chunk_groups group
if num_ranks_per_node == -1:
self.num_ranks_per_node = num_ranks
Expand Down Expand Up @@ -71,8 +73,8 @@ def check(self, prog):
return correct

class Broadcast(Collective):
def __init__(self, num_ranks, root, chunk_factor, inplace, create_all_chunks=False):
Collective.__init__(self, num_ranks, root, chunk_factor, inplace)
def __init__(self, num_ranks, chunk_factor, inplace, root, create_all_chunks=False):
Collective.__init__(self, num_ranks, chunk_factor, inplace, root)
self.name = "broadcast"
# This flag is a temporary solution, which initialize all the chuncks only for inputbuffer
# In this future we need to remove this flag and always initialize all the chunks
Expand All @@ -82,27 +84,23 @@ def __init__(self, num_ranks, root, chunk_factor, inplace, create_all_chunks=Fal
def init_buffers(self):
rank_buffers = []
if self.inplace:
# Inplace broadcast only uses the output buffer
# Inplace broadcast only uses the input buffer
for r in range(self.num_ranks):
input_buffer = [None] * (self.chunk_factor)
#if not self.create_all_chunks:
# for ch in range(self.chunk_factor):
# output_buffer[ch] = Chunk(r, ch, -1, ch)
#else:
for ch in range(self.chunk_factor):
input_buffer[ch] = Chunk(root, ch, -1, ch)
input_buffer[ch] = Chunk(self.root, ch, -1, ch)
buffers = {
Buffer.input: input_buffer, #this only needs to be set for the root
Buffer.input: input_buffer,
Buffer.output: input_buffer,
}
rank_buffers.append(buffers)
else:
for r in range(self.num_ranks):
input_buffer = [None] * self.chunk_factor
output_buffer = [None] * (self.chunk_factor)
if r==root:
output_buffer = [None] * self.chunk_factor
if r==self.root:
for ch in range(self.chunk_factor):
input_buffer[ch] = Chunk(root, ch, -1, ch)
input_buffer[ch] = Chunk(self.root, ch, -1, ch)
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} # add if statement
rank_buffers.append(buffers)
return rank_buffers
Expand Down Expand Up @@ -130,10 +128,6 @@ def check(self, prog):
def get_buffer_index(self, rank, buffer, index):
# For inplace Broadcast, the input buffer points into the output buffer
return buffer, index
#if self.inplace and buffer == Buffer.input:
# return Buffer.output, index
#else:
# return buffer, index

class AllGather(Collective):
def __init__(self, num_ranks, chunk_factor, inplace, create_all_chunks=False):
Expand Down

0 comments on commit daf360d

Please sign in to comment.