Skip to content

Commit

Permalink
Add bitwise_left_shift, bitwise_right_shift (#7343)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Jun 25, 2024
1 parent b505288 commit 557996a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
2 changes: 0 additions & 2 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
"_segment_reduce",
"_upsample_bilinear2d_aa",
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"bitwise_left_shift",
"bitwise_right_shift",
"block_diag",
"broadcast_tensors",
"broadcast_to",
Expand Down
14 changes: 13 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,18 @@ def _aten_bitwise_not(self):
return ~self


# aten.bitwise_left_shift
@op(torch.ops.aten.bitwise_left_shift)
def _aten_bitwise_left_shift(input, other):
return jnp.left_shift(input, other)


# aten.bitwise_right_shift
@op(torch.ops.aten.bitwise_right_shift)
def _aten_bitwise_right_shift(input, other):
return jnp.right_shift(input, other)


# aten.embedding_dense_backward


Expand Down Expand Up @@ -2113,4 +2125,4 @@ def _aten_randint(

@op(torch.ops.aten.dim, is_jax_function=False)
def _aten_dim(self):
return len(self.shape)
return len(self.shape)

0 comments on commit 557996a

Please sign in to comment.