Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[NFC] Clean up collective (#546)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jun 27, 2022
1 parent 11a67b8 commit 88b9332
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 363 deletions.
Original file line number Diff line number Diff line change
@@ -1,195 +1,31 @@
"""Implementation of devive_mesh's send/recv/allgather/broadcast."""

from typing import Sequence
"""Utility functions for device mesh workers to call nccl APIs."""
import logging
from typing import Sequence

from jax import device_put
from jax._src.lib import xla_extension as xe

import cupy
import jax.numpy as jnp
from jax import device_put
from jax._src.dlpack import from_dlpack, to_dlpack
from jax._src.lib import xla_bridge as xb, xla_client as xc
import numpy as np
import cupy

import alpa.collective as col
from alpa.collective.collective_group import nccl_util
from alpa.util import (jax_tensor_to_cupy, cupy_to_jax_tensor, jax_tensor_set,
from alpa.util import (jax_tensor_set, jax_tensor_index,
xla_buffer_to_jax_tensor, jax_tensor_to_xla_buffer,
xla_buffer_to_cupy, cupy_to_xla_buffer,
is_continuous_subset, infer_offset_and_n_elements,
jax_tensor_index, infer_start_pos_and_n_elements)
is_continuous_subset, infer_offset_and_n_elements)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def xla_nccl_send_tile(worker, uuid: int, offset: Sequence[slice],
dst_rank: int, dst_gpu_idx: int, group_name: str):

tensor_shape = worker.buffers[uuid].shape
if is_continuous_subset(offset, tensor_shape):
start_pos, n_elements = (infer_start_pos_and_n_elements(
tensor_shape, offset))
col.send_multigpu(worker.buffers[uuid],
dst_rank,
dst_gpu_idx,
group_name,
start_pos=start_pos,
n_elements=n_elements)
else:
# slower path, because of indexing.
logger.debug("Send goes along the slowest path. "
"If this is for transformers, please check the resharding "
"specs.")
start_indices = tuple(o.start for o in offset)
slice_sizes = tuple(o.stop - o.start for o in offset)
src_buffer = jax_tensor_index(
xla_buffer_to_jax_tensor(worker.buffers[uuid]), start_indices,
slice_sizes)
to_send = jax_tensor_to_xla_buffer(src_buffer)
n_elements = np.prod(slice_sizes)
col.send_multigpu(to_send,
dst_rank,
dst_gpu_idx,
group_name,
start_pos=0,
n_elements=n_elements)


def xla_nccl_recv_tile(worker, uuid: int, device_id: int,
indices_in_dst_tile: Sequence[slice], src_rank: int,
src_gpu_idx: int, group_name: str):
tensor_shape = worker.buffers[uuid].shape
slice_shape = tuple(ind.stop - ind.start for ind in indices_in_dst_tile)
is_bool = worker.buffers[uuid].dtype == np.bool_
if is_continuous_subset(indices_in_dst_tile, tensor_shape):
start_pos, n_elements = infer_start_pos_and_n_elements(
tensor_shape, indices_in_dst_tile)
col.recv_multigpu(worker.buffers[uuid],
src_rank,
src_gpu_idx,
group_name,
start_pos=start_pos,
n_elements=n_elements)
else:
tmp_buffer = device_put(
jnp.ones(slice_shape, dtype=worker.buffers[uuid].dtype),
worker.local_devices[device_id])
to_recv = jax_tensor_to_xla_buffer(tmp_buffer)
n_elements = np.prod(slice_shape)
col.recv_multigpu(to_recv,
src_rank,
src_gpu_idx,
group_name,
start_pos=0,
n_elements=n_elements)
start_indices = tuple(
ind_in_dst.start for ind_in_dst in indices_in_dst_tile)
new_buffer = jax_tensor_set(
xla_buffer_to_jax_tensor(worker.buffers[uuid]),
xla_buffer_to_jax_tensor(to_recv), start_indices)
worker.buffers[uuid] = jax_tensor_to_xla_buffer(new_buffer)
if is_bool:
worker.buffers[uuid] = _uint8_to_bool(worker.buffers[uuid])


def xla_nccl_allgather(worker, uuids: Sequence[int], device_ids: Sequence[int],
tensor_slices: Sequence[slice], output_slice):

if repr(sorted(device_ids)) not in worker.allgather_communicators:
worker.allgather_communicators[repr(
sorted(device_ids))] = (worker.nccl_local_allgather_init_comms(
list(device_ids)))

communicators = worker.allgather_communicators[repr(sorted(device_ids))]
is_bool = worker.buffers[uuids[device_ids[0]]].dtype == np.bool_
tensor_shape = worker.buffers[uuids[device_ids[0]]].shape
global_start_pos, _ = infer_start_pos_and_n_elements(
tensor_shape, output_slice)

buffers = []
local_start_pos_list = []
for device_id, tensor_slice in zip(device_ids, tensor_slices):
uuid = uuids[device_id]
xla_buffer = worker.buffers[uuid]
start_pos, _ = infer_start_pos_and_n_elements(tensor_shape,
tensor_slice)
buffers.append(xla_buffer)
local_start_pos_list.append(start_pos)

_, local_n_elements = infer_offset_and_n_elements(tensor_slices[0])
xe.nccl_local_all_gather(communicators, buffers, local_start_pos_list,
global_start_pos, local_n_elements)

for device_id, buf in zip(device_ids, buffers):
uuid = uuids[device_id]
if is_bool:
buf = _uint8_to_bool(buf)
worker.buffers[uuid] = buf


def xla_nccl_broadcast(worker, uuids, comm_key, world_size, devices_ids,
devices_global_rank, tensor_slices, group_name):
buffers = []
local_start_pos_list = []
is_bool = worker.buffers[uuids[devices_ids[0]]].dtype == np.bool_
_, n_elements = infer_offset_and_n_elements(tensor_slices[0])
for device_id, global_rank, tensor_slice in zip(devices_ids,
devices_global_rank,
tensor_slices):
uuid = uuids[device_id]
tensor_shape = worker.buffers[uuid].shape
slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice)
if is_continuous_subset(tensor_slice, tensor_shape):
# fast path, two cases: (1) same shape, (2) continuous subset.
start_pos, _ = infer_start_pos_and_n_elements(
tensor_shape, tensor_slice)
local_start_pos_list.append(start_pos)
buffers.append(worker.buffers[uuid])
else:
tmp = None
if global_rank == 0:
start_indices = tuple(o.start for o in tensor_slice)
tmp = jax_tensor_index(
xla_buffer_to_jax_tensor(worker.buffers[uuid]),
start_indices, slice_shape)
else:
tmp = device_put(
jnp.ones(slice_shape, dtype=worker.buffers[uuid].dtype),
worker.local_devices[device_id])
local_start_pos_list.append(0)
buffers.append(jax_tensor_to_xla_buffer(tmp))

