Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code refactor for mscclang #3

Merged
merged 76 commits into from
May 7, 2024
Merged
Changes from 1 commit
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
58b7d9f
integration branch
Binyang2014 Mar 20, 2024
e039484
WIP
Binyang2014 Mar 20, 2024
3527e32
fix
Binyang2014 Mar 22, 2024
24c0863
WIP
Binyang2014 Mar 23, 2024
c32229d
WIP
Binyang2014 Mar 23, 2024
4929ef6
WIP need algo
Binyang2014 Mar 23, 2024
787645b
WIP
Binyang2014 Mar 24, 2024
0448b1d
WIP
Binyang2014 Mar 24, 2024
e23dc18
WIP need fuse
Binyang2014 Mar 24, 2024
2e08484
WIP
Binyang2014 Mar 24, 2024
e634e6f
WIP
Binyang2014 Mar 25, 2024
f8fe329
need more fuse
Binyang2014 Mar 25, 2024
b4c08c9
WIP
Binyang2014 Mar 25, 2024
e406632
WIP
Binyang2014 Mar 26, 2024
3321c5f
WIP
Binyang2014 Mar 26, 2024
7e4bd8b
WIP
Binyang2014 Mar 26, 2024
7074c01
WIP
Binyang2014 Mar 26, 2024
f31d9b4
Now for deps
Binyang2014 Mar 27, 2024
dc8d44e
let make instance work
Binyang2014 Mar 27, 2024
085be4a
enable instance
Binyang2014 Mar 28, 2024
e613558
fix
Binyang2014 Mar 28, 2024
79f450a
update ignore
Binyang2014 Mar 28, 2024
c4a10dd
bug fix
Binyang2014 Mar 29, 2024
ec4a112
update
Binyang2014 Apr 2, 2024
82de232
update
Binyang2014 Apr 2, 2024
171e894
fix
Binyang2014 Apr 5, 2024
99ff31c
WIP
Binyang2014 Apr 7, 2024
93683b5
WIP
Binyang2014 Apr 8, 2024
10b648c
WIP
Binyang2014 Apr 8, 2024
52fd030
update
Binyang2014 Apr 8, 2024
7dd76b6
WIP
Binyang2014 Apr 8, 2024
451f31d
WIP
Binyang2014 Apr 8, 2024
b2ceb13
WIP
Binyang2014 Apr 8, 2024
b683d7f
update
Binyang2014 Apr 8, 2024
3cf049e
WIP
Binyang2014 Apr 8, 2024
b1fd952
Done for today
Binyang2014 Apr 8, 2024
a4728fa
update packet algo
Binyang2014 Apr 19, 2024
42c4a7d
fix comments
Binyang2014 Apr 19, 2024
b494b75
OPT
Binyang2014 Apr 19, 2024
c2bd38f
Fix
Binyang2014 Apr 22, 2024
bb3aebe
WIP
Binyang2014 Apr 24, 2024
40217f9
WIP
Binyang2014 Apr 24, 2024
2e5bac6
WIP
Binyang2014 Apr 24, 2024
eb4f612
fix
Binyang2014 Apr 24, 2024
2617280
fix
Binyang2014 Apr 24, 2024
66d0495
fix
Binyang2014 Apr 24, 2024
8cfcc30
Fix UT
Binyang2014 Apr 25, 2024
2ba1760
update
Binyang2014 Apr 25, 2024
52783e9
WIP
Binyang2014 Apr 25, 2024
01a0745
WIP
Binyang2014 Apr 25, 2024
13e902a
update
Binyang2014 Apr 25, 2024
e52cabf
revert
Binyang2014 Apr 29, 2024
8f46276
revert
Binyang2014 Apr 29, 2024
775d9d6
fix
Binyang2014 Apr 29, 2024
8a58c84
WIP
Binyang2014 Apr 29, 2024
da282fb
WIP
Binyang2014 Apr 29, 2024
1f5114b
WIP
Binyang2014 Apr 29, 2024
5e978bf
WIP
Binyang2014 Apr 29, 2024
a90647e
WIP
Binyang2014 Apr 29, 2024
0e77d88
fix
Binyang2014 Apr 29, 2024
e05edcf
WIP
Binyang2014 Apr 29, 2024
54463bf
WIP
Binyang2014 Apr 29, 2024
ca3f10a
WIP
Binyang2014 Apr 29, 2024
b9c9734
add back
Binyang2014 Apr 29, 2024
a3b9745
update
Binyang2014 Apr 29, 2024
d81936f
Fix
Binyang2014 Apr 30, 2024
2651faf
revert
Binyang2014 Apr 30, 2024
bb7a584
address comments
Binyang2014 May 6, 2024
d072a34
WIP
Binyang2014 May 6, 2024
dda74f9
Fix
Binyang2014 May 6, 2024
f47dfe9
Fix
Binyang2014 May 6, 2024
daef76a
WIP
Binyang2014 May 6, 2024
3d2d838
WIP
Binyang2014 May 6, 2024
5a34cd5
WIP
Binyang2014 May 6, 2024
02b5de9
WIP
Binyang2014 May 6, 2024
06d7776
done
Binyang2014 May 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP
Binyang2014 committed Mar 20, 2024
commit e0394848eecb0356fec0a2b43d04615c2db72033
17 changes: 17 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File with Arguments",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"args": "4 1 --protocol Simple",
"justMyCode": false
}
]
}
10 changes: 5 additions & 5 deletions examples/mscclang/allreduce_a100_allpairs.py
Original file line number Diff line number Diff line change
@@ -11,9 +11,9 @@ def allreduce_allpairs(gpus, instances, protocol):
chunksperloop = gpus * gpus
topology = fully_connected(size)
collective = AllReduce(size, chunksperloop, True)
with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol,
with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol,
interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True):

