Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Aug 29, 2024
1 parent 2a3e518 commit ad14244
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 71 deletions.
11 changes: 5 additions & 6 deletions examples/mscclang/mscclpp/send_recv_a100_packet.mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,17 @@ def send_recv(instances):
if nghr == r:
continue
c = chunk(r, Buffer.input, 0)
c.put_packet(nghr, "scratch", 0, sendtb=0, channel_type=ChannelType.proxy)
c = c.trans_to_packet(r, "scratch", 0, sendtb=0)
c.put_packet(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy, trans_to_packet=False)

for r in range(size):
for nghr in range(size):
if nghr == r:
continue
c = chunk(r, "scratch", 0)
c.get_packet(nghr, Buffer.output, 0, recvtb=0, channel_type=ChannelType.proxy)
c = chunk(r, "scratch", 1)
c.copy_packet(r, Buffer.output, 0, sendtb=0)

Json()
Check()


parser = argparse.ArgumentParser()
parser.add_argument("instances", type=int, help="number of instances")

Expand Down
31 changes: 15 additions & 16 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
raise RuntimeError("This program is not currently in context")
_current_program = None


def _convert_to_exectuion_plan(self):
ops = self.instr_dag.convert_set_list()
ops = sorted(ops, key=lambda x: x.step)
Expand Down Expand Up @@ -187,12 +186,8 @@ def _get_buffer_index(self, remote_rank, buffer, index):
return buffer, self.prog.buffers[remote_rank][buffer].instance_size()
return buffer, index

def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False):
def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False, trans_to_packet=True):
self.prog.check_buffer_exists(dst, buffer)
temp_chunk=None
if chan_type == ChannelType.proxy:
self.prog.check_buffer_exists(self.rank, "scratch")
temp_chunk = self.prog.get_ref(self.rank, "scratch", 0, self.size)
assert self.rank != dst, "Cannot put to the same rank"
buffer, index = self._get_buffer_index(dst, buffer, index)

Expand All @@ -201,18 +196,21 @@ def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm,
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
if use_packet:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, use_packet, temp_chunk)
if trans_to_packet:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, trans_to_packet)
else:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type)
self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none)
self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none)
else:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, temp_chunk)
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type)
return dst_chunkref

def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, chan_type)

def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True)
def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, trans_to_packet=False):
return self._put(dst, buffer, index, sendtb, chan_type, use_packet=True, trans_to_packet=trans_to_packet)

def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
self.prog.check_buffer_exists(src, buffer)
Expand Down Expand Up @@ -253,7 +251,7 @@ def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
src_chunkref = self.prog.get_ref(src, buffer, index, self.size)
self.prog.instr_dag.add_wait(receiver, self, src_chunkref, recvtb, chan_type)

def _copy(self, dst, buffer=None, index=-1, sendtb=-1, use_packet=False):
def _copy(self, dst, buffer=None, index=-1, sendtb=-1, trans_from_packet=False, trans_to_packet=False):
self.prog.check_buffer_exists(dst, buffer)
buffer, index = self._get_buffer_index(dst, buffer, index)

Expand All @@ -264,24 +262,25 @@ def _copy(self, dst, buffer=None, index=-1, sendtb=-1, use_packet=False):
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)

assert self.rank == dst, "Chunk copy only supports intra-rank communication"
self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, use_packet)
self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, trans_from_packet, trans_to_packet)

return dst_chunkref

# Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index)
def copy(self, dst, buffer=None, index=-1, sendtb=-1):
return self._copy(dst, buffer, index, sendtb)

def trans_to_packet(self, dst, buffer=None, index=-1, sendtb=-1):
return self._copy(dst, buffer, index, sendtb, trans_from_packet=False, trans_to_packet=True)

def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1):
return self._copy(dst, buffer, index, sendtb, use_packet=True)
return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False)

def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm, use_packet=False):
dst = self.rank
src = other_chunkref.rank
assert self.prog.topo.link(src, dst) or src == dst, f"No link from {src} to {dst}"
self.prog.apply_reduce(
src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size
)
self.prog.apply_reduce(src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size)
if use_packet:
assert src == dst, "Packet reduce only supports intra-rank communication"

