Skip to content

Commit

Permalink
Add gather/scatter support 1D tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
chang-l committed Nov 19, 2024
1 parent e1e32bc commit 5edebec
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@


def gen_int_embedding(indice_tensor, embedding_dim, output_type):
if embedding_dim == 0:
embedding_dim = 1 # unsqueeze to 2D tensor for input embeddings (2D is required for scatter op)
indice_count = indice_tensor.shape[0]
indice_part = (
indice_tensor.type(torch.int).reshape(indice_count, 1).repeat(1, embedding_dim)
Expand Down Expand Up @@ -57,9 +59,14 @@ def scatter_gather_test_cast(
f"embedding_dim={embedding_dim}, "
f"indice_count={indice_count}, dt={dt}, mt={mt}, ml={ml}"
)
wm_embedding = wmb.create_wholememory_matrix(
dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition
)
if embedding_dim == 0:
wm_embedding = wmb.create_wholememory_array(
dt, embedding_count, wm_comm, mt, ml, entry_partition
)
else:
wm_embedding = wmb.create_wholememory_matrix(
dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition
)

scatter_indice = torch.arange(
world_rank, embedding_count, world_size, dtype=torch.int64
Expand Down Expand Up @@ -93,9 +100,13 @@ def scatter_gather_test_cast(
local_ref_start = wm_embedding.get_local_entry_start()
local_ref_count = wm_embedding.get_local_entry_count()
assert local_start == local_ref_start
assert local_tensor_cuda.dim() == 2
assert local_tensor_cuda.dim() == 2 if embedding_dim > 0 else 1
assert local_tensor_cuda.shape[0] == local_ref_count
assert local_tensor_cuda.shape[1] == embedding_dim
if local_tensor_cuda.dim() == 2:
assert local_tensor_cuda.shape[1] == embedding_dim
else:
# unsqueeze to 2D for comparison
local_tensor_cuda = local_tensor_cuda.unsqueeze(1)

local_tensor = local_tensor_cuda.cpu()
local_indices = torch.arange(
Expand All @@ -118,6 +129,9 @@ def scatter_gather_test_cast(
)
embedding_after_gather = embedding_after_gather_cuda.cpu()
ref_embedding_gather = gen_int_embedding(gather_indice, embedding_dim, torch.float)
if embedding_after_gather.dim() == 1:
# unsqueeze to 2D for comparison
embedding_after_gather = embedding_after_gather.unsqueeze(1)
# print('\ngather_indice=%s\nembedding_after_gather=%s\nref_embedding_gather=%s' % (
# gather_indice, embedding_after_gather, ref_embedding_gather))
assert torch.allclose(embedding_after_gather, ref_embedding_gather)
Expand All @@ -138,7 +152,6 @@ def routine_func(world_rank: int, world_size: int):
wm_comm = wm_comm.wmb_comm

embedding_count = 1024 * 256 * world_size + 3
embedding_dim = 256
indice_count = 100001
dt = wmb.WholeMemoryDataType.DtFloat
entry_partition = random_partition(embedding_count, world_size)
Expand All @@ -154,18 +167,19 @@ def routine_func(world_rank: int, world_size: int):
wmb.WholeMemoryMemoryLocation.MlHost,
wmb.WholeMemoryMemoryLocation.MlDevice,
]:
if wm_comm.support_type_location(mt, ml):
scatter_gather_test_cast(
wm_comm,
dt,
mt,
ml,
embedding_count,
embedding_dim,
indice_count,
True,
entry_partition,
)
for embedding_dim in [0, 256]: # 0 is for 1D tensor
if wm_comm.support_type_location(mt, ml):
scatter_gather_test_cast(
wm_comm,
dt,
mt,
ml,
embedding_count,
embedding_dim,
indice_count,
True,
entry_partition,
)
wmb.finalize()


Expand Down
12 changes: 8 additions & 4 deletions python/pylibwholegraph/pylibwholegraph/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def gather(
self, indice: torch.Tensor, *, force_dtype: Union[torch.dtype, None] = None
):
assert indice.dim() == 1
embedding_dim = self.shape[1]
embedding_dim = self.shape[1] if self.dim() == 2 else 1
embedding_count = indice.shape[0]
current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),)
output_dtype = force_dtype if force_dtype is not None else self.dtype
Expand All @@ -79,13 +79,17 @@ def gather(
get_wholegraph_env_fns(),
get_stream(),
)
return output_tensor
return output_tensor.view(-1) if self.dim() == 1 else output_tensor

def scatter(self, input_tensor: torch.Tensor, indice: torch.Tensor):
assert indice.dim() == 1
assert input_tensor.dim() == 2
assert input_tensor.dim() == self.dim()
assert indice.shape[0] == input_tensor.shape[0]
assert input_tensor.shape[1] == self.shape[1]
if self.dim() == 2:
assert input_tensor.shape[1] == self.shape[1]
else:
# unsqueeze input to 2D tensor here because wmb_tensor is unsqueezed within scatter_op
input_tensor = input_tensor.unsqueeze(1)
wmb.wholememory_scatter_op(
wrap_torch_tensor(input_tensor),
wrap_torch_tensor(indice),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ def wholememory_gather_forward_functor(
assert indices_tensor.dtype == torch.int32 or indices_tensor.dtype == torch.int64
if torch_output_dtype is None:
torch_output_dtype = wholememory_dtype_to_torch_dtype(wholememory_tensor.dtype)

embedding_dim = wholememory_tensor.shape[1] if wholememory_tensor.dim() == 2 else 1
output_tensor = torch.empty(
[indices_tensor.shape[0], wholememory_tensor.shape[1]],
[indices_tensor.shape[0], embedding_dim],
device="cuda",
dtype=torch_output_dtype,
requires_grad=requires_grad,
Expand All @@ -52,7 +54,7 @@ def wholememory_gather_forward_functor(
get_wholegraph_env_fns(),
get_stream(),
)
return output_tensor
return output_tensor.view(-1) if wholememory_tensor.dim() == 1 else output_tensor


def wholememory_scatter_functor(
Expand Down

0 comments on commit 5edebec

Please sign in to comment.