# Each rank sends the nth chunk to the nth rank into scratch space
for r1 in range(size):
for r2 in range(size):
@@ -28,15 +28,15 @@ def allreduce_allpairs(gpus, instances, protocol):
for index in range(0, size * (size-1)):
c = chunk(r, Buffer.input, r*size + (index % size))
c.reduce(chunk(r, 'scratch', index), sendtb=(index % size))

# Each rank sends the fully reduced nth chunk to all other gpus
for r1 in range(size):
for r2 in range(size):
if r1 != r2:
index = r1 * size
c = chunk(r1, Buffer.input, index, size)
c.copy(r2, Buffer.input, index, sendtb=r2, recvtb=r1)

XML()
Check()

@@ -47,4 +47,4 @@ def allreduce_allpairs(gpus, instances, protocol):

args = parser.parse_args()

allreduce_allpairs(args.num_gpus, args.instances, args.protocol)
allreduce_allpairs(args.num_gpus, args.instances, args.protocol)
51 changes: 51 additions & 0 deletions examples/mscclang/allreduce_a100_allpairs_mscclpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllReduce

def allreduce_allpairs(gpus, instances, protocol):
size = gpus
chunksperloop = gpus * gpus
topology = fully_connected(size)
collective = AllReduce(size, chunksperloop, True)
with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol,
interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True):

# Each rank sends the nth chunk to the nth rank into scratch space
for r1 in range(size):
for r2 in range(size):
if r1 != r2:
index = r2 * size
c = chunk(r1, Buffer.input, index, size=size)
c.put(r2, 'scratch', index=r1, sendtb=r2)

# Each rank performs a local reduction on the nth chunk
# Utilize 8 threadblocks for this reduction for better parallelism
for r in range(size):
for index in range(0, size * (size-1)):
c = chunk(r, Buffer.input, r*size + (index % size))
c.reduce(chunk(r, 'scratch', index), sendtb=(index % size))

# Each rank sends the fully reduced nth chunk to all other gpus
for r1 in range(size):
for r2 in range(size):
index = r1 * size
c = chunk(r1, Buffer.input, index + r2)
for r3 in range(size):
if r3 != r1:
c.put(r3, Buffer.input, index, sendtb=r2)

XML()
Check()

parser = argparse.ArgumentParser()
parser.add_argument('num_gpus', type=int, help ='number of gpus')
parser.add_argument('instances', type=int, help='number of instances')
parser.add_argument('--protocol', type=str, default='LL', choices=['Simple', 'LL'], help='Protocol')

args = parser.parse_args()

allreduce_allpairs(args.num_gpus, args.instances, args.protocol)
45 changes: 45 additions & 0 deletions examples/mscclang/allreduce_a100_allpairs_mscclpp_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllReduce

def allreduce_allpairs(gpus, instances, protocol):
size = gpus
chunksperloop = gpus * gpus
topology = fully_connected(size)
collective = AllReduce(size, chunksperloop, True)
with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol,
interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True):

# Each rank sends the nth chunk to the nth rank into scratch space
for rank in range(size):
for tb in range(size):
index = rank * size
c = chunk(rank, Buffer.input, index + tb)
for nghr in range(size):
if rank != nghr:
c.reduce(chunk(nghr, 'input', index + tb), recvtb==tb)

# Each rank sends the fully reduced nth chunk to all other gpus
for rank in range(size):
for tb in range(size):
index = rank * size
c = chunk(rank, Buffer.input, index + tb)
for nghr in range(size):
if rank != nghr:
c.put(nghr, Buffer.input, index, sendtb=tb)

XML()
Check()

