Skip to content

Commit

Permalink
Add MSCCLPPProgram (#4)
Browse files Browse the repository at this point in the history
- Add MSCCLPPProgram example
- Support msccl++ primitives
- Generate execution plan for msccl++ executor
  • Loading branch information
Binyang2014 authored May 8, 2024
1 parent 3bfba6b commit 87ce281
Show file tree
Hide file tree
Showing 11 changed files with 1,796 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllReduce


def allreduce_allpairs(gpus, instances):
size = gpus
chunksperloop = gpus * gpus
topology = fully_connected(size)
collective = AllReduce(size, chunksperloop, True)
with MSCCLPPProgram(
"allreduce_pairs",
topology,
collective,
instances,
protocol="LL",
):

# Each rank sends the nth chunk to the nth rank into scratch space
for r1 in range(size):
for tb in range(size):
if tb == r1:
continue
remote_rank = tb
index = remote_rank * size
c = chunk(r1, Buffer.input, index, size)
c.put_packet(remote_rank, "scratch", index=r1*size, sendtb=tb)

# Each rank performs a local reduction on the nth chunk
# Utilize 8 threadblocks for this reduction for better parallelism
for r in range(size):
for index in range(size):
c = chunk(r, Buffer.input, r * size + index)
for peer in range(size):
if peer != r:
c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index)
for peer in range(size):
if peer != r:
c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index)

# Each rank get final result from scratch space
for r in range(size):
for peer in range(size):
if peer != r:
c = chunk(r, "scratch", size * size + peer * size, size)
c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)

Json()
Check()


parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")

args = parser.parse_args()

allreduce_allpairs(args.num_gpus, args.instances)
58 changes: 58 additions & 0 deletions examples/mscclang/mscclpp/allreduce_a100_allpairs_sm_mscclpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllReduce


def allreduce_allpairs(gpus, instances, protocol):
size = gpus
chunksperloop = gpus * gpus
topology = fully_connected(size)
collective = AllReduce(size, chunksperloop, True)
with MSCCLPPProgram("allreduce_pairs", topology, collective, instances, protocol=protocol):
for rank in range(size):
for tb in range(size):
index = rank * size
c = chunk(rank, Buffer.input, index + tb)
# step1 make sure the data is ready
for nghr in range(size):
peer_index = nghr * size
if rank != nghr:
# signal peer the buffer is ready
c_peer = chunk(rank, Buffer.input, peer_index + tb)
c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb)
for nghr in range(size):
if rank != nghr:
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
# step2 reduce the chunks and send to peers
for nghr in range(size):
if rank != nghr:
c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb)
for nghr in range(size):
if rank != nghr:
c.put(nghr, Buffer.input, index + tb, sendtb=tb)
# step3 signal the peers buffer is ready
for nghr in range(size):
if rank != nghr:
c.signal(nghr, Buffer.input, index + tb, sendtb=tb)
for nghr in range(size):
if rank != nghr:
peer_index = nghr * size
c_peer = chunk(rank, Buffer.input, peer_index + tb)
c_peer.wait(nghr, Buffer.input, peer_index + tb, recvtb=tb)

Json()
Check()


parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol")

args = parser.parse_args()

allreduce_allpairs(args.num_gpus, args.instances, args.protocol)
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllReduce


def allreduce_allpairs(gpus, instances, protocol):
size = gpus
chunksperloop = gpus * gpus
topology = fully_connected(size)
collective = AllReduce(size, chunksperloop, True)
with MSCCLPPProgram(
"allreduce_pairs",
topology,
collective,
instances,
protocol=protocol,
):

# Each rank sends the nth chunk to the nth rank into scratch space
for rank in range(size):
for tb in range(size):
index = rank * size
c = chunk(rank, Buffer.input, index + tb)
# make sure the data is ready
for nghr in range(size):
peer_index = nghr * size
if rank != nghr:
c_peer = chunk(rank, Buffer.input, peer_index + tb)
c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb)
for nghr in range(size):
if rank != nghr:
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
# reduce the chunks
for i in range(size):
nghr = (rank + i) % size
if rank != nghr:
c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb)
for nghr in range(size):
if rank != nghr:
c.signal(nghr, Buffer.input, index + tb, sendtb=tb)

# wait for all the chunks is ready, then get the chunks
for rank in range(size):
for tb in range(size):
for nghr in range(size):
if rank != nghr:
index = nghr * size
c = chunk(rank, Buffer.input, index + tb)
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
for i in range(size):
nghr = (rank + i) % size
index = nghr * size
if rank != nghr:
c = chunk(rank, Buffer.input, index + tb)
c.get(nghr, Buffer.input, index + tb, recvtb=tb)

Json()
Check()


parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol")

args = parser.parse_args()

allreduce_allpairs(args.num_gpus, args.instances, args.protocol)
54 changes: 54 additions & 0 deletions examples/mscclang/mscclpp/allreduce_a100_ring_mscclpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllReduce


# Ring all reduce for A100s
def allreduce_ring(size, instances):
topology = fully_connected(size)
collective = AllReduce(size, size, True)
with MSCCLPPProgram(
f"allreduce_ring",
topology,
collective,
instances,
protocol="Simple",
):
# Reduce ring
for step in range(0, size - 1):
for index in range(0, size):
rank = (index + step) % size
next_rank = (index + step + 1) % size
c = chunk(rank, Buffer.input, index)
c.signal(next_rank, Buffer.input, index, 0)
prev_rank = (index + step - 1) % size
c = chunk(rank, Buffer.input, (index + size - 1) % size)
c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0)
c.reduce(chunk(prev_rank, Buffer.input, (index + size - 1) % size), recvtb=0)

# Propagate ring
for step in range(-1, size - 2):
for index in range(0, size):
rank = (index + step) % size
c = chunk(rank, Buffer.input, index)
next_rank = (index + step + 1) % size
c.put(next_rank, Buffer.input, index, sendtb=0)
c.signal(next_rank, Buffer.input, index, 0)
prev_rank = (index + step - 1) % size
c = chunk(rank, Buffer.input, (index + size - 1) % size)
c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0)

Json()
Check()


parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()

allreduce_ring(args.num_gpus, args.instances)
9 changes: 7 additions & 2 deletions msccl/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from msccl.language.chunk import *
from msccl.language.buffer import *
from msccl.language.instruction_dag import *
import msccl.language.mscclpp as mscclpp
from msccl.language.mscclpp import *
from typing import Union

from msccl.language.types import ReplicationPolicy, ThreadblockPolicy

Expand All @@ -18,8 +21,10 @@

def _curr():
global _current_program
if _current_program == None:
if _current_program == None and mscclpp._current_program == None:
raise RuntimeError("No Program in context")
if _current_program == None:
return mscclpp._current_program
return _current_program


Expand Down Expand Up @@ -279,7 +284,7 @@ def Print():
_curr().print_chunk_dag()


def chunk(rank, buffer, index, size=1) -> Ref:
def chunk(rank, buffer, index, size=1) -> Union[mscclpp.Ref, Ref]:
if _curr().buffers[rank][buffer][index] is None:
return None
return _curr().get_ref(rank, buffer, index, size)
Expand Down
Loading

0 comments on commit 87ce281

Please sign in to comment.