Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Sep 11, 2024
1 parent 62b3b96 commit 3e3104e
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 140 deletions.
11 changes: 11 additions & 0 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type)

# only proxy channel need to use this function
def flush(self, dst, buffer=None, index=-1, sendtb=-1):
sender = self.rank
receiver = dst
assert sender != receiver, "Cannot flush to the same rank"
buffer, index = self._get_buffer_index(dst, buffer, index)

assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}"
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb)

def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
sender = src
receiver = self.rank
Expand Down
194 changes: 54 additions & 140 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


from msccl.language.buffer import Buffer
from msccl.language.types import Channel, ChannelType
from msccl.language.instruction_dag import (
buf_dst_src_match,
merge_op,
Expand All @@ -16,7 +15,16 @@
same_src_dst_buffer_type,
)
from msccl.language.instruction_dag import InstructionDAG
from msccl.language.types import ChunkRef, MscclppInstruction as Instruction, Op, ReplicationPolicy, Threadblock
from msccl.language.mscclpp.optimizer import Optimizer
from msccl.language.types import (
Channel,
ChannelType,
ChunkRef,
MscclppInstruction as Instruction,
Op,
ReplicationPolicy,
Threadblock,
)


class MscclppInstructionDAG(InstructionDAG):
Expand Down Expand Up @@ -133,6 +141,27 @@ def add_signal(self, rank, send_ref, recv_ref, tb, ch_type):
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
return op

def add_flush(self, rank, send_ref, recv_ref, tb):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.flush,
rank,
send_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ChannelType.proxy,
step=tb_step,
)
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
self._read(rank, buffer, index, size, op)
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
return op

def add_wait(self, rank, dst_ref, src_ref, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Expand Down Expand Up @@ -220,152 +249,37 @@ def _optimize_redundant_signal_wait(self):

# rrc(_,_,_,dst,dbuf,di) rrc(_,_,_,dst,dbuf,di) -> rrc(list[src,sbuf,si], dst, dbuf, di)
# signal(_,_,_,dst,dbuf,di) signal(_,_,_,dst,dbuf,di) -> signal(_,_,_,list[dst,dbuf,di])
# flush(_,_,_,dst,dbuf,di) flush(_,_,_,dst,dbuf,di) -> flush(_,_,_,list[dst,dbuf,di])
# wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_])
# reduce(_,_,_,dst,dbuf,di) reduce(_,_,_,dst,dbuf,di) -> reduce(list[src,sbuf,si], dst, dbuf, di)
# reduce_packet(_,_,_,dst,dbuf,di) reduce_packet(_,_,_,dst,dbuf,di) -> reduce_packet(list[src,sbuf,si], dst, dbuf, di)
def _optimize_rrc_r_signal_wait(self):
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
def _optimize_fuse_same_instruction(self):
optimizer = Optimizer()
# Mapping instruction to their respective condition checks and same buffer function
instruction_handlers = {
Instruction.read_reduce_copy: same_buf_dst,
Instruction.reduce: same_buf_dst,
Instruction.reduce_packet: same_buf_dst,
Instruction.signal: same_buf_src,
Instruction.wait: same_buf_dst,
Instruction.flush: same_buf_src,
}

for _, rank_tbs in enumerate(self.tbs):
for _, tb in rank_tbs.items():
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
if op.inst == Instruction.read_reduce_copy:
fused = False
for next_op in op.next:
if (
next_op.inst == Instruction.read_reduce_copy
and same_count(op, next_op)
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
ChunkRef(
next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size
),
next_op.step,
)
)
merge_op(op, next_op)
tb.ops.remove(next_op)
queue.remove(next_op)
fused = True
break
if fused:
continue
elif op.inst == Instruction.reduce:
fused = False
for next_op in op.next:
if (
next_op.inst == Instruction.reduce
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
ChunkRef(
next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size
),
next_op.step,
)
)
merge_op(op, next_op)
tb.ops.remove(next_op)
queue.remove(next_op)
fused = True
break
if fused:
continue
elif op.inst == Instruction.reduce_packet:
fused = False
for next_op in op.next:
if (
next_op.inst == Instruction.reduce_packet
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
ChunkRef(
next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size
),
next_op.step,
)
)
merge_op(op, next_op)
tb.ops.remove(next_op)
queue.remove(next_op)
fused = True
break
if fused:
continue
elif op.inst == Instruction.signal:
fused = False
fused = False
inst_type = op.inst
if inst_type in instruction_handlers:
for next_op in op.next:
if (
next_op.inst == Instruction.signal
and same_buf_src(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.dsts.append(
(
ChunkRef(
next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size
),
next_op.step,
)
)
op.srcs.append(
(
ChunkRef(
next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size
),
next_op.step,
)
)
merge_op(op, next_op)
tb.ops.remove(next_op)
queue.remove(next_op)
fused = True
break
if fused:
continue
elif op.inst == Instruction.wait:
fused = False
for next_op in op.next:
if (
next_op.inst == Instruction.wait
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
ChunkRef(
next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size
),
next_op.step,
)
)
op.dsts.append(
(
ChunkRef(
next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size
),
next_op.step,
)
)
merge_op(op, next_op)
tb.ops.remove(next_op)
queue.remove(next_op)
same_buf_func = instruction_handlers[inst_type]
if optimizer.try_merge_same_instruction(op, next_op, tb, queue, inst_type, same_buf_func):
fused = True
break
if fused:
continue
if fused:
continue
queue = queue[1:]

# rrc(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rrcs(_,_,_,_,_,_)
Expand Down Expand Up @@ -643,7 +557,7 @@ def _get_tb_step(self, rank: int, tb: int):

def optimize(self):
self._optimize_redundant_signal_wait()
self._optimize_rrc_r_signal_wait()
self._optimize_fuse_same_instruction()
self._optimize_rrcs_rs()
self._optimize_get_put()

Expand Down
1 change: 1 addition & 0 deletions msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Instruction.put,
Instruction.put_packet,
Instruction.signal,
Instruction.flush,
Instruction.copy,
Instruction.copy_packet,
Instruction.transform_to_packet,
Expand Down
58 changes: 58 additions & 0 deletions msccl/language/mscclpp/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from msccl.language.instruction_dag import (
merge_op,
circular_dep_after_merge,
same_chan_type,
same_count,
)
from msccl.language.types import ChunkRef, MscclppInstruction as Instruction, Op, ReplicationPolicy, Threadblock


class Optimizer:
def try_merge_same_instruction(
self, op: Op, next_op: Op, tb: Threadblock, queue: list, inst_type: Instruction, same_buf_func: callable
) -> bool:
"""
Attempts to merge two instruction if conditions are met.
:param op: The current operation.
:param next_op: The next operation to potentially merge with.
:param tb: The thread block containing the operations.
:param queue: The queue of operations.
:param inst_type: The type of the instruction being processed.
:param same_buf_func: The function to check if the buffer is the same (same_buf_dst or same_buf_src).
:return: True if operations are merged, False otherwise.
"""
if (
next_op.inst == inst_type
and same_buf_func(op, next_op)
and same_count(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
# Append the source chunks from next_op
op.srcs.append(
(
ChunkRef(
next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size
),
next_op.step,
)
)
# For 'signal' and 'wait' instructions, append destination chunks too
if inst_type in [Instruction.signal, Instruction.wait, Instruction.flush]:
op.dsts.append(
(
ChunkRef(
next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size
),
next_op.step,
)
)
# Merge operations
merge_op(op, next_op)
tb.ops.remove(next_op)
queue.remove(next_op)
return True
return False

0 comments on commit 3e3104e

Please sign in to comment.