parser = argparse.ArgumentParser()
parser.add_argument('num_gpus', type=int, help ='number of gpus')
parser.add_argument('instances', type=int, help='number of instances')
parser.add_argument('--protocol', type=str, default='LL', choices=['Simple', 'LL'], help='Protocol')

args = parser.parse_args()

allreduce_allpairs(args.num_gpus, args.instances, args.protocol)
37 changes: 36 additions & 1 deletion msccl/language/__init__.py
Original file line number Diff line number Diff line change
@@ -138,6 +138,9 @@ def print_instr_dags(self, rank):
else:
visualize_instr_dag(self.instr_dags[rank].operations)

class MSCCLPPProgram:
pass

def Print():
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
_curr().print_chunk_dag()

@@ -190,8 +193,40 @@ def group(self, other):
end = max(first._end(), second._end())
return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog)

def put(self, dst, buffer=None, index=-1, sendtb=-1):
def put(self, dst, buffer=None, index=-1, sendtb=-1, channel_type="SM"):
self.prog.check_buffer_exists(dst, buffer)
# If index is not specified assume it is going to the same place in the next gpu
if index == -1 and buffer == None:
index = self.index
buffer = self.buffer
elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output:
index = self.prog.buffers[dst][buffer].instance_size()

# Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather)
buffer, index = self.prog.collective.get_buffer_index(self.rank, buffer, index)

# Direct put
assert (self.prog.topo.link(self.rank, dst) or dst == self.rank), f'No link from {self.rank} to {dst}'
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)

# Check if we are copying the chunk to the same index (easy mistake when we are using inplace)
if dst_chunkref == self:
return

# chunks = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size)
# overwritten_chunks = self.prog.get_chunks(dst, buffer, index, self.size)

self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)

# self.prog.chunk_dag.add_send(chunks, overwritten_chunks, self, dst_chunkref, sendtb, recvtb, ch)
sender = self.rank
receiver = dst
if sender != receiver:
sop = self.prog.instr_dag.add_send(sender, self, dst_chunkref, sendtb)
else:
self.prog.instr_dag.add_copy(sender, self, dst_chunkref, sendtb)

return dst_chunkref

def get(self, src, buffer=None, index=-1, recvtb=-1):
self.prog.check_buffer_exists(src, buffer)
15 changes: 10 additions & 5 deletions msccl/language/ir.py
Original file line number Diff line number Diff line change
@@ -76,8 +76,13 @@ class Instruction(Enum):
recv_reduce_copy_send = 'rrcs'
copy = 'cpy'
reduce = 're'
delete = 'd'
delete = 'd'
start = 'st'
put = 'put'
get = 'get'
wait = 'wait'
signal = 'signal'
flush = 'flush'

def __str__(self):
return self.value
@@ -93,7 +98,7 @@ def __str__(self):

def __lt__(self, other):
return self.value < other.value

def __gt__(self, other):
return self.value < other.value

@@ -172,7 +177,7 @@ def send_peer(self):
if self.is_send():
return self.dst.rank
return -1

def recv_peer(self):
if self.is_recv():
return self.src.rank
@@ -244,7 +249,7 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print=
op.depends = list(
filter(lambda dep: op_tb_id[dep] != tb_id[tb], op.depends))
# Filter out redundant dependencies
# e.g. if op1 and op2 depend on op, and op1 happends before op2
# e.g. if op1 and op2 depend on op, and op1 happends before op2
# then op2 does not need to explicitly depend on op
for gpu in program.gpus:
for tb in gpu.threadblocks:
@@ -276,7 +281,7 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print=
for dep in op.depends:
if first_dep is None:
first_dep = dep
else:
else:
pre_ops.append(Op(Instruction.nop, -1, None, None, [dep]))
op.depends = []
if first_re is None:
45 changes: 27 additions & 18 deletions msccl/language/rank_dag.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ def same_tb(op1, op2):

def same_count(op1, op2):
return op1.cnt() == op2.cnt()

def same_buf_dst(op1, op2):
return op1.dst.buffer == op2.dst.buffer and op1.dst.index == op2.dst.index

@@ -36,9 +36,9 @@ def __init__(self, num_ranks, buffers):
self.last_writer = {} # slot -> last writing op
self.last_readers = defaultdict(list) # slot -> list of last reading ops
# State for the MSCCL-IR
self.tbs = []
self.tbs = []
for _ in range(num_ranks):
self.tbs.append({})
self.tbs.append({})
self.tb_mapping = {}
self.num_channels = [1] * num_ranks

@@ -62,7 +62,7 @@ def _write(self, rank, buffer, index, size, op, read=False):
prev_ops.update(readers)
elif slot in self.last_writer:
prev_ops.add(self.last_writer[slot])

# Set the last_writer to this op, and clear all readers
self.last_writer[slot] = op
self.last_readers[slot] = []
@@ -82,7 +82,7 @@ def _read(self, rank, buffer, index, size, op):
writer = self.last_writer[slot]
prev_ops.add(writer)
self.last_readers[slot].append(op)

# Update the next pointer of the previous ops
for prev_op in prev_ops:
prev_op.next.add(op)
@@ -133,6 +133,15 @@ def add_send(self, rank, send_ref, recv_ref, tb, ch):
self._read(rank, buffer, index, size, op)
return op

# InstructionDAG - adds a put node
def add_put(self, rank, send_ref, recv_ref, tb, ch):
op = Op(Instruction.put, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
self._read(rank, buffer, index, size, op)
return op

# InstructionDAG - adds a recv node
def add_recv(self, rank, send_ref, recv_ref, tb, ch, send_op):
op = Op(Instruction.recv, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
@@ -172,7 +181,7 @@ def convert_set_list(self):
ops = ops[1:] + op.next
else:
ops = ops[1:]

def optimize(self):
self._optimize_rrcs_rrs()
self._optimize_rcs()
@@ -198,13 +207,13 @@ def dfs(op, cs):
for chunk, op in self.operations.items():
if op.inst == Instruction.start:
dfs(op,-2) # Start instructions should start at -1


# Given the set of operations that operate over a particular slot (rank, buffer, idx) fixed
# Try and replace operations with pipelined ops like receive copy send (rcs)
# or receive reduce send (rrs) and receive reduce copy send (rrcs)
# Rules:
# recv-copy-send
# recv-copy-send
# recv(src, sbuf, si, _, _, _ ) send(_, _, _, dst, dbuf, di) -> recv_copy_send(src, sbuf, si, dst, dbuf, di)
def _optimize_rcs(self):
for slot, ops in self.operations.items():
@@ -222,7 +231,7 @@ def _optimize_rcs(self):
break
frontier = frontier[1:] + op.next
# recv-reduce-send - A rrc followed by a send that gets overwritten
# rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di) recv(_, _, _, dst, dbuf, di)
# rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di) recv(_, _, _, dst, dbuf, di)
# recv-reduce-copy-send - A rrc followed by a send that does not get overwritten
# rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di)
def _optimize_rrcs_rrs(self):
@@ -241,7 +250,7 @@ def _optimize_rrcs_rrs(self):
next_op.recv_match.send_match = op
op.recv_match = next_op.recv_match
remove_op(next_op)

if op.inst == Instruction.recv_reduce_copy and next_op.inst == Instruction.send and same_tb(op, next_op) and same_count(op, next_op) and same_buf_dst(op, next_op):
op.inst = Instruction.recv_reduce_copy_send
op.dst = next_op.dst
@@ -253,7 +262,7 @@ def _optimize_rrcs_rrs(self):
def lower_pt1(self, instances):
self.infer_dependencies()
self.lower_buffers(instances)

def lower_pt2(self, instances, interleaved):
self.replicate(instances, interleaved)
return self.lower_tbs()
@@ -311,14 +320,14 @@ def lower_tbs(self):
# interleaved sets the replication policy
# if True chunks are split as: ChunkA ChunkB -> ChunkA0 ChunkA1 .. ChunkB0 ChunkB1 ...
# if false chunks are divided as ChunkA0 ChunkB0 ChunkA1 ChunkB1 ...
# For collectives were chunks are designated for a particular GPU (e.g. AllToAll)
# For collectives were chunks are designated for a particular GPU (e.g. AllToAll)
# only interleaved replication will be correct
# Interleaved policy only supports single count sends/receives from the input/output buffer
# (multicount ops are fine between scratch)
def replicate(self, instances, interleaved):
if instances == 1:
self.instanced_tbs = self.tbs
return
return

self.instanced_tbs = []
for _ in range(self.num_ranks):
@@ -357,12 +366,12 @@ def get_instance_ref(ref):
for s, op in enumerate(tb.ops):
isrc = get_instance_ref(op.src)
idst = get_instance_ref(op.dst)
idepends = []
idepends = []
# Note: We don't need the fill out the rest of the metadata since replication is the last optimization
iop = Op(op.inst, op.rank, isrc, idst, idepends, op.step, itbid)
iop = Op(op.inst, op.rank, isrc, idst, idepends, op.step, itbid)
itb.ops[s] = iop
self.instanced_tbs[op.rank][itbid] = itb

# Redo dependency analysis
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
@@ -375,5 +384,5 @@ def get_instance_ref(ref):
dep_tbid = dep.tb
dep_itbid = dep_tbid * instances + i
dep_step = dep.step
iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step]
iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step]