From e499c8765055e19e7b8172eab2d7a941b5c2fc7e Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:28:13 -0800 Subject: [PATCH] Update jaten.py --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 34c72533baf..1af6902ae69 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2750,6 +2750,21 @@ def _aten_nextafter(input, other, *, out=None): return jnp.nextafter(input, other) +@op(torch.ops.aten.nonzero_static) +def _aten_nonzero_static(input, size, fill_value = -1): + # import pdb; pdb.set_trace() + # print("arrive here _aten_nonzero") + indices = jnp.argwhere(input) + + if size < indices.shape[0]: + indices = indices[:size] + elif size > indices.shape[0]: + padding = jnp.full((size - indices.shape[0], indices.shape[1]), fill_value, dtype=indices.dtype) + indices = jnp.concatenate((indices, padding)) + + return indices + + # aten.nonzero @op(torch.ops.aten.nonzero) def _aten_nonzero(x, as_tuple=False):