Skip to content

Commit

Permalink
Update jaten.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 14, 2024
1 parent 0eae76f commit 4613713
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2752,7 +2752,10 @@ def _aten_nextafter(input, other, *, out=None):

# aten.nonzero
@op(torch.ops.aten.nonzero)
def _aten_nonzero(x):
def _aten_nonzero(x, as_tuple=False):
if jnp.ndim(x) == 0 and (as_tuple or x.item()==0):
print("arrive here")
return torch.empty(0, 0, dtype=torch.int64) # x
if jnp.ndim(x) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64)
res = torch.empty(1, 0, dtype=torch.int64)
return jnp.array(res.numpy())
Expand Down

0 comments on commit 4613713

Please sign in to comment.