Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Apr 24, 2024
1 parent c2bd38f commit bb3aebe
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
27 changes: 13 additions & 14 deletions msccl/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from msccl.language import *

class Collective():
def __init__(self, num_ranks, chunk_factor, inplace):
def __init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups = 1):
self.num_ranks = num_ranks
self.chunk_factor = chunk_factor
self.inplace = inplace
self.name = "custom"
self.num_chunk_groups = num_chunk_groups

def init_buffers(self):
pass
Expand Down Expand Up @@ -35,10 +36,10 @@ def init_buffers(self):
chunk = Chunk(r, index, index//self.chunk_factor, index % self.chunk_factor + r*self.chunk_factor)
input_buffer[index] = chunk
if self.inplace:
buffers = {Buffer.input : input_buffer,
buffers = {Buffer.input : input_buffer,
Buffer.output : input_buffer}
else:
buffers = {Buffer.input : input_buffer,
buffers = {Buffer.input : input_buffer,
Buffer.output : output_buffer}
rank_buffers.append(buffers)
return rank_buffers
Expand Down Expand Up @@ -69,7 +70,7 @@ def __init__(self, num_ranks, chunk_factor, inplace):
def init_buffers(self):
rank_buffers = []
if self.inplace:
# Inplace AllGather only uses the output buffer
# Inplace AllGather only uses the output buffer
for r in range(self.num_ranks):
output_buffer = [None] * (self.num_ranks * self.chunk_factor)
for ch in range(self.chunk_factor):
Expand All @@ -83,11 +84,11 @@ def init_buffers(self):
output_buffer = [None] * (self.num_ranks * self.chunk_factor)
for ch in range(self.chunk_factor):
input_buffer[ch] = Chunk(r, ch, -1, r*self.chunk_factor+ch)
buffers = {Buffer.input : input_buffer,
buffers = {Buffer.input : input_buffer,
Buffer.output : output_buffer}
rank_buffers.append(buffers)
return rank_buffers

# Expected output buffer for allgather
def check(self, prog):
correct = True
Expand All @@ -106,7 +107,7 @@ def check(self, prog):
correct = False
return correct


def get_buffer_index(self, rank, buffer, index):
# For inplace AllGathers, the input buffer points into the output buffer
if self.inplace and buffer == Buffer.input:
Expand All @@ -115,12 +116,11 @@ def get_buffer_index(self, rank, buffer, index):
return buffer, index



class AllReduce(Collective):

def __init__(self, num_ranks, chunk_factor, inplace):
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = 'allreduce'
Collective.__init__(self, num_ranks, chunk_factor, inplace, num_ranks)
self.name = "allreduce"

def init_buffers(self):
chunks_per_node = self.chunk_factor
Expand All @@ -133,10 +133,10 @@ def init_buffers(self):
input_buffer.append(Chunk(r, c, -1, c))
# Input and output buffer are the same.
if self.inplace:
buffers = {Buffer.input : input_buffer,
buffers = {Buffer.input : input_buffer,
Buffer.output : input_buffer}
else:
buffers = {Buffer.input : input_buffer,
buffers = {Buffer.input : input_buffer,
Buffer.output : output_buffer}
rank_buffers.append(buffers)
return rank_buffers
Expand Down Expand Up @@ -190,7 +190,7 @@ def init_buffers(self):
for i in range(self.num_ranks):
for c in range(self.chunk_factor):
input_buffer.append(Chunk(r, i*self.chunk_factor + c, i, c))
buffers = {Buffer.input : input_buffer,
buffers = {Buffer.input : input_buffer,
Buffer.output : output_buffer}
rank_buffers.append(buffers)
return rank_buffers
Expand Down Expand Up @@ -223,4 +223,3 @@ def get_buffer_index(self, rank, buffer, index):
return Buffer.input, index + rank * self.chunk_factor
else:
return buffer, index

1 change: 1 addition & 0 deletions msccl/language/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Program:
inplace: bool
protocol: str
gpus: list = field(default_factory=list)
num_chunk_groups: int = 1


@dataclass
Expand Down
1 change: 1 addition & 0 deletions msccl/language/ir_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def remove_empty_fields(d):
"inputChunks": gpu.input_chunks,
"outputChunks": gpu.output_chunks,
"scratchChunks": gpu.scratch_chunks,
"chunkGroups": program.num_chunk_groups,
"threadblocks": [],
"channels": [],
}
Expand Down
1 change: 1 addition & 0 deletions msccl/language/mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def lower(self):
self.collective.inplace,
self.protocol,
gpu_prgms,
self.collective.num_chunk_groups * self.instances
)

def generate_json(self):
Expand Down

0 comments on commit bb3aebe

Please sign in to comment.