From ee388f65bf43dc5f4ae94f0675812e75c9535d9c Mon Sep 17 00:00:00 2001 From: Matthias Guenther Date: Wed, 16 Oct 2024 11:20:59 -0700 Subject: [PATCH] Add `unique`, `unique_consecutive` (#8258) Co-authored-by: mrguenther --- experimental/torch_xla2/test/test_ops.py | 4 +- .../torch_xla2/torch_xla2/ops/jaten.py | 137 ++++++++++++++++++ 2 files changed, 138 insertions(+), 3 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index aa0d84e71dc..bcc5b05e81c 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -97,8 +97,6 @@ "svd_lowrank", "unfold_copy", "unfold", - "unique_consecutive", - "unique", "unravel_index", "var_mean", "nanmean", @@ -109,7 +107,7 @@ not_support_ops_list = { "chalf", # Skip due to jax not support complex32 with backend: https://github.com/google/jax/issues/14180 "__rpow__", # NOTE: cannot fix because torch test case has undefined behavior - # such as 0 to negative power. + # such as 0 to negative power. "ceil", # only failed with python 3.9 "trunc", # only failed with python 3.9 "to_sparse", # We are not supporting sparse tensors yet. diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5ef39d40cc8..2535338fd2a 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -53,6 +53,9 @@ torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce, } +# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`. +_jax_version = tuple(int(v) for v in jax.version._version.split(".")) + def make_mutation(op): if type(mutation_ops_to_functional[op]) is tuple: @@ -2757,6 +2760,140 @@ def _aten_unbind(a, dim=0): return [jax.lax.index_in_dim(a, i, dim, keepdims=False) for i in range(a.shape[dim])] +# aten.unique_dim +# +# NOTE: Like the CUDA and CPU implementations, this implementation always sorts +# the tensor regardless of the `sorted` argument passed to `torch.unique`. +@op(torch.ops.aten.unique_dim) +def _aten_unique_dim(input_tensor, + dim, + sort=True, + return_inverse=False, + return_counts=False): + result_tensor_or_tuple = jnp.unique(input_tensor, + return_index=False, + return_inverse=return_inverse, + return_counts=return_counts, + axis=dim, + equal_nan=False) + result_list = ( + list(result_tensor_or_tuple) if isinstance(result_tensor_or_tuple, tuple) + else [result_tensor_or_tuple]) + + if not return_inverse: + result_list.insert(1, None) + elif _jax_version < (0, 4, 31) and dim is not None: + result_list[1] = result_list[1].flatten() + + if not return_counts: + result_list.insert(2, None) + + # [result, None, None] if return_inverse=False and return_counts=False + # [result, inverse, None] if return_inverse=True and return_counts=False + # [result, None, counts] if return_inverse=False and return_counts=True + # [result, inverse, counts] if return_inverse=True and return_counts=True + return result_list + + +# aten._unique +# +# NOTE: Like the CUDA and CPU implementations, this implementation always sorts +# the tensor regardless of the `sorted` argument passed to `torch.unique`. +@op(torch.ops.aten._unique) +def _aten_unique(input_tensor, + sort=True, + return_inverse=False): + result_tensor_or_tuple = jnp.unique(input_tensor, + return_index=False, + return_inverse=return_inverse, + return_counts=False, + axis=None, + equal_nan=False) + if return_inverse: + return result_tensor_or_tuple + else: + return (result_tensor_or_tuple, None) + + +# aten._unique2 +# +# NOTE: Like the CUDA and CPU implementations, this implementation always sorts +# the tensor regardless of the `sorted` argument passed to `torch.unique`. +@op(torch.ops.aten._unique2) +def _aten_unique2(input_tensor, + sort=True, + return_inverse=False, + return_counts=False): + return _aten_unique_dim(input_tensor=input_tensor, + dim=None, + sort=sort, + return_inverse=return_inverse, + return_counts=return_counts) + + +# aten.unique_consecutive +@op(torch.ops.aten.unique_consecutive) +def _aten_unique_consecutive(input_tensor, + return_inverse=False, + return_counts=None, + dim=None): + # Explanation of computations (shown in 1D for simplicity): + # + # Input [a b b c c c d d d d e e e e e] + # Slice dropping final element (input[:-1]) [a b b c c c d d d d e e e e] + # Slice dropping first element (input[1:]) [b b c c c d d d d e e e e e] + # Boolean != operation on shifted slices [1 0 1 0 0 1 0 0 0 1 0 0 0 0] + # Prepend 1 to represent the first element [1 1 0 1 0 0 1 0 0 0 1 0 0 0 0] + # Filter input by the resulting bool array [a b c d e ] + # Output [a b c d e] + + if dim is None: + inverse_shape = input_tensor.shape + input_tensor = input_tensor.flatten() + ndim = 1 + dim = 0 + else: + inverse_shape = input_tensor.shape[dim] + ndim = input_tensor.ndim + if dim < 0: + dim += ndim + + nd_slice_0 = tuple(slice(None, -1) if d == dim else slice(None) + for d in range(ndim)) + nd_slice_1 = tuple(slice(1, None) if d == dim else slice(None) + for d in range(ndim)) + + axes_to_reduce = tuple(d for d in range(ndim) if d != dim) + + does_not_equal_prior = ( + jnp.any(input_tensor[nd_slice_0] != input_tensor[nd_slice_1], + axis=axes_to_reduce, + keepdims=False)) + + if input_tensor.shape[dim] != 0: + # Prepend `True` to represent the first element of the input. + does_not_equal_prior = jnp.insert(does_not_equal_prior, 0, True) + + include_indices = jnp.argwhere(does_not_equal_prior)[:, 0] + + output_tensor = input_tensor[ + tuple(include_indices if d == dim else slice(None) for d in range(ndim))] + + if return_inverse or return_counts: + counts = (jnp.append(include_indices[1:], input_tensor.shape[dim]) - + include_indices[:]) + + inverse = ( + jnp.reshape(jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape) + if return_inverse + else None + ) + + return output_tensor, inverse, counts + + return output_tensor, None, None + + # NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d # despite those being core aten ops, they also have decompositions. # here we are using torch decompositions.