From 57c44bafe2b9be2ea2dbaa8a1d0c2a6e0b1af40b Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 30 Aug 2024 03:57:26 +0000 Subject: [PATCH] WIP --- msccl/language/mscclpp/__init__.py | 2 +- msccl/language/mscclpp/ir.py | 40 ++++++++++++++++++++---------- pyproject.toml | 2 +- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/msccl/language/mscclpp/__init__.py b/msccl/language/mscclpp/__init__.py index 1842ed6..2881684 100644 --- a/msccl/language/mscclpp/__init__.py +++ b/msccl/language/mscclpp/__init__.py @@ -209,7 +209,7 @@ def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): return self._put(dst, buffer, index, sendtb, chan_type) - def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, trans_to_packet=False): + def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, trans_to_packet=True): return self._put(dst, buffer, index, sendtb, chan_type, use_packet=True, trans_to_packet=trans_to_packet) def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): diff --git a/msccl/language/mscclpp/ir.py b/msccl/language/mscclpp/ir.py index ed72cc1..d9ebb49 100644 --- a/msccl/language/mscclpp/ir.py +++ b/msccl/language/mscclpp/ir.py @@ -6,7 +6,7 @@ from msccl.language.types import Buffer, ChannelType, Op, Program, MscclppInstruction as Instruction -_local_src_insts_mscclpp = { +_local_src_insts_mscclpp: set = { Instruction.put, Instruction.trans_put_packet, Instruction.signal, @@ -18,7 +18,7 @@ Instruction.reduce_send, Instruction.reduce_send_packet, } -_local_dst_insts_mscclpp = { +_local_dst_insts_mscclpp: set = { Instruction.get, Instruction.wait, Instruction.read_reduce_copy, @@ -32,6 +32,12 @@ Instruction.reduce_send_packet, } +_insts_no_need_sync_barrier: set = { + Instruction.copy_packet, + Instruction.reduce_packet, + Instruction.reduce_send_packet, +} + def ir_to_json(program: Program): # Figure out sizes of buffers based on usage @@ -104,16 +110,18 @@ def ir_to_json(program: Program): # Do some additional postprocessing of operations: # - Expand operations with dependencies with no-ops - if program.protocol != "LL": # TODO(binyli): fix this. Should based on OP type not algorithm - for gpu in program.gpus: - for tb in gpu.threadblocks: - new_ops = [] - for op in tb.ops: - # Expand extra dependencies into nop operations - for i, dep in enumerate(op.depends): - new_ops.append(Op(Instruction.nop, -1, None, None, [dep])) + for gpu in program.gpus: + for tb in gpu.threadblocks: + new_ops = [] + for op in tb.ops: + if op.inst in _insts_no_need_sync_barrier: new_ops.append(op) - tb.ops = new_ops + continue + # Expand extra dependencies into nop operations + for i, dep in enumerate(op.depends): + new_ops.append(Op(Instruction.nop, -1, None, None, [dep])) + new_ops.append(op) + tb.ops = new_ops # update step and tid for ops for gpu in program.gpus: @@ -247,7 +255,9 @@ def remove_empty_fields(d): "deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)), } elif op.inst == Instruction.put or op.inst == Instruction.trans_put_packet: - dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + dst_channel_ids = get_channel_ids( + op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + ) o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) elif op.inst == Instruction.get: @@ -256,7 +266,11 @@ def remove_empty_fields(d): ) i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} dsts = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.dsts)) - elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet or op.inst == Instruction.transform_to_packet: + elif ( + op.inst == Instruction.copy + or op.inst == Instruction.copy_packet + or op.inst == Instruction.transform_to_packet + ): src = op.src dst = op.dst elif op.inst == Instruction.transform_to_packet: diff --git a/pyproject.toml b/pyproject.toml index d891952..3d74b6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ [tool.black] -line-length = 140 +line-length = 120 target-version = ['py38'] include = '\.pyi?$'