col.broadcast_partialgpu(buffers, n_elements, comm_key, world_size,
devices_ids, devices_global_rank, group_name,
local_start_pos_list)

for xla_buffer, device_id, global_rank, tensor_slice in zip(
buffers, devices_ids, devices_global_rank, tensor_slices):
if global_rank == 0:
continue
uuid = uuids[device_id]
tensor_shape = worker.buffers[uuid].shape
slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice)
if is_continuous_subset(tensor_slice, tensor_shape):
worker.buffers[uuid] = xla_buffer
else:
start_indices = tuple(
ind_in_dst.start for ind_in_dst in tensor_slice)
new_buffer = jax_tensor_set(
xla_buffer_to_jax_tensor(worker.buffers[uuid]),
xla_buffer_to_jax_tensor(xla_buffer), start_indices)
worker.buffers[uuid] = jax_tensor_to_xla_buffer(new_buffer)
if is_bool:
worker.buffers[uuid] = _uint8_to_bool(worker.buffers[uuid])


# Note: in this device mesh code, we will use 3 types of tensors:
# (1) JAX high-level _DeviceArray, which is index-able, has __cuda_array__
# interface
# (2) XLA low-level PyLocalBuffer, which is not index-able
# (3) cupy array, which is an intermediate format for ray collective
def cupy_nccl_send_tile(worker, uuid: int, offset: Sequence[slice],
dst_rank: int, dst_gpu_idx: int, group_name: str):
def send_tile(worker, uuid: int, offset: Sequence[slice], dst_rank: int,
dst_gpu_idx: int, group_name: str):
"""
Send a slice of a source buffer to a target GPU.
Expand Down Expand Up @@ -228,9 +64,9 @@ def cupy_nccl_send_tile(worker, uuid: int, offset: Sequence[slice],
col.send_multigpu(to_send, dst_rank, dst_gpu_idx, group_name)


def cupy_nccl_recv_tile(worker, uuid: int, device_id: int,
indices_in_dst_tile: Sequence[slice], src_rank: int,
src_gpu_idx: int, group_name: str):
def recv_tile(worker, uuid: int, device_id: int,
indices_in_dst_tile: Sequence[slice], src_rank: int,
src_gpu_idx: int, group_name: str):
"""
Receive a slice from a source GPU and in-place write it on the target
buffer.
Expand Down Expand Up @@ -291,9 +127,8 @@ def cupy_nccl_recv_tile(worker, uuid: int, device_id: int,
worker.buffers[uuid] = _uint8_to_bool(worker.buffers[uuid])


def cupy_nccl_allgather(worker, uuids: Sequence[int], device_ids: Sequence[int],
tensor_slices: Sequence[slice], output_slice):

def allgather(worker, uuids: Sequence[int], device_ids: Sequence[int],
tensor_slices: Sequence[slice], output_slice):
cupy_buffers = []
communicators = worker.allgather_communicators[repr(sorted(device_ids))]
relative_idx = dict(zip(sorted(device_ids), range(len(device_ids))))
Expand Down Expand Up @@ -322,8 +157,8 @@ def cupy_nccl_allgather(worker, uuids: Sequence[int], device_ids: Sequence[int],
worker.buffers[uuid] = buf


def cupy_nccl_broadcast(worker, uuids, comm_key, world_size, devices_ids,
devices_global_rank, tensor_slices, group_name):
def broadcast(worker, uuids, comm_key, world_size, devices_ids,
devices_global_rank, tensor_slices, group_name):
to_use = []
for_buffer = []
is_bool = worker.buffers[uuids[devices_ids[0]]].dtype == np.bool_
Expand Down Expand Up @@ -383,6 +218,50 @@ def cupy_nccl_broadcast(worker, uuids, comm_key, world_size, devices_ids,
worker.buffers[uuid] = _uint8_to_bool(worker.buffers[uuid])


init_local_comm = cupy.cuda.nccl.NcclCommunicator.initAll


def to_signal_buffer(jax_tensor):
return jax_tensor_to_cupy(jax_tensor, take_ownership=True)


def xla_buffer_to_cupy(xla_buf, take_ownership=False):
"""Convert an xla buffer directly to cupy, w/o transitioning from jax
buffer."""
return cupy.fromDlpack(
xc._xla.buffer_to_dlpack_managed_tensor( # pylint: disable=protected-access
xla_buf,
take_ownership=take_ownership))


def cupy_to_xla_buffer(tensor):
"""Convert cupy tensors to XLA buffers."""
if isinstance(tensor, list):
return list(map(cupy_to_xla_buffer, tensor))
cpu_backend = xb.get_backend("cpu")
try:
gpu_backend = xb.get_backend("gpu")
except RuntimeError:
gpu_backend = None
buf = xc._xla.dlpack_managed_tensor_to_buffer( # pylint: disable=protected-access
tensor.toDlpack(), cpu_backend, gpu_backend)
return buf


def jax_tensor_to_cupy(tensors, take_ownership=False):
"""Convert a Jax DeviceArray to cupy tensor; zero copy."""
if isinstance(tensors, list):
return list(map(jax_tensor_to_cupy, tensors))
return cupy.fromDlpack(to_dlpack(tensors, take_ownership=take_ownership))


def cupy_to_jax_tensor(tensors):
"""Convert cupy tensors to JAX tensors."""
if isinstance(tensors, list):
return list(map(cupy_to_jax_tensor, tensors))
return from_dlpack(tensors.toDlpack())


# in XLA pred(bool) and uint8 are different, but xla->dlpack->xla
# turns a bool into uint8. This implementation is slow.
def _uint8_to_bool(xla_buffer):
Expand Down
Loading

0 comments on commit 88b9332

Please sign in to comment.