From d0d329849f3ce7779d6fd5e2f90a51be16dd917d Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:31:45 -0700 Subject: [PATCH] Update jaten.py --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 8f34c675df7..138f582c101 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2324,6 +2324,9 @@ def _aten_neg(x): # aten.nonzero @op(torch.ops.aten.nonzero) def _aten_nonzero(x): + if jnp.ndim(x) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64) + res = torch.empty(1, 0, dtype=torch.int64) + return jnp.array(res.numpy()) index_tuple = jnp.nonzero(x) index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] return jnp.concatenate(index_tuple, axis=-1)