Skip to content

Commit

Permalink
Update debugging.py
Browse files Browse the repository at this point in the history
ManfeiBai authored Nov 29, 2023
1 parent 3af91e8 commit 4af73e2
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torch_xla/distributed/spmd/debugging.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,8 @@
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
import torch_xla.utils.utils as xu
import torch_xla.core.xla_env_vars as xenv
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor

try:
@@ -215,7 +217,7 @@ def visualize_sharding(shape: torch.Size,

def visualize_tensor_sharding(t, **kwargs):
"""Visualizes an array's sharding."""
if (isinstance(t, torch.tensor)):
if torch.is_tensor(t):
import torch_xla
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
return visualize_sharding(t.shape, sharding, **kwargs)

0 comments on commit 4af73e2

Please sign in to comment.