Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Aug 30, 2024
1 parent a15cb19 commit 71bf147
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
11 changes: 9 additions & 2 deletions examples/mscclang/mscclpp/send_recv_a100_packet.mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,15 @@ def send_recv(instances):
if nghr == r:
continue
c = chunk(r, Buffer.input, 0)
c = c.trans_to_packet(r, "scratch", 0, sendtb=0)
c.put_packet(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy, trans_to_packet=False)
c.put_packet(
nghr,
"scratch",
1,
sendtb=0,
chan_type=ChannelType.proxy,
temp_buffer="scratch",
temp_buffer_index=0,
)

for r in range(size):
c = chunk(r, "scratch", 1)
Expand Down
27 changes: 18 additions & 9 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,7 @@ def _get_buffer_index(self, remote_rank, buffer, index):
return buffer, self.prog.buffers[remote_rank][buffer].instance_size()
return buffer, index

def _put(
self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False, trans_to_packet=False
):
def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False):
self.prog.check_buffer_exists(dst, buffer)
assert self.rank != dst, "Cannot put to the same rank"
buffer, index = self._get_buffer_index(dst, buffer, index)
Expand All @@ -198,7 +196,7 @@ def _put(
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
if use_packet:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, trans_to_packet)
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, True)
self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none)
self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none)
else:
Expand All @@ -208,8 +206,22 @@ def _put(
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, chan_type)

def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, trans_to_packet=True):
return self._put(dst, buffer, index, sendtb, chan_type, use_packet=True, trans_to_packet=trans_to_packet)
def put_packet(
self,
dst,
buffer=None,
index=-1,
sendtb=-1,
chan_type=ChannelType.sm,
temp_buffer=None,
temp_buffer_index=-1,
):
chunk_ref = self
if chan_type == ChannelType.proxy:
chunk_ref = self._copy(
self.rank, temp_buffer, temp_buffer_index, sendtb, trans_from_packet=False, trans_to_packet=True
)
return chunk_ref._put(dst, buffer, index, sendtb, chan_type, True)

def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
self.prog.check_buffer_exists(src, buffer)
Expand Down Expand Up @@ -269,9 +281,6 @@ def _copy(self, dst, buffer=None, index=-1, sendtb=-1, trans_from_packet=False,
def copy(self, dst, buffer=None, index=-1, sendtb=-1):
return self._copy(dst, buffer, index, sendtb)

def trans_to_packet(self, dst, buffer=None, index=-1, sendtb=-1):
return self._copy(dst, buffer, index, sendtb, trans_from_packet=False, trans_to_packet=True)

def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1):
return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False)

Expand Down
16 changes: 8 additions & 8 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def add_reduce(self, rank, send_ref, recv_ref, tb, use_packet=False):
return op

# InstructionDAG - adds a put node
def add_put(self, rank, send_ref, recv_ref, tb, ch_type, trans_to_packet=False):
def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet=False):
tb_step = self._get_tb_step(rank, tb)
if trans_to_packet:
if use_packet:
op = Op(
Instruction.trans_put_packet,
Instruction.put_packet,
rank,
send_ref,
recv_ref,
Expand Down Expand Up @@ -167,7 +167,7 @@ def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type):
return op

def complete_channels(self):
send_op = [Instruction.put, Instruction.signal, Instruction.trans_put_packet]
send_op = [Instruction.put, Instruction.signal, Instruction.put_packet]
recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy]
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
Expand Down Expand Up @@ -198,7 +198,7 @@ def _optimize_redundant_signal_wait(self):
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
if op.inst == Instruction.trans_put_packet:
if op.inst == Instruction.put_packet:
fused = False
for next_op in op.next:
if next_op.inst == Instruction.signal:
Expand Down Expand Up @@ -439,7 +439,7 @@ def _optimize_rrcs_rs(self):
fused = False
for next_op in op.next:
if (
next_op.inst == Instruction.trans_put_packet
next_op.inst == Instruction.put_packet
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and next_op.channel_type == ChannelType.sm
Expand Down Expand Up @@ -533,12 +533,12 @@ def _optimize_get_put(self):
fused = True
if fused:
continue
elif op.inst == Instruction.trans_put_packet:
elif op.inst == Instruction.put_packet:
fused = False
if len(queue) > 1:
seq_op = queue[1]
if (
seq_op.inst == Instruction.trans_put_packet
seq_op.inst == Instruction.put_packet
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and same_count(op, seq_op)
Expand Down
4 changes: 2 additions & 2 deletions msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

_local_src_insts_mscclpp: set = {
Instruction.put,
Instruction.trans_put_packet,
Instruction.put_packet,
Instruction.signal,
Instruction.copy,
Instruction.copy_packet,
Expand Down Expand Up @@ -254,7 +254,7 @@ def remove_empty_fields(d):
"name": op.inst.value,
"deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)),
}
elif op.inst == Instruction.put or op.inst == Instruction.trans_put_packet:
elif op.inst == Instruction.put or op.inst == Instruction.put_packet:
dst_channel_ids = get_channel_ids(
op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type
)
Expand Down
2 changes: 1 addition & 1 deletion msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class MscclppInstruction(Enum):
reduce_send_packet = "rspkt"
reduce_packet = "rpkt"
put = "put"
trans_put_packet = "ppkt"
put_packet = "ppkt"
get = "get"
wait = "wait"
signal = "signal"
Expand Down

0 comments on commit 71bf147

Please sign in to comment.