Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Apr 29, 2024
1 parent 8f46276 commit 775d9d6
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 7 deletions.
2 changes: 1 addition & 1 deletion msccl/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_buffer_index(self, rank, buffer, index):
class AllReduce(Collective):

def __init__(self, num_ranks, chunk_factor, inplace):
Collective.__init__(self, num_ranks, chunk_factor, inplace, num_ranks)
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "allreduce"

def init_buffers(self):
Expand Down
6 changes: 0 additions & 6 deletions msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from enum import Enum

from msccl.language.buffer import Buffer
from msccl.language.channel import ChannelType


@dataclass
Expand All @@ -32,7 +31,6 @@ class Gpu:
output_chunks: int = 0
scratch_chunks: int = 0
scratch: dict = field(default_factory=dict)
channels: dict = field(default_factory=dict)

def scratch_size(self):
return max((idx for addr, idx in self.scratch.items()), default=-1) + 1
Expand All @@ -46,7 +44,6 @@ class Threadblock:
recv: int = -1
ops: list = field(default_factory=list)
rbid: int = -1 # threadblock id of the receiver
channels: list = field(default_factory=list)

def __eq__(self, other):
return self is other
Expand All @@ -73,8 +70,6 @@ def __str__(self):


class ReplicationPolicy(Enum):
# this means pack multi instrances to deal with the same chunk and share the channels
packed = "packed"
# this means each instance deal with the different chunk
# Chunk A, Chunk B -> Chunk A0, Chunk B0, Chunk A1, Chunk B1
duplicated = "duplicated"
Expand Down Expand Up @@ -131,7 +126,6 @@ class Op:
recv_match = None
send_match = None
channel: int = -1
channel_type: ChannelType = ChannelType.none
srcs: list = field(default_factory=list)
dsts: list = field(default_factory=list)

Expand Down

0 comments on commit 775d9d6

Please sign in to comment.