Expand Down
55 changes: 12 additions & 43 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def __init__(self, num_ranks, buffers):
super().__init__(num_ranks, buffers)

# InstructionDAG - adds a copy node
def add_copy(self, rank, send_ref, recv_ref, tb, use_packet=False):
def add_copy(self, rank, send_ref, recv_ref, tb, trans_from_packet=False, trans_to_packet=False):
tb_step = self._get_tb_step(rank, tb)
if use_packet:
if trans_from_packet:
op = Op(Instruction.copy_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
elif trans_to_packet:
op = Op(Instruction.transform_to_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
else:
op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
dstbuffer = recv_ref.buffer
Expand Down Expand Up @@ -61,44 +63,11 @@ def add_reduce(self, rank, send_ref, recv_ref, tb, use_packet=False):
return op

# InstructionDAG - adds a put node
def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet=False, temp_chunk=None):
def add_put(self, rank, send_ref, recv_ref, tb, ch_type, trans_to_packet=False):
tb_step = self._get_tb_step(rank, tb)
if ch_type == ChannelType.proxy and temp_chunk is not None:
op = Op(
Instruction.transform_to_packet,
rank,
send_ref,
temp_chunk,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
tb_step = self._get_tb_step(rank, tb)
op2 = Op(
Instruction.put,
rank,
send_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
self._read(rank, buffer, index, size, op)
self._write(rank, temp_chunk.buffer, temp_chunk.index, temp_chunk.size, op)
self._read(rank, temp_chunk.buffer, temp_chunk.index, temp_chunk.size, op2)
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
op2.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
return op
if use_packet:
if trans_to_packet:
op = Op(
Instruction.put_packet,
Instruction.trans_put_packet,
rank,
send_ref,
recv_ref,
Expand Down Expand Up @@ -198,7 +167,7 @@ def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type):
return op

def complete_channels(self):
send_op = [Instruction.put, Instruction.signal, Instruction.put_packet]
send_op = [Instruction.put, Instruction.signal, Instruction.trans_put_packet]
recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy]
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
Expand Down Expand Up @@ -229,7 +198,7 @@ def _optimize_redundant_signal_wait(self):
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
if op.inst == Instruction.put_packet:
if op.inst == Instruction.trans_put_packet:
fused = False
for next_op in op.next:
if next_op.inst == Instruction.signal:
Expand Down Expand Up @@ -470,7 +439,7 @@ def _optimize_rrcs_rs(self):
fused = False
for next_op in op.next:
if (
next_op.inst == Instruction.put_packet
next_op.inst == Instruction.trans_put_packet
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and next_op.channel_type == ChannelType.sm
Expand Down Expand Up @@ -564,12 +533,12 @@ def _optimize_get_put(self):
fused = True
if fused:
continue
elif op.inst == Instruction.put_packet:
elif op.inst == Instruction.trans_put_packet:
fused = False
if len(queue) > 1:
seq_op = queue[1]
if (
seq_op.inst == Instruction.put_packet
seq_op.inst == Instruction.trans_put_packet
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and same_count(op, seq_op)
Expand Down
8 changes: 3 additions & 5 deletions msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

_local_src_insts_mscclpp = {
Instruction.put,
Instruction.put_packet,
Instruction.trans_put_packet,
Instruction.signal,
Instruction.copy,
Instruction.copy_packet,
Expand Down Expand Up @@ -246,10 +246,8 @@ def remove_empty_fields(d):
"name": op.inst.value,
"deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)),
}
elif op.inst == Instruction.put or op.inst == Instruction.put_packet:
dst_channel_ids = get_channel_ids(
op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type
)
elif op.inst == Instruction.put or op.inst == Instruction.trans_put_packet:
dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))
elif op.inst == Instruction.get:
Expand Down
2 changes: 1 addition & 1 deletion msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class MscclppInstruction(Enum):
reduce_send_packet = "rspkt"
reduce_packet = "rpkt"
put = "put"
put_packet = "ppkt"
trans_put_packet = "ppkt"
get = "get"
wait = "wait"
signal = "signal"
Expand Down

0 comments on commit ad14244

Please sign in to comment.