Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Aug 30, 2024
1 parent ad14244 commit 57c44ba
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
2 changes: 1 addition & 1 deletion msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm,
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, chan_type)

def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, trans_to_packet=False):
def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, trans_to_packet=True):
return self._put(dst, buffer, index, sendtb, chan_type, use_packet=True, trans_to_packet=trans_to_packet)

def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
Expand Down
40 changes: 27 additions & 13 deletions msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from msccl.language.types import Buffer, ChannelType, Op, Program, MscclppInstruction as Instruction

_local_src_insts_mscclpp = {
_local_src_insts_mscclpp: set = {
Instruction.put,
Instruction.trans_put_packet,
Instruction.signal,
Expand All @@ -18,7 +18,7 @@
Instruction.reduce_send,
Instruction.reduce_send_packet,
}
_local_dst_insts_mscclpp = {
_local_dst_insts_mscclpp: set = {
Instruction.get,
Instruction.wait,
Instruction.read_reduce_copy,
Expand All @@ -32,6 +32,12 @@
Instruction.reduce_send_packet,
}

_insts_no_need_sync_barrier: set = {
Instruction.copy_packet,
Instruction.reduce_packet,
Instruction.reduce_send_packet,
}


def ir_to_json(program: Program):
# Figure out sizes of buffers based on usage
Expand Down Expand Up @@ -104,16 +110,18 @@ def ir_to_json(program: Program):

# Do some additional postprocessing of operations:
# - Expand operations with dependencies with no-ops
if program.protocol != "LL": # TODO(binyli): fix this. Should based on OP type not algorithm
for gpu in program.gpus:
for tb in gpu.threadblocks:
new_ops = []
for op in tb.ops:
# Expand extra dependencies into nop operations
for i, dep in enumerate(op.depends):
new_ops.append(Op(Instruction.nop, -1, None, None, [dep]))
for gpu in program.gpus:
for tb in gpu.threadblocks:
new_ops = []
for op in tb.ops:
if op.inst in _insts_no_need_sync_barrier:
new_ops.append(op)
tb.ops = new_ops
continue
# Expand extra dependencies into nop operations
for i, dep in enumerate(op.depends):
new_ops.append(Op(Instruction.nop, -1, None, None, [dep]))
new_ops.append(op)
tb.ops = new_ops

# update step and tid for ops
for gpu in program.gpus:
Expand Down Expand Up @@ -247,7 +255,9 @@ def remove_empty_fields(d):
"deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)),
}
elif op.inst == Instruction.put or op.inst == Instruction.trans_put_packet:
dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
dst_channel_ids = get_channel_ids(
op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type
)
o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))
elif op.inst == Instruction.get:
Expand All @@ -256,7 +266,11 @@ def remove_empty_fields(d):
)
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
dsts = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.dsts))
elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet or op.inst == Instruction.transform_to_packet:
elif (
op.inst == Instruction.copy
or op.inst == Instruction.copy_packet
or op.inst == Instruction.transform_to_packet
):
src = op.src
dst = op.dst
elif op.inst == Instruction.transform_to_packet:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[tool.black]
line-length = 140
line-length = 120
target-version = ['py38']
include = '\.pyi?$'

0 comments on commit 57c44ba

Please sign in to comment.