Skip to content

Commit

Permalink
Op Info tests broadcast_tensors, count_nonzero, cov, cross, a…
Browse files Browse the repository at this point in the history
…nd `equal` (pytorch#7951)
  • Loading branch information
ManfeiBai authored and guyao committed Sep 9, 2024
1 parent 636484c commit 3daf9b3
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 6 deletions.
5 changes: 0 additions & 5 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"_upsample_bilinear2d_aa",
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"block_diag",
"broadcast_tensors",
"bucketize",
"byte",
"cat",
Expand All @@ -27,9 +26,6 @@
"cholesky_solve",
"combinations",
"complex",
"count_nonzero",
"cov",
"cross",
"cummax",
"cummin",
"cumsum",
Expand All @@ -40,7 +36,6 @@
"diff",
"digamma",
"dist",
"equal",
"erfc",
"erfinv",
"expand",
Expand Down
76 changes: 75 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,10 @@ def reduce_fn(a, b):

@op(torch.ops.aten.min)
def _aten_min(x, axis=None):
return jnp.min(x, axis=axis), jnp.argmin(x, axis=axis).astype(jnp.int64)
if axis:
return jnp.min(x, axis=axis), jnp.argmin(x, axis=axis).astype(jnp.int64)
else:
return jnp.min(x, axis=axis)


@op(torch.ops.aten.amin)
Expand Down Expand Up @@ -1695,6 +1698,70 @@ def _aten_bitwise_xor(self, other):
return self ^ other


# aten.broadcast_tensors
@op(torch.ops.aten.broadcast_tensors)
def _aten_broadcast_tensors(*tensors):

def _get_broadcast_shape(shapes):
"""
Determines the output shape by broadcasting all input shapes.
Args:
shapes: A list of tuples representing the shapes of the input tensors.
Returns:
A tuple representing the broadcasted output shape.
"""

# Find the maximum number of dimensions among all input tensors
max_dims = max(len(shape) for shape in shapes)
# Pad shorter shapes with 1s on the left to match the maximum number of dimensions
padded_shapes = [(1,) * (max_dims - len(shape)) + shape for shape in shapes]

# Initialize the output shape with 1s
output_shape = [1] * max_dims
# Iterate through each dimension and apply broadcasting rules
for dim in range(max_dims):
dim_sizes = [shape[dim] for shape in padded_shapes]
max_size = max(dim_sizes)
if all(size == 1 or size == max_size for size in dim_sizes):
output_shape[dim] = max_size
else:
raise ValueError("Incompatible shapes for broadcasting")
return tuple(output_shape)

def _broadcast_dimensions(input_shape, output_shape):
"""
Determines the broadcast_dimensions argument for jax.lax.broadcast_in_dim.
Args:
input_shape: The shape of the input tensor.
output_shape: The desired output shape after broadcasting.
Returns:
A tuple specifying which dimensions of the input tensor should be broadcasted.
"""

res = tuple(i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape)))
return res

# clean some function's previous wrap
if len(tensors)==1 and len(tensors[0])>=1 and isinstance(tensors[0][0], jax.Array):
tensors = tensors[0]

# Get the shapes of all input tensors
shapes = [t.shape for t in tensors]
# Find the output shape by broadcasting all input shapes
output_shape = _get_broadcast_shape(shapes)
# Broadcast each tensor to the output shape
broadcasted_tensors = [
jax.lax.broadcast_in_dim(t, output_shape, _broadcast_dimensions(t.shape, output_shape))
for t in tensors
]

return broadcasted_tensors


# aten.broadcast_to
@op(torch.ops.aten.broadcast_to)
def _aten_broadcast_to(input, shape):
Expand Down Expand Up @@ -1780,6 +1847,13 @@ def _aten_eq(input1, input2):
return input1 == input2


# aten.equal
@op(torch.ops.aten.equal, is_jax_function=False)
def _aten_equal(input, other):
res = jnp.array_equal(input._elem, other._elem)
return bool(res)


# aten.erf
@op(torch.ops.aten.erf)
@op_base.promote_int_input
Expand Down

0 comments on commit 3daf9b3

Please sign in to comment.