diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 34c72533baf..1af6902ae69 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2750,6 +2750,21 @@ def _aten_nextafter(input, other, *, out=None): return jnp.nextafter(input, other) +@op(torch.ops.aten.nonzero_static) +def _aten_nonzero_static(input, size, fill_value = -1): + # import pdb; pdb.set_trace() + # print("arrive here _aten_nonzero") + indices = jnp.argwhere(input) + + if size < indices.shape[0]: + indices = indices[:size] + elif size > indices.shape[0]: + padding = jnp.full((size - indices.shape[0], indices.shape[1]), fill_value, dtype=indices.dtype) + indices = jnp.concatenate((indices, padding)) + + return indices + + # aten.nonzero @op(torch.ops.aten.nonzero) def _aten_nonzero(x, as_tuple=False):