diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py index 361ae4f..f42d2b6 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py @@ -25,6 +25,8 @@ def gen_int_embedding(indice_tensor, embedding_dim, output_type): + if embedding_dim == 0: + embedding_dim = 1 # unsqueeze to 2D tensor for input embeddings (2D is required for scatter op) indice_count = indice_tensor.shape[0] indice_part = ( indice_tensor.type(torch.int).reshape(indice_count, 1).repeat(1, embedding_dim) @@ -57,9 +59,14 @@ def scatter_gather_test_cast( f"embedding_dim={embedding_dim}, " f"indice_count={indice_count}, dt={dt}, mt={mt}, ml={ml}" ) - wm_embedding = wmb.create_wholememory_matrix( - dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition - ) + if embedding_dim == 0: + wm_embedding = wmb.create_wholememory_array( + dt, embedding_count, wm_comm, mt, ml, entry_partition + ) + else: + wm_embedding = wmb.create_wholememory_matrix( + dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition + ) scatter_indice = torch.arange( world_rank, embedding_count, world_size, dtype=torch.int64 @@ -93,9 +100,13 @@ def scatter_gather_test_cast( local_ref_start = wm_embedding.get_local_entry_start() local_ref_count = wm_embedding.get_local_entry_count() assert local_start == local_ref_start - assert local_tensor_cuda.dim() == 2 + assert local_tensor_cuda.dim() == 2 if embedding_dim > 0 else 1 assert local_tensor_cuda.shape[0] == local_ref_count - assert local_tensor_cuda.shape[1] == embedding_dim + if local_tensor_cuda.dim() == 2: + assert local_tensor_cuda.shape[1] == embedding_dim + else: + # unsqueeze to 2D for comparison + local_tensor_cuda = local_tensor_cuda.unsqueeze(1) local_tensor = local_tensor_cuda.cpu() local_indices = torch.arange( @@ -118,6 +129,9 @@ def scatter_gather_test_cast( ) embedding_after_gather = embedding_after_gather_cuda.cpu() ref_embedding_gather = gen_int_embedding(gather_indice, embedding_dim, torch.float) + if embedding_after_gather.dim() == 1: + # unsqueeze to 2D for comparison + embedding_after_gather = embedding_after_gather.unsqueeze(1) # print('\ngather_indice=%s\nembedding_after_gather=%s\nref_embedding_gather=%s' % ( # gather_indice, embedding_after_gather, ref_embedding_gather)) assert torch.allclose(embedding_after_gather, ref_embedding_gather) @@ -138,7 +152,6 @@ def routine_func(world_rank: int, world_size: int): wm_comm = wm_comm.wmb_comm embedding_count = 1024 * 256 * world_size + 3 - embedding_dim = 256 indice_count = 100001 dt = wmb.WholeMemoryDataType.DtFloat entry_partition = random_partition(embedding_count, world_size) @@ -154,18 +167,19 @@ def routine_func(world_rank: int, world_size: int): wmb.WholeMemoryMemoryLocation.MlHost, wmb.WholeMemoryMemoryLocation.MlDevice, ]: - if wm_comm.support_type_location(mt, ml): - scatter_gather_test_cast( - wm_comm, - dt, - mt, - ml, - embedding_count, - embedding_dim, - indice_count, - True, - entry_partition, - ) + for embedding_dim in [0, 256]: # 0 is for 1D tensor + if wm_comm.support_type_location(mt, ml): + scatter_gather_test_cast( + wm_comm, + dt, + mt, + ml, + embedding_count, + embedding_dim, + indice_count, + True, + entry_partition, + ) wmb.finalize() diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index e46ffa2..f169bed 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -62,7 +62,7 @@ def gather( self, indice: torch.Tensor, *, force_dtype: Union[torch.dtype, None] = None ): assert indice.dim() == 1 - embedding_dim = self.shape[1] + embedding_dim = self.shape[1] if self.dim() == 2 else 1 embedding_count = indice.shape[0] current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),) output_dtype = force_dtype if force_dtype is not None else self.dtype @@ -79,13 +79,17 @@ def gather( get_wholegraph_env_fns(), get_stream(), ) - return output_tensor + return output_tensor.view(-1) if self.dim() == 1 else output_tensor def scatter(self, input_tensor: torch.Tensor, indice: torch.Tensor): assert indice.dim() == 1 - assert input_tensor.dim() == 2 + assert input_tensor.dim() == self.dim() assert indice.shape[0] == input_tensor.shape[0] - assert input_tensor.shape[1] == self.shape[1] + if self.dim() == 2: + assert input_tensor.shape[1] == self.shape[1] + else: + # unsqueeze input to 2D tensor here because wmb_tensor is unsqueezed within scatter_op + input_tensor = input_tensor.unsqueeze(1) wmb.wholememory_scatter_op( wrap_torch_tensor(input_tensor), wrap_torch_tensor(indice), diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py b/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py index eedf4bb..773f3a3 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py @@ -39,8 +39,10 @@ def wholememory_gather_forward_functor( assert indices_tensor.dtype == torch.int32 or indices_tensor.dtype == torch.int64 if torch_output_dtype is None: torch_output_dtype = wholememory_dtype_to_torch_dtype(wholememory_tensor.dtype) + + embedding_dim = wholememory_tensor.shape[1] if wholememory_tensor.dim() == 2 else 1 output_tensor = torch.empty( - [indices_tensor.shape[0], wholememory_tensor.shape[1]], + [indices_tensor.shape[0], embedding_dim], device="cuda", dtype=torch_output_dtype, requires_grad=requires_grad, @@ -52,7 +54,7 @@ def wholememory_gather_forward_functor( get_wholegraph_env_fns(), get_stream(), ) - return output_tensor + return output_tensor.view(-1) if wholememory_tensor.dim() == 1 else output_tensor def wholememory_scatter_functor(