Skip to content

Commit

Permalink
Update jaten.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 14, 2024
1 parent 3c9529e commit e499c87
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2750,6 +2750,21 @@ def _aten_nextafter(input, other, *, out=None):
return jnp.nextafter(input, other)


@op(torch.ops.aten.nonzero_static)
def _aten_nonzero_static(input, size, fill_value = -1):
# import pdb; pdb.set_trace()
# print("arrive here _aten_nonzero")
indices = jnp.argwhere(input)

if size < indices.shape[0]:
indices = indices[:size]
elif size > indices.shape[0]:
padding = jnp.full((size - indices.shape[0], indices.shape[1]), fill_value, dtype=indices.dtype)
indices = jnp.concatenate((indices, padding))

return indices


# aten.nonzero
@op(torch.ops.aten.nonzero)
def _aten_nonzero(x, as_tuple=False):
Expand Down

0 comments on commit e499c87

Please sign in to comment.