Skip to content

Commit

Permalink
Fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
chang-l committed Nov 20, 2024
1 parent 5edebec commit 7a47cfc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

def gen_int_embedding(indice_tensor, embedding_dim, output_type):
if embedding_dim == 0:
embedding_dim = 1 # unsqueeze to 2D tensor for input embeddings (2D is required for scatter op)
embedding_dim = 1 # unsqueeze 2D for input (2D is required for scatter op)
indice_count = indice_tensor.shape[0]
indice_part = (
indice_tensor.type(torch.int).reshape(indice_count, 1).repeat(1, embedding_dim)
Expand Down Expand Up @@ -167,7 +167,7 @@ def routine_func(world_rank: int, world_size: int):
wmb.WholeMemoryMemoryLocation.MlHost,
wmb.WholeMemoryMemoryLocation.MlDevice,
]:
for embedding_dim in [0, 256]: # 0 is for 1D tensor
for embedding_dim in [0, 256]: # 0 is for 1D tensor
if wm_comm.support_type_location(mt, ml):
scatter_gather_test_cast(
wm_comm,
Expand Down
2 changes: 1 addition & 1 deletion python/pylibwholegraph/pylibwholegraph/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def scatter(self, input_tensor: torch.Tensor, indice: torch.Tensor):
if self.dim() == 2:
assert input_tensor.shape[1] == self.shape[1]
else:
# unsqueeze input to 2D tensor here because wmb_tensor is unsqueezed within scatter_op
# unsqueeze to 2D tensor because wmb_tensor is unsqueezed within scatter_op
input_tensor = input_tensor.unsqueeze(1)
wmb.wholememory_scatter_op(
wrap_torch_tensor(input_tensor),
Expand Down

0 comments on commit 7a47cfc

Please sign in to comment.