From 7a47cfc4aaeb10163a5c2ad185fe6a38ea00e96f Mon Sep 17 00:00:00 2001 From: chang-l Date: Tue, 19 Nov 2024 22:50:43 -0800 Subject: [PATCH] Fix format --- .../wholegraph_torch/ops/test_wholegraph_gather_scatter.py | 4 ++-- python/pylibwholegraph/pylibwholegraph/torch/tensor.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py index f42d2b6..078e6cf 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py @@ -26,7 +26,7 @@ def gen_int_embedding(indice_tensor, embedding_dim, output_type): if embedding_dim == 0: - embedding_dim = 1 # unsqueeze to 2D tensor for input embeddings (2D is required for scatter op) + embedding_dim = 1 # unsqueeze 2D for input (2D is required for scatter op) indice_count = indice_tensor.shape[0] indice_part = ( indice_tensor.type(torch.int).reshape(indice_count, 1).repeat(1, embedding_dim) @@ -167,7 +167,7 @@ def routine_func(world_rank: int, world_size: int): wmb.WholeMemoryMemoryLocation.MlHost, wmb.WholeMemoryMemoryLocation.MlDevice, ]: - for embedding_dim in [0, 256]: # 0 is for 1D tensor + for embedding_dim in [0, 256]: # 0 is for 1D tensor if wm_comm.support_type_location(mt, ml): scatter_gather_test_cast( wm_comm, diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index f169bed..41d8fad 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -88,7 +88,7 @@ def scatter(self, input_tensor: torch.Tensor, indice: torch.Tensor): if self.dim() == 2: assert input_tensor.shape[1] == self.shape[1] else: - # unsqueeze input to 2D tensor here because wmb_tensor is unsqueezed within scatter_op + # unsqueeze to 2D tensor because wmb_tensor is unsqueezed within scatter_op input_tensor = input_tensor.unsqueeze(1) wmb.wholememory_scatter_op( wrap_torch_tensor(input_tensor),