Skip to content

Commit

Permalink
add broadcast to collective.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu authored and Ubuntu committed Dec 12, 2024
1 parent 146677a commit e88b672
Showing 1 changed file with 64 additions and 1 deletion.
65 changes: 64 additions & 1 deletion msccl/language/collectives.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass, field
from msccl.language.ir import Buffer
from msccl.language import *
#test

class Collective:

Expand Down Expand Up @@ -71,6 +70,70 @@ def check(self, prog):
correct = False
return correct

class Broadcast(Collective):
def __init__(self, num_ranks, root, chunk_factor, inplace, create_all_chunks=False):
Collective.__init__(self, num_ranks, root, chunk_factor, inplace)
self.name = "broadcast"
# This flag is a temporary solution, which initialize all the chuncks only for inputbuffer
# In this future we need to remove this flag and always initialize all the chunks
self.create_all_chunks = create_all_chunks

# Initializes input buffer for an broadcast
def init_buffers(self):
rank_buffers = []
if self.inplace:
# Inplace broadcast only uses the output buffer
for r in range(self.num_ranks):
input_buffer = [None] * (self.chunk_factor)
#if not self.create_all_chunks:
# for ch in range(self.chunk_factor):
# output_buffer[ch] = Chunk(r, ch, -1, ch)
#else:
for ch in range(self.chunk_factor):
input_buffer[ch] = Chunk(root, ch, -1, ch)
buffers = {
Buffer.input: input_buffer, #this only needs to be set for the root
Buffer.output: input_buffer,
}
rank_buffers.append(buffers)
else:
for r in range(self.num_ranks):
input_buffer = [None] * self.chunk_factor
output_buffer = [None] * (self.chunk_factor)
if r==root:
for ch in range(self.chunk_factor):
input_buffer[ch] = Chunk(root, ch, -1, ch)
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} # add if statement
rank_buffers.append(buffers)
return rank_buffers

# Expected output buffer for broadcast
def check(self, prog):
correct = True
buf = Buffer.output
for r in range(self.num_ranks):
output = prog.buffers[r][buf]
for i in range(self.num_ranks):
for ch in range(self.chunk_factor):
index = ch
chunk = output[index]
if chunk is None:
print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None")
correct = False
elif chunk.origin_rank != i or chunk.origin_index != ch:
print(
f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})"
)
correct = False
return correct

def get_buffer_index(self, rank, buffer, index):
# For inplace Broadcast, the input buffer points into the output buffer
return buffer, index
#if self.inplace and buffer == Buffer.input:
# return Buffer.output, index
#else:
# return buffer, index

class AllGather(Collective):
def __init__(self, num_ranks, chunk_factor, inplace, create_all_chunks=False):
Expand Down

0 comments on commit e88b672

Please sign in to comment.