Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Apr 24, 2024
1 parent eb4f612 commit 2617280
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 43 deletions.
36 changes: 18 additions & 18 deletions msccl/language/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict

from msccl.language.buffer import Buffer
from msccl.language.types import ChunkRef, Gpu, InstancePolicy, Instruction, MscclInstruction, Op, Threadblock
from msccl.language.types import ChunkRef, Gpu, InstancePolicy, Instruction, Op, Threadblock


def remove_op(op: Op):
Expand Down Expand Up @@ -186,7 +186,7 @@ def convert_set_list(self):
op.next = list(op.next)
for o in op.next:
ops.append(o)
elif op.inst != MscclInstruction.copy:
elif op.inst != Instruction.copy:
ops.append(op)

while len(ops) > 0:
Expand Down Expand Up @@ -216,14 +216,14 @@ def replicate(self, instances: int, instance_policy: InstancePolicy):
pass


class MscclInstructionDAG(InstructionDAG):
class InstructionDAG(InstructionDAG):

def __init__(self, num_ranks, buffers):
super().__init__(num_ranks, buffers)

# InstructionDAG - adds a copy node
def add_copy(self, rank, send_ref, recv_ref, tb, ch):
op = Op(MscclInstruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
dstbuffer = recv_ref.buffer
dstindex = recv_ref.index
srcbuffer = send_ref.buffer
Expand All @@ -237,7 +237,7 @@ def add_copy(self, rank, send_ref, recv_ref, tb, ch):

# InstructionDAG - adds a redduce node
def add_reduce(self, rank, send_ref, recv_ref, tb, ch):
op = Op(MscclInstruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
dstbuffer = recv_ref.buffer
dstindex = recv_ref.index
srcbuffer = send_ref.buffer
Expand All @@ -251,7 +251,7 @@ def add_reduce(self, rank, send_ref, recv_ref, tb, ch):

# InstructionDAG - adds a send node
def add_send(self, rank, send_ref, recv_ref, tb, ch):
op = Op(MscclInstruction.send, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
op = Op(Instruction.send, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
Expand All @@ -260,7 +260,7 @@ def add_send(self, rank, send_ref, recv_ref, tb, ch):

# InstructionDAG - adds a recv node
def add_recv(self, rank, send_ref, recv_ref, tb, ch, send_op):
op = Op(MscclInstruction.recv, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
op = Op(Instruction.recv, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
buffer = recv_ref.buffer
index = recv_ref.index
size = recv_ref.size
Expand All @@ -270,7 +270,7 @@ def add_recv(self, rank, send_ref, recv_ref, tb, ch, send_op):

# InstructionDAG - adds a rrc node
def add_recv_reduce_copy(self, rank, send_ref, recv_ref, tb, ch, send_op):
op = Op(MscclInstruction.recv_reduce_copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
op = Op(Instruction.recv_reduce_copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
buffer = recv_ref.buffer
index = recv_ref.index
size = recv_ref.size
Expand Down Expand Up @@ -317,14 +317,14 @@ def _optimize_rcs(self):
op = frontier[0]
for next_op in op.next:
if (
op.inst == MscclInstruction.recv
and next_op.inst == MscclInstruction.send
op.inst == Instruction.recv
and next_op.inst == Instruction.send
and same_tb(op, next_op)
and same_count(op, next_op)
and same_buf_dst(op, next_op)
):
# recv -> rcs, remove send
op.inst = MscclInstruction.recv_copy_send
op.inst = Instruction.recv_copy_send
op.dst = next_op.dst
next_op.recv_match.send_match = op
op.recv_match = next_op.recv_match
Expand All @@ -347,27 +347,27 @@ def _optimize_rrcs_rrs(self):
if len(next_op.next) == 1:
nnext_op = next_op.next[0]
if (
op.inst == MscclInstruction.recv_reduce_copy
and next_op.inst == MscclInstruction.send
and nnext_op.inst is MscclInstruction.recv
op.inst == Instruction.recv_reduce_copy
and next_op.inst == Instruction.send
and nnext_op.inst is Instruction.recv
and same_tb(op, next_op)
and same_count(op, next_op)
and same_buf_dst(op, next_op)
):
op.inst = MscclInstruction.recv_reduce_send
op.inst = Instruction.recv_reduce_send
op.dst = next_op.dst
next_op.recv_match.send_match = op
op.recv_match = next_op.recv_match
remove_op(next_op)

if (
op.inst == MscclInstruction.recv_reduce_copy
and next_op.inst == MscclInstruction.send
op.inst == Instruction.recv_reduce_copy
and next_op.inst == Instruction.send
and same_tb(op, next_op)
and same_count(op, next_op)
and same_buf_dst(op, next_op)
):
op.inst = MscclInstruction.recv_reduce_copy_send
op.inst = Instruction.recv_reduce_copy_send
op.dst = next_op.dst
next_op.recv_match.send_match = op
op.recv_match = next_op.recv_match
Expand Down
2 changes: 1 addition & 1 deletion msccl/language/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict

from msccl.language.buffer import Buffer
from msccl.language.types import MscclInstruction as Instruction, Op, Program
from msccl.language.types import Instruction, Op, Program


# Instructions where src is on local GPU
Expand Down
4 changes: 2 additions & 2 deletions msccl/language/tb_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def priority(op):
if op.inst == Instruction.start:
visited.add(op)
for o in op.next:
if o.inst == MscclInstruction.send or o.inst == MscclInstruction.copy:
if o.inst == Instruction.send or o.inst == Instruction.copy:
heapq.heappush(ops, (priority(o), o))

while len(ops) > 0:
Expand Down Expand Up @@ -206,7 +206,7 @@ def dfs(op, channels, f):

# Assign channels to flows
for op in instrs:
if op.inst == MscclInstruction.send and op.recv_match.is_fused():
if op.inst == Instruction.send and op.recv_match.is_fused():
dfs(op, all_channels(), [])

# Iterate through and make certain the sends and receives between a pair of GPUs is consistent
Expand Down
38 changes: 16 additions & 22 deletions msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,6 @@ def __str__(self):
class Instruction(Enum):
delete = "d"
start = "st"

def __str__(self):
return self.value


class MscclInstruction(Enum):
nop = "nop"
send = "s"
recv = "r"
Expand Down Expand Up @@ -144,7 +138,7 @@ def __hash__(self):

@dataclass
class Op:
inst: Union[Instruction, MscclInstruction, MscclppInstruction]
inst: Union[Instruction, MscclppInstruction]
rank: int
src: ChunkRef
dst: ChunkRef
Expand Down Expand Up @@ -175,35 +169,35 @@ def cnt(self):

def is_send(self):
return (
self.inst == MscclInstruction.send
or self.inst == MscclInstruction.recv_reduce_copy_send
or self.inst == MscclInstruction.recv_copy_send
or self.inst == MscclInstruction.recv_reduce_send
self.inst == Instruction.send
or self.inst == Instruction.recv_reduce_copy_send
or self.inst == Instruction.recv_copy_send
or self.inst == Instruction.recv_reduce_send
)

def is_recv(self):
return (
self.inst == MscclInstruction.recv
or self.inst == MscclInstruction.recv_reduce_copy
or self.inst == MscclInstruction.recv_reduce_copy_send
or self.inst == MscclInstruction.recv_copy_send
or self.inst == MscclInstruction.recv_reduce_send
self.inst == Instruction.recv
or self.inst == Instruction.recv_reduce_copy
or self.inst == Instruction.recv_reduce_copy_send
or self.inst == Instruction.recv_copy_send
or self.inst == Instruction.recv_reduce_send
)

def is_fused(self):
return (
self.inst == MscclInstruction.recv_reduce_copy_send
or self.inst == MscclInstruction.recv_copy_send
or self.inst == MscclInstruction.recv_reduce_send
self.inst == Instruction.recv_reduce_copy_send
or self.inst == Instruction.recv_copy_send
or self.inst == Instruction.recv_reduce_send
)

def is_local(self):
return self.inst == MscclInstruction.copy or self.inst == MscclInstruction.reduce
return self.inst == Instruction.copy or self.inst == Instruction.reduce

def peer(self):
if self.inst == MscclInstruction.send:
if self.inst == Instruction.send:
return self.dst.rank
elif self.inst == MscclInstruction.recv:
elif self.inst == Instruction.recv:
return self.src.rank
else:
return None
Expand Down

0 comments on commit 2617280

Please sign in to comment.