Skip to content

Commit

Permalink
Add barrier OP (#27)
Browse files Browse the repository at this point in the history
Add barrier op to set barrier for multi threadblocks in same rank.
The way to use:
```python
r = rank(n)
r.barrier(tb_list=list(range(gpus)))
```
  • Loading branch information
Binyang2014 authored Nov 27, 2024
1 parent 514a1c9 commit 8efa06f
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 6 deletions.
47 changes: 47 additions & 0 deletions examples/mscclang/mscclpp/allgather_barrier_mscclpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllGather


def allgather_test(gpus, instances):
size = gpus
topology = fully_connected(size)
collective = AllGather(size, 1, False)
with MSCCLPPProgram(
"allgather_test",
topology,
collective,
instances,
protocol="Simple",
replication_policy=ReplicationPolicy.interleaved,
):
for n in range(gpus):
c = chunk(n, Buffer.input, 0, 1)
for peer in range(gpus):
if n != peer:
c.put(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm)
else:
c.copy(n, Buffer.output, n, sendtb=peer)
# explicit barrier
r = rank(n)
r.barrier(tb_list=list(range(gpus)))
for peer in range(gpus):
if n != peer:
c.signal(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm)

for n in range(gpus):
for peer in range(gpus):
c = chunk(n, Buffer.output, peer, 1)
if n != peer:
c.wait(peer, Buffer.input, peer, recvtb=peer, chan_type=ChannelType.sm)

Json()
Check()


parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
allgather_test(args.num_gpus, args.instances)
3 changes: 3 additions & 0 deletions msccl/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ def chunk(rank, buffer, index, size=1) -> Union[mscclpp.Ref, Ref]:
return None
return _curr().get_ref(rank, buffer, index, size)

def rank(rank) -> mscclpp.RankRef:
return _curr().get_rank_ref(rank)


def create_scratch(rank, name):
return _curr().create_scratch(rank, name)
Expand Down
3 changes: 3 additions & 0 deletions msccl/language/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def same_buf_src(op1: Op, op2: Op):
def same_chan_type(op1: Op, op2: Op):
return op1.channel_type == op2.channel_type

def same_tb(op1: Op, op2: Op):
return op1.tb == op2.tb


class InstructionDAG(ABC):
def __init__(self, num_ranks, buffers):
Expand Down
19 changes: 19 additions & 0 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from msccl.language.types import ChannelType
from msccl.language.mscclpp.ir import *
from msccl.language.mscclpp.instruction_dag import MscclppInstructionDAG
from msccl.language.mscclpp.rank import Rank
from msccl.language.tb_assignment import *
from msccl.topologies.topology import Topology

Expand Down Expand Up @@ -51,7 +52,9 @@ def __init__(
# Initialize the input buffers
self.buffers = collective.init_buffers()
self.instr_dag = MscclppInstructionDAG(self.num_ranks, self.buffers)
self.ranks = []
for r in range(self.num_ranks):
self.ranks.append(Rank(r))
for index, chunk in enumerate(self.buffers[r][Buffer.input]):
buffer, index = self.collective.get_buffer_index(r, Buffer.input, index)
ref = self.get_ref(r, buffer, index, 1)
Expand Down Expand Up @@ -81,6 +84,9 @@ def _convert_to_exectuion_plan(self):
tb = self.instr_dag.tbs[rank][tbid]
tb.ops.append(op)

def get_rank_ref(self, rank):
return RankRef(rank, self)

# Tracks a send operation on the buffers
def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
Expand Down Expand Up @@ -156,6 +162,19 @@ def Json():
print(_curr().generate_json())


@dataclass
class RankRef:
rank: int
prog: MSCCLPPProgram

def _get_barrier_id(self, tb_list) -> int:
return self.prog.ranks[self.rank].get_barrier_id(tb_list)

def barrier(self, tb_list):
barrier_id = self._get_barrier_id(tb_list)
return self.prog.instr_dag.add_barrier(self.rank, tb_list, barrier_id)


@dataclass
class Ref(ChunkRef):
prog: MSCCLPPProgram
Expand Down
35 changes: 33 additions & 2 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.


import copy
from msccl.language.buffer import Buffer
from msccl.language.instruction_dag import (
same_buf_dst,
Expand Down Expand Up @@ -131,7 +132,8 @@ def add_signal(self, rank, send_ref, recv_ref, tb, ch_type):
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
# treat signal as a write since it can not be executed parallelly with read operations
# treat signal as a write. signal acts as a barrier for the next instruction which prevents the
# below instructions to be scheduled above the signal instruction.
self._write(rank, buffer, index, size, op)
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
Expand Down Expand Up @@ -191,6 +193,15 @@ def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type):
self._write(rank, buffer, index, size, op, read=True)
return op

def add_barrier(self, rank, tb_list, barrier_id):
buffers = self.buffers[rank]
for tb in tb_list:
tb_step = self._get_tb_step(rank, tb)
extra = {"tb_list": tb_list, "barrier_id": barrier_id}
op = Op(Instruction.barrier, rank, None, None, next=set(), prev=set(), tb=tb, step=tb_step, extra=extra)
for buffer_type, buffer in buffers.items():
self._write(rank, buffer_type, 0, len(buffer), op)

def add_group_load_reduce(self, rank, send_refs, recv_ref, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Expand Down Expand Up @@ -245,6 +256,8 @@ def complete_channels(self):
for tbid, tb in rank_tbs.items():
chans = set()
for op in tb.ops:
if op.inst == Instruction.barrier:
continue
if op.src != None:
src_buffer = (
Buffer.scratch
Expand Down Expand Up @@ -483,10 +496,19 @@ def get_new_index(rank, buffer, index, size, i):
return len(self.buffers[rank][buffer]) * i + index

def get_instance_ref(ref):
if ref is None:
return None
iindex = get_new_index(ref.rank, ref.buffer, ref.index, ref.size, i)
iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size)
return iref

def update_extra(op, ori_op):
if op.inst == Instruction.barrier:
tb_list = ori_op.extra["tb_list"]
new_tb_list = [tb * instances + i for tb in tb_list]
op.extra["tb_list"] = new_tb_list
op.extra["barrier_id"] = ori_op.extra["barrier_id"] * instances + i

for i in range(instances):
# Generate all the threadblocks and ops
for rank, rank_tbs in enumerate(self.tbs):
Expand All @@ -501,8 +523,17 @@ def get_instance_ref(ref):
idepends = []
# Note: We don't need the fill out the rest of the metadata since replication is the last optimization
iop = Op(
op.inst, op.rank, isrc, idst, idepends, op.step, itbid, channel_type=op.channel_type
op.inst,
op.rank,
isrc,
idst,
idepends,
op.step,
itbid,
channel_type=op.channel_type,
extra=copy.deepcopy(op.extra),
)
update_extra(iop, op)
itb.ops[s] = iop
for src, step in op.srcs:
isrc = get_instance_ref(src)
Expand Down
4 changes: 4 additions & 0 deletions msccl/language/mscclpp/instruction_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
same_count,
same_buf_dst,
same_buf_src,
same_tb,
all_prevs_visited_after_merge,
)
from msccl.language.types import ChunkRef, ChannelType, MscclppInstruction as Instruction, Op, Threadblock
Expand Down Expand Up @@ -38,6 +39,7 @@ def try_merge_same_instructions(
"""
if (
next_op.inst == expected_next_inst
and same_tb(op, next_op)
and same_buf_func(op, next_op)
and same_count(op, next_op)
and same_chan_type(op, next_op)
Expand Down Expand Up @@ -122,6 +124,7 @@ def try_fuse_with_put(self, op: Op, next_op: Op, tb: Threadblock, queue: list) -
if (
next_op.inst == Instruction.put
or next_op.inst == Instruction.put_packet
and same_tb(op, next_op)
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and next_op.channel_type == ChannelType.sm
Expand Down Expand Up @@ -168,6 +171,7 @@ def try_fuse_instructions_using_proxy_channel(
"""
if (
next_op.inst == expected_next_inst
and same_tb(op, next_op)
and same_count(op, next_op)
and same_buf_dst(op, next_op)
and same_buf_src(op, next_op)
Expand Down
21 changes: 17 additions & 4 deletions msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Instruction.copy_packet,
Instruction.reduce_packet,
Instruction.reduce_send_packet,
Instruction.barrier,
}


Expand Down Expand Up @@ -125,8 +126,17 @@ def ir_to_json(program: Program):
new_ops.append(op)
continue
# Expand extra dependencies into nop operations
nop = Op(Instruction.nop, -1, None, None, [])
for i, dep in enumerate(op.depends):
new_ops.append(Op(Instruction.nop, -1, None, None, [dep]))
# barrier already syncs all threads
if dep.inst != Instruction.barrier:
nop.depends.append(dep)
if len(new_ops) > 0 and (
new_ops[-1].inst == Instruction.barrier or new_ops[-1].inst == Instruction.nop
):
new_ops[-1].depends.extend(nop.depends)
elif len(nop.depends) > 0:
new_ops.append(nop)
new_ops.append(op)
tb.ops = new_ops

Expand Down Expand Up @@ -230,8 +240,9 @@ def remove_empty_fields(d):
"chanIds": [id for id, ele in enumerate(channels) if ele[0] == tb.id],
"connectedTo": [ele[1] for ele in channels if ele[0] == tb.id],
}
tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj
tb_channels.append(obj)
if len(obj["chanIds"]) > 0:
tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj
tb_channels.append(obj)
tb_channels = filter(lambda x: x["type"] != "none", tb_channels)
tb_channels = sorted(tb_channels, key=lambda x: (x["srcbuff"], x["dstbuff"]))
for op in tb.ops:
Expand Down Expand Up @@ -292,6 +303,8 @@ def remove_empty_fields(d):
"name": op.inst.value,
"deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)),
}
elif op.inst == Instruction.barrier:
instr = {"name": op.inst.value, "nthread_blocks": len(op.extra["tb_list"]), "barrier_id": op.extra["barrier_id"]}
elif (
op.inst == Instruction.put
or op.inst == Instruction.put_packet
Expand Down Expand Up @@ -325,7 +338,7 @@ def remove_empty_fields(d):
dst_channel_ids = get_channel_ids(
op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type
)
if op.inst != Instruction.nop:
if op.inst != Instruction.nop and op.inst != Instruction.barrier:
instr = {
"name": op.inst.value,
"i_buff": i_buff,
Expand Down
33 changes: 33 additions & 0 deletions msccl/language/mscclpp/rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from dataclasses import dataclass, field
from typing import Dict


class BarrierInfo:
def __init__(self, tb_list):
self.tb_list = tb_list

def __eq__(self, other):
return self.tb_list == other.tb_list

def __hash__(self):
return hash(tuple(self.tb_list))


@dataclass
class Rank:
rank_id: int
current_max_barrier_id: int = 0
current_barriers: Dict[BarrierInfo, int] = field(default_factory=dict)

def get_barrier_id(self, tb_list):
barrier_info = BarrierInfo(tb_list)
if barrier_info in self.current_barriers:
return self.current_barriers[barrier_info]
else:
self.current_barriers[barrier_info] = self.current_max_barrier_id
barrier_id = self.current_max_barrier_id
self.current_max_barrier_id += 1
return barrier_id
2 changes: 2 additions & 0 deletions msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class MscclppInstruction(Enum):
wait = "wait"
signal = "signal"
flush = "flush"
barrier = "barrier"
group_store = "gstore"
group_load_reduce = "glre"
group_load_reduce_store = "glres"
Expand Down Expand Up @@ -186,6 +187,7 @@ class Op:
channel_type: ChannelType = ChannelType.none
srcs: list = field(default_factory=list)
dsts: list = field(default_factory=list)
extra: dict = field(default_factory=dict)

def cnt(self):
if self.src:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,32 @@ def test_illegal_tb_assignment():
XML()


def test_group_api():
num_gpus = 4
topology = fully_connected(num_gpus)
collective = AllReduce(num_gpus, num_gpus, True)
prgm = MSCCLPPProgram("allreduce", topology, collective, 1)
with prgm:
for rank in range(num_gpus):
index = rank
reduce_chunks = []
c = chunk(rank, Buffer.input, index)
# make sure the data is ready
for nghr in range(num_gpus):
if rank != nghr:
c_peer = chunk(nghr, Buffer.input, index)
reduce_chunks.append(c_peer)
c = c.group_load_reduce(reduce_chunks, recvtb=0)
ngbrs = [nghr for nghr in range(num_gpus) if nghr != rank]
c.group_store(ngbrs, sendtb=0)
assert Check()
lowered_prgm = prgm.lower()
for gpu in lowered_prgm.gpus:
for tb in gpu.threadblocks:
assert len(tb.ops) == 1
assert tb.ops[0].inst == MscclppInstruction.group_load_reduce_store


def test_routines_allgather_ring_inplace():
size = 4
topology = fully_connected(size)
Expand Down

0 comments on commit 8efa06f

Please sign in to comment.