diff --git a/BUILD.bazel b/BUILD.bazel index 49ca5b4dcce..03754b5fd92 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -311,6 +311,7 @@ cc_library( ], strip_include_prefix = ".", deps = [ + ":base", ":interpreter_ops_inc_gen", ":reference_numpy", ":reference_ops", diff --git a/docs/images/spec/gather.svg b/docs/images/spec/gather.svg index ad7414c56c1..d2891862c5f 100644 --- a/docs/images/spec/gather.svg +++ b/docs/images/spec/gather.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/docs/images/spec/scatter.svg b/docs/images/spec/scatter.svg index c257d692aeb..e0e6a72a2fb 100644 --- a/docs/images/spec/scatter.svg +++ b/docs/images/spec/scatter.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/docs/interpreter_status.md b/docs/interpreter_status.md index df68df9a4e8..3d1bf406b07 100644 --- a/docs/interpreter_status.md +++ b/docs/interpreter_status.md @@ -60,23 +60,14 @@ interpreter supports resides in [hlo_expand_main.cc](https://github.com/openxla/ ### Not in HLO -Apart from the specced ops, this category consists of 10 unspecced ops (see -[StableHLO Ops Categories](#stablehlo-ops-categories)) which are planed to be -moved out of StableHLO. Some of these ops have existing passes in +Apart from the specced ops, this category consists of 8 unspecced ops (see +[StableHLO Ops Categories](#stablehlo-ops-categories)) which are planned to be +moved out of StableHLO. Most of these ops have existing passes in [mhlo](https://github.com/openxla/xla/tree/main/xla/mlir_hlo/mhlo/transforms) to -convert them to StableHLO equivalent ops. There are three ops the interpreter -does not support because there are no existing decompositions to StableHLO ops: - -* `compute_reshape_shape` -* `cstr_reshapable` -* `trace` - -`compute_reshape_shape` and `cstr_reshapable` ops are part of the ongoing -Dynamism work, and they are planned to be removed from StableHLO (see -[#1668](https://github.com/openxla/stablehlo/issues/1668)). - -`trace` op is private to XLA and there no no users in JAX, PyTorch or TensorFlow -(see [#604](https://github.com/openxla/stablehlo/issues/604)). +convert them to StableHLO equivalent ops. There is one op the interpreter +does not support because there is no existing decomposition to StableHLO ops: +`trace`. `trace` op is private to XLA and there no users in JAX, PyTorch or +TensorFlow (see [#604](https://github.com/openxla/stablehlo/issues/604)). The tool to convert remaining ops in this category to equivalent StableHLO ops @@ -254,18 +245,12 @@ hlo-expand --triangular_solve_expander # broadcast mlir-hlo-opt -mhlo-legalize-broadcast-to-broadcast-in-dim -# compute_reshape_shape -# This op will be removed from StableHLO as part of Dynamism work (see #1668). - # create_token mlir-hlo-opt -mhlo-legalize-create-token-to-after-all # cross-replica-sum mlir-hlo-opt -mhlo-legalize-cross-replica-sum-to-all-reduce -# cstr_reshapable -# This op will be removed from StableHLO as part of Dynamism work (see #1668). - # dot mlir-hlo-opt -mhlo-legalize-dot-to-dot-general @@ -295,6 +280,6 @@ mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general | Extensibility | custom_call, get_tuple_element, tuple | 3 | | Miscellaneous | batch_norm_grad, batch_norm_inference, batch_norm_training, cholesky, constant, fft, iota, rng, rng_bit_generator, triangular_solve | 10 | | Modularity | call, func, module, return | 4 | -| Not In HLO | broadcast, compute_reshape_shape, create_token, cross-replica-sum, cstr_reshapable, dot, einsum, torch_index_select, trace, unary_einsum | 10 | +| Not In HLO | broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, trace, unary_einsum | 8 | | Quantization | uniform_dequantize, uniform_quantize | 2 | | Reduction | convolution, dot_general, reduce, reduce_window, select_and_scatter | 5 | diff --git a/docs/spec.md b/docs/spec.md index aa27e58cb65..95aaabfb79c 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -207,8 +207,8 @@ constraints: * For per-tensor quantization: * No additional constraints. * For per-axis quantization: - * (C12) `quantization_dimension < rank(self)`. - * (C13) `dim(self, quantization_dimension) = size(scales)`. + * (C13) `quantization_dimension < rank(self)`. + * (C14) `dim(self, quantization_dimension) = size(scales)`. ```ebnf TokenType ::= 'token' @@ -331,9 +331,8 @@ in StableHLO programs. In the meanwhile, here is the list of these operations: ([#3](https://github.com/openxla/stablehlo/issues/3)), and `trace` ([#604](https://github.com/openxla/stablehlo/issues/604)). * "Dynamism" category of StableHLO operations - they were bootstrapped from - MHLO, but we haven't specced them yet: `compute_reshape_shape`, - `cstr_reshapable`, `dynamic_broadcast_in_dim`, `dynamic_conv`, - `dynamic_gather`, `dynamic_iota`, `dynamic_pad`, `dynamic_reshape`, + MHLO,and we are in the process of speccing them: `dynamic_broadcast_in_dim`, + `dynamic_conv`, `dynamic_gather`, `dynamic_iota`, `dynamic_pad`, `dynamic_reshape`, `real_dynamic_slice`, `set_dimension_size` ([#8](https://github.com/openxla/stablehlo/issues/8)). * Shape computations, including `arith`, `shape` and `tensor` operations @@ -2636,6 +2635,50 @@ planning to address this in  [More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/dot_general.mlir) +### dynamic_iota + +#### Semantics + +This operation is functionally identical to +[iota](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota) +op, but the result shape is specified dynamically via `output_shape`. + +#### Inputs + +| Label | Name | Type | Constraints | +|-------|------------------|------------------------------------------------------------------------------------|-------------| +| (I1) | `output_shape` | 1-dimensional tensor constant of type `si64` | (C1), (C2) | +| (I2) | `iota_dimension` | `si64` | (C1) | + +#### Outputs + +| Name | Type | Constraints | +|----------|-----------------------------------------------------------------------------------|-------------| +| `result` | tensor of integer, floating-point, or complex type or per-tensor quantized tensor | (C2) | + +#### Constraints + +* (C1) `0 <= iota_dimension < size(output_shape)`. +* (C2) `rank(result) = size(output_shape)`. + +#### Examples + +```mlir + +%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64> +%result = "stablehlo.dynamic_iota"(%output_shape) { + iota_dimension = 0 : i64 +} : (tensor<2xi64>) -> tensor<4x5xi64> +// %result: [ +// [0, 0, 0, 0, 0], +// [1, 1, 1, 1, 1], +// [2, 2, 2, 2, 2], +// [3, 3, 3, 3, 3] +// ] +``` + + [More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/dynamic_iota.mlir) + ### dynamic_slice #### Semantics diff --git a/docs/status.md b/docs/status.md index 85512cc6528..28cf6d318c7 100644 --- a/docs/status.md +++ b/docs/status.md @@ -66,7 +66,6 @@ one of the following tracking labels. | compare | yes | yes | yes | yes | yes | | complex | yes | yes | yes | yes | yes | | composite | yes | yes | infeasible | yes | yes | -| compute_reshape_shape | no | revisit | no | yes | no | | concatenate | yes | yes | yes | yes | yes | | constant | yes | yes | yes | yes | yes | | convert | yes | yes | infeasible | yes | yes | @@ -75,7 +74,6 @@ one of the following tracking labels. | count_leading_zeros | yes | yes | yes | yes | yes | | create_token | no | yes\* | yes\* | yes | revisit | | cross-replica-sum | no | revisit | yes\* | no | revisit | -| cstr_reshapable | no | revisit | no | yes | no | | custom_call | yes | yes | infeasible | yes | yes | | divide | yes | yes | yes | yes | yes | | dot | no | revisit | infeasible | yes | revisit | @@ -83,7 +81,7 @@ one of the following tracking labels. | dynamic_broadcast_in_dim | no | revisit | infeasible | no | no | | dynamic_conv | no | revisit | no | no | no | | dynamic_gather | no | revisit | revisit | no | no | -| dynamic_iota | no | revisit | infeasible | yes | no | +| dynamic_iota | yes | yes | infeasible | yes | revisit | | dynamic_pad | no | revisit | no | yes | no | | dynamic_reshape | no | revisit | infeasible | yes | no | | dynamic_slice | yes | yes | yes | yes | yes | diff --git a/rfcs/20240311-gather-scatter-batching-dims.md b/rfcs/20240311-gather-scatter-batching-dims.md new file mode 100644 index 00000000000..7f2f6933885 --- /dev/null +++ b/rfcs/20240311-gather-scatter-batching-dims.md @@ -0,0 +1,466 @@ +# [RFC] Add batching dims to `stablehlo.gather` and `stable.scatter` specification + +Status: Review
+Initial version: 03/11/2024
+Last updated: 03/11/2024
+Discussion thread: TBD + +## Overview + +This RFC proposes adding `operand_batching_dims` and +`start_indices_batching_dims` attributes to `stablehlo.gather`. +`operand_batching_dims` refers to the dimensions of the `operand` that are +treated as batch. `start_indices_batching_dims` refers to the dimensions of the +`start_indices` that are treated as batch. The corresponding dimension sizes +must be equal. The semantics is equivalent to concatenating the outputs of the +gather with each slices of `operand` and `start_indices`. + +Similarly, this RFC proposes adding `input_batching_dims` and +`scatter_indices_batching_dims` attributes to `stablehlo.scatter`. +`input_batching_dims` refers to the dimensions of each tensor in `inputs` that +are treated as batch. `scatter_indices_batching_dims` refers to the dimensions +of the `scatter_indices` that are treated as batch. + +## Motivation + +StableHLO gather and scatter ops currently have no way of specifying batch +dimensions that correspond between all operands and result, only the indices +tensor can have implicit batch dimensions that only exist in the +result/update tensor (`batch_dims` in specification). + +This is important when the user wants a batched/vectorized version of +gather/scatter across all operands and result (e.g., when using [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)). + +The current workaround is to use the implicit batch dimensions in the indices +tensor (all dimensions but `index_vector_dim`) with `stablehlo.concatenate` and +`stablehlo.iota` (and other ops like `stablehlo.clamp`) to mimic batch +dimensions in the operand. + +This isn't ideal as it hides the fact that those are batch dimensions that +correspond between the operands and result tensors. This information is crucial +when doing sharding propagation between the operands and result tensors of the +gather/scatter op, where partitioning batch dimensions across tensors is +trivial, as they can be sharded in the same way with no communication needed. +The above workaround requires pattern matching to identify those batch +dimensions, which can be error prone and hard to maintain. + +This proposal is inspired by `lhs_batching_dims` and `rhs_batching_dims` of +`stablehlo.dot_general`, which serve the same purpose. + +## Compatibility + +The new `stablehlo.gather` and `stablehlo.scatter` can be decomposed to the old +ops by applying the workaround above. For `stablehlo.gather` this would mean +making the `operand_batching_dims` as `collapsed_slice_dims` and +`start_indices_batching_dims` as implicit batch dimensions in `start_indices`, +by incrementing `index_vector_dim` by the size of `operand_batching_dims` (and +updating `start_index_map` accordingly), and concatenating an iota for each +batch dimension to the original `start_indices`. + +In the backward compatibility window (assuming 6 months), the old ops being +loaded will automatically get an empty tensor for these added attributes. + +## Alternatives considered + +We could do the workaround above (using concatenated iota for the indices +tensor), but in addition mark the batch dimension in each tensor using an +unregistered attribute, so that the information won't be lost and partitioning +systems can use it to partition and propagate through these batch dimensions. +However, unregistered attributes can be discarded at any time and are harder to +maintain (e.g. if the iota is replaced by another op, who is responsible for +updating the unregistered attributes). + +Another option is to start with `stablehlo.dynamic_slice` and +`stablehlo.dynamic_update_slice`, which are simpler ops that would suffice for a +lot of the use cases (but not all) that we've encountered. Currently +`stablehlo.gather` and `stable.scatter` are used to get a vectorized version of +`stablehlo.dynamic_slice` and `stablehlo.dynamic_update_slice` respectively. +However, for the same reason that they can be expressed by gather/scatter, we +propose to go with the more general solution, that would address all use cases. + +## Naming + +As indicated above, the proposed naming of the new attributes is inspired by +`lhs_batching_dims` and `rhs_batching_dims` of `stablehlo.dot_general`. Note +that the gather op specification already uses the term `batch_dims` +(`update_scatter_dims` for scatter op), to refer to all dimensions in the result +tensor that aren't offset dimensions, which have corresponding dimensions in the +start-indices tensor. The difference is that the proposed dimensions are +explicit batching dimensions that exist in all operands and result, whereas the +existing `batch_dims` are implicit batch dimensions (as they are derived by the +offset dimensions) that exist only in the start-indices and result tensors. + +## Proposed Specification + +### gather + +#### Semantics + +Gathers slices from `operand` tensor from offsets specified in `start_indices` +and produces a `result` tensor. + +The following diagram shows how elements in `result` map on elements in +`operand` using a concrete example. The diagram picks a few example `result` +indices and explains in detail which `operand` indices they correspond to. + +![gather](images/20240311-gather-scatter-batching-dims/gather.svg) + +More formally, `result[result_index] = operand[operand_index]` where: + + +* `batch_dims = [d for d in axes(result) and d not in offset_dims]`. +* `batch_index = result_index[batch_dims...]`. +* `start_index` is defined as: + * `start_indices[bi0, ..., :, ..., biN]` where `bi` are individual elements in + `batch_index` and `:` is inserted at the `index_vector_dim` index, if + `index_vector_dim` < `rank(start_indices)`. + * `[start_indices[batch_index]]` otherwise. +* For `d_operand` in `axes(operand)`, + * `full_start_index[d_operand] = clamp(start_index[d_start], 0, + dim(operand, d_operand) - slice_sizes[d_operand])` + if `d_operand = start_index_map[d_start]`. + * `full_start_index[d_operand] = 0` otherwise. +* For `d_operand` in `axes(operand)`, + * `full_batching_index[d_operand] = + batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]` + if `d_operand = operand_batching_dims[i_batching]` and + `d_start = start_indices_batching_dims[i_batching]`. + * `full_batching_index[d_operand] = 0` otherwise. +* `offset_index = result_index[offset_dims...]`. +* `full_offset_index = [oi0, ..., 0, ..., oiN]` where `oi` are individual + elements in `offset_index`, and `0` is inserted at indices from + `collapsed_slice_dims` and `operand_batching_dims`. +* `operand_index = full_start_index + full_batching_index + full_offset_index`. + + +If `indices_are_sorted` is `true` then the implementation can assume that +`start_indices` are sorted with respect to `start_index_map`, otherwise the +behavior is undefined. More formally, for all `i1 < i2` from `indices(result)`, +`full_start_index(i1) <= full_start_index(i2)`. + +#### Inputs + +| Label | Name | Type | Constraints | +|-------|-------------------------------|----------------------------------------------|--------------------------------------------| +| (I1) | `operand` | tensor or per-tensor quantized tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) | +| (I2) | `start_indices` | tensor of integer type | (C2-C3), (C14), (C17), (C22) | +| (I3) | `offset_dims` | 1-dimensional tensor constant of type `si64` | (C1), (C4-C5), (C22) | +| (I4) | `collapsed_slice_dims` | 1-dimensional tensor constant of type `si64` | (C1), (C6-C9), (C22) | +| (I5) | `operand_batching_dims` | 1-dimensional tensor constant of type `si64` | (C1), (C6), (C10-C12), (C16-C18), (C22) | +| (I6) | `start_indices_batching_dims` | 1-dimensional tensor constant of type `si64` | (C13-C17) | +| (I7) | `start_index_map` | 1-dimensional tensor constant of type `si64` | (C3), (C18-C19) | +| (I8) | `index_vector_dim` | constant of type `si64` | (C2-C3), (C15), (C22) | +| (I9) | `slice_sizes` | 1-dimensional tensor constant of type `si64` | (C9), (C12), (C20-C22) | +| (I10) | `indices_are_sorted` | constant of type `i1` | | + +#### Outputs + +| Name | Type | Constraints | +|----------|---------------------------------------|-----------------| +| `result` | tensor or per-tensor quantized tensor | (C5), (C22-C23) | + +#### Constraints + +* (C1) `rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + + size(operand_batching_dims)`. +* (C2) `0 <= index_vector_dim <= rank(start_indices)`. +* (C3) `size(start_index_map) = + index_vector_dim < rank(start_indices) ? + dim(start_indices, index_vector_dim) : 1`. +* (C4) `is_unique(offset_dims) and is_sorted(offset_dims)`. +* (C5) `0 <= offset_dims < rank(result)`. +* (C6) `is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))` +* (C7) `is_sorted(collapsed_slice_dims)`. +* (C8) `0 <= collapsed_slice_dims < rank(operand)`. +* (C9) `slice_sizes[collapsed_slice_dims...] <= 1`. +* (C10) `is_sorted(operand_batching_dims)`. +* (C11) `0 <= operand_batching_dims < rank(operand)`. +* (C12) `slice_sizes[operand_batching_dims...] <= 1`. +* (C13) `is_unique(start_indices_batching_dims)`. +* (C14) `0 <= start_indices_batching_dims < rank(start_indices)`. +* (C15) `index_vector_dim not in start_indices_batching_dims`. +* (C16) `size(operand_batching_dims) == size(start_indices_batching_dims)`. +* (C17) `dim(operand, operand_batching_dims...) = + dim(start_indices, start_indices_batching_dims...)`. +* (C18) `is_unique(concatenate(start_index_map, operand_batching_dims))`. +* (C19) `0 <= start_index_map < rank(operand)`. +* (C20) `size(slice_sizes) = rank(operand)`. +* (C21) `0 <= slice_sizes <= shape(operand)`. +* (C22) `shape(result) = combine(batch_dim_sizes, offset_dim_sizes)` where: + * `batch_dim_sizes = shape(start_indices)` except that the dimension size + of `start_indices` corresponding to `index_vector_dim` is not included. + * `offset_dim_sizes = slice_sizes` except that the dimension sizes in + `slice_sizes` corresponding to `collapsed_slice_dims` and + `operand_batching_dims` are not included. + * `combine` puts `batch_dim_sizes` at axes corresponding to `batch_dims` and + `offset_dim_sizes` at axes corresponding to `offset_dims`. +* (C23) `element_type(operand) = element_type(result)`. + +#### Examples + +```mlir +// %operand: [ +// [ +// [[1, 2], [3, 4], [5, 6], [7, 8]], +// [[9, 10],[11, 12], [13, 14], [15, 16]], +// [[17, 18], [19, 20], [21, 22], [23, 24]] +// ], +// [ +// [[25, 26], [27, 28], [29, 30], [31, 32]], +// [[33, 34], [35, 36], [37, 38], [39, 40]], +// [[41, 42], [43, 44], [45, 46], [47, 48]] +// ] +// ] +// %start_indices: [ +// [ +// [[0, 0], [1, 0], [2, 1]], +// [[0, 1], [1, 1], [0, 9]] +// ], +// [ +// [[0, 0], [2, 1], [2, 2]], +// [[1, 2], [0, 1], [1, 0]] +// ] +// ] +%result = "stablehlo.gather"(%operand, %start_indices) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3, 4], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [2, 1], + index_vector_dim = 3>, + slice_sizes = array, + indices_are_sorted = false +} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32> +// %result: [ +// [ +// [ +// [[1, 2], [3, 4]], +// [[3, 4], [5, 6]], +// [[13, 14], [15, 16]] +// ], +// [ +// [[33, 34], [35, 36]], +// [[35, 36], [37, 38]], +// [[41, 42], [43, 44]] +// ] +// ], +// [ +// [ +// [[1, 2], [3, 4]], +// [[13, 14], [15, 16]], +// [[21, 22], [23, 24]] +// ], +// [ +// [[43, 44], [45, 46]], +// [[33, 34], [35, 36]], +// [[27, 28], [29, 30]] +// ] +// ] +// ] +``` + + [More Examples](../stablehlo/tests/interpret/gather.mlir) + +### scatter + +#### Semantics + +Produces `results` tensors which are equal to `inputs` tensors except that +several slices specified by `scatter_indices` are updated with the values +`updates` using `update_computation`. + +The following diagram shows how elements in `updates...` map on elements in +`results...` using a concrete example. The diagram picks a few example +`updates...` indices and explains in detail which `results...` indices they +correspond to. + +![](images/20240311-gather-scatter-batching-dims/scatter.svg) + +More formally, for all `update_index` in `index_space(updates[0])`: + +* `update_scatter_dims = [d for d in axes(updates[0]) and d not in + update_window_dims]`. +* `update_scatter_index = update_index[update_scatter_dims...]`. +* `start_index` is defined as: + * `scatter_indices[si0, ..., :, ..., siN]` where `si` are individual + elements in `update_scatter_index` and `:` is inserted at the + `index_vector_dim` index, if `index_vector_dim` < + `rank(scatter_indices)`. + * `[scatter_indices[update_scatter_index]]` otherwise. +* For `d_input` in `axes(inputs[0])`, + * `full_start_index[d_input] = start_index[d_start]` if + `d_input = scatter_dims_to_operand_dims[d_start]`. + * `full_start_index[d_input] = 0` otherwise. +* For `d_input` in `axes(inputs[0])`, + * `full_batching_index[d_input] = + update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]` + if `d_input = input_batching_dims[i_batching]` and + `d_start = scatter_indices_batching_dims[i_batching]`. + * `full_batching_index[d_input] = 0` otherwise. +* `update_window_index = update_index[update_window_dims...]`. +* `full_window_index = [wi0, ..., 0, ..., wiN]` where `wi` are individual + elements in `update_window_index`, and `0` is inserted at indices from + `inserted_window_dims` and `input_batching_dims`. +* `result_index = full_start_index + full_batching_index + full_window_index`. + +Given that, `results = exec(schedule, inputs)`, where: + +* `schedule` is an implementation-defined permutation of + `index_space(updates[0])`. +* `exec([update_index, ...], results) = exec([...], updated_results)` where: + * If `result_index` is in bounds for `shape(results...)` + * `updates_converted = to_destination_type( + updates...[update_index], type(func_inputs(update_computation) + [len(func_inputs(update_computation))//2:])... )` + * `updated_values = update_computation(results...[result_index], + updates_converted)` + * `updated_results` is a copy of `results` with `results...[result_index]` + set to `updated_values...`. + * Otherwise + * `updated_results = results`. +* `exec([], results) = results`. + +If `indices_are_sorted` is `true` then the implementation can assume that +`scatter_indices` are sorted with respect to `scatter_dims_to_operand_dims`, +otherwise the behavior is undefined. More formally, for all `i1 < i2` from +`indices(result)`, `full_start_index(i1)` <= `full_start_index(i2)`. + +If `unique_indices` is `true` then the implementation can assume that all +`result_index` indices being scattered to are unique. If `unique_indices` is +`true` but the indices being scattered to are not unique then the behavior is +undefined. + +#### Inputs + +| Label | Name | Type | Constraints | +|-------|---------------------------------------|------------------------------------------------------------|------------------------------------------------------------| +| (I1) | `inputs` | variadic number of tensors or per-tensor quantized tensors | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) | +| (I2) | `scatter_indices` | tensor of integer type | (C4), (C15), (C19), (C22) | +| (I3) | `updates` | variadic number of tensors or per-tensor quantized tensors | (C3-C6), (C8) | +| (I4) | `update_window_dims` | 1-dimensional tensor constant of type `si64` | (C2), (C4), (C7-C8) | +| (I5) | `inserted_window_dims` | 1-dimensional tensor constant of type `si64` | (C2), (C4), (C9-C11) | +| (I6) | `input_batching_dims` | 1-dimensional tensor constant of type `si64` | (C2), (C4), (C9), (C12-13), (C17-18), (C20) | +| (I7) | `scatter_indices_batching_dims` | 1-dimensional tensor constant of type `si64` | (C14-C18) | +| (I8) | `scatter_dims_to_operand_dims` | 1-dimensional tensor constant of type `si64` | (C19-C21) | +| (I9) | `index_vector_dim` | constant of type `si64` | (C4), (C16), (C19), (C22) | +| (I10) | `indices_are_sorted` | constant of type `i1` | | +| (I11) | `unique_indices` | constant of type `i1` | | +| (I12) | `update_computation` | function | (C23) | + +#### Outputs + +| Name | Type | Constraints | +|-----------|------------------------------------------------------------|-------------| +| `results` | variadic number of tensors or per-tensor quantized tensors | (C24-C25) | + +#### Constraints + +* (C1) `same(shape(inputs...))`. +* (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + + size(input_batching_dims)`. +* (C3) `same(shape(updates...))`. +* (C4) `shape(updates[0]) = combine(update_scatter_dim_sizes, + update_window_dim_sizes)` where: + * `update_scatter_dim_sizes = shape(scatter_indices)` except that + the dimension size of `scatter_indices` corresponding to + `index_vector_dim` is not included. + * `update_window_dim_sizes <= shape(inputs[0])` except that + the dimension sizes in `inputs[0]` corresponding to `inserted_window_dims` + and `input_batching_dims` are not included. + * `combine` puts `update_scatter_dim_sizes` at axes corresponding to + `update_scatter_dims` and `update_window_dim_sizes` at axes corresponding + to `update_window_dims`. +* (C5) `0 < size(inputs) = size(updates) = N`. +* (C6) `element_type(updates...) = element_type(inputs...)`. +* (C7) `is_unique(update_window_dims) and is_sorted(update_window_dims)`. +* (C8) `0 <= update_window_dims < rank(updates[0])`. +* (C9) `is_unique(concatenate(inserted_window_dims, input_batching_dims))` +* (C10) `is_sorted(inserted_window_dims)`. +* (C11) `0 <= inserted_window_dims < rank(inputs[0])`. +* (C12) `is_sorted(input_batching_dims)`. +* (C13) `0 <= input_batching_dims < rank(inputs[0]))`. +* (C14) `is_unique(scatter_indices_batching_dims)`. +* (C15) `0 <= scatter_indices_batching_dims < rank(scatter_indices)`. +* (C16) `index_vector_dim not in scatter_indices_batching_dims`. +* (C17) `size(input_batching_dims) == size(scatter_indices_batching_dims)`. +* (C18) `dim(inputs[0], input_batching_dims...) = + dim(scatter_indices, scatter_indices_batching_dims...)`. +* (C19) `size(scatter_dims_to_operand_dims) = + index_vector_dim < rank(scatter_indices) ? + dim(scatter_indices, index_vector_dim) : 1`. +* (C20) `is_unique(concatenate(scatter_dims_to_operand_dims, + input_batching_dims))`. +* (C21) `0 <= scatter_dims_to_operand_dims < rank(inputs[0])`. +* (C22) `0 <= index_vector_dim <= rank(scatter_indices)`. +* (C23) `update_computation` has type `(tensor, ..., tensor, + tensor, ..., tensor) -> (tensor, ..., tensor)`, + where `is_promotable(element_type(inputs[i]), Ei)`. +* (C24) `shape(inputs...) = shape(results...)`. +* (C25) `element_type(results[i]) = Ei` for all `i` in `[0,N)`. + +#### Examples + +```mlir +// %input: [ +// [ +// [[1, 2], [3, 4], [5, 6], [7, 8]], +// [[9, 10],[11, 12], [13, 14], [15, 16]], +// [[17, 18], [19, 20], [21, 22], [23, 24]] +// ], +// [ +// [[25, 26], [27, 28], [29, 30], [31, 32]], +// [[33, 34], [35, 36], [37, 38], [39, 40]], +// [[41, 42], [43, 44], [45, 46], [47, 48]] +// ] +// ] +// %scatter_indices: [ +// [ +// [[0, 0], [1, 0], [2, 1]], +// [[0, 1], [1, 1], [0, 9]] +// ], +// [ +// [[0, 0], [2, 1], [2, 2]], +// [[1, 2], [0, 1], [1, 0]] +// ] +// ] +// %update: [ +// [ +// [[1, 1], [1, 1], [1, 1]], +// [[1, 1], [1, 1], [1, 1]] +// ], +// [ +// [[1, 1], [1, 1], [1, 1]], +// [[1, 1], [1, 1], [1, 1]] +// ] +// ] +%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () +}) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3, 4], + inserted_window_dims = [1], + input_batching_dims = [0], + scatter_indices_batching_dims = [1], + scatter_dims_to_operand_dims = [2, 1], + index_vector_dim = 3>, + indices_are_sorted = false, + unique_indices = false +} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64> +// %result: [ +// [ +// [[3, 4], [6, 7], [6, 7], [7, 8]], +// [[9, 10],[11, 12], [15, 16], [17, 18]], +// [[17, 18], [19, 20], [22, 23], [24, 25]] +// ], +// [ +// [[25, 26], [28, 29], [30, 31], [31, 32]], +// [[35, 36], [38, 39], [38, 39], [39, 40]], +// [[41, 42], [44, 45], [46, 47], [47, 48]] +// ] +// ] +``` + + [More Examples](../stablehlo/tests/interpret/scatter.mlir) diff --git a/rfcs/images/20240311-gather-scatter-batching-dims/gather.svg b/rfcs/images/20240311-gather-scatter-batching-dims/gather.svg new file mode 100644 index 00000000000..d2891862c5f --- /dev/null +++ b/rfcs/images/20240311-gather-scatter-batching-dims/gather.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/rfcs/images/20240311-gather-scatter-batching-dims/scatter.svg b/rfcs/images/20240311-gather-scatter-batching-dims/scatter.svg new file mode 100644 index 00000000000..b7581cf9031 --- /dev/null +++ b/rfcs/images/20240311-gather-scatter-batching-dims/scatter.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp index f938f2c233a..316cb1073ce 100644 --- a/stablehlo/dialect/Base.cpp +++ b/stablehlo/dialect/Base.cpp @@ -132,6 +132,7 @@ bool isCompatibleForHloTypeInference(TypeRange tp1, TypeRange tp2) { } bool isCompatibleForHloTypeInference(ArrayRef shape1, Type tp2) { + if (llvm::any_of(shape1, [&](int64_t x) { return x < 0; })) return false; auto stp2 = dyn_cast(tp2); if (!stp2) return false; return isCompatibleForHloTypeInference( @@ -141,11 +142,7 @@ bool isCompatibleForHloTypeInference(ArrayRef shape1, Type tp2) { bool isCompatibleForHloTypeInference(Value shape1, Type tp2) { SmallVector shapeVec1; if (!succeeded(matchInts(shape1, shapeVec1))) return true; - if (llvm::any_of(shapeVec1, [&](int64_t x) { return x < 0; })) return false; - auto stp2 = dyn_cast(tp2); - if (!stp2) return false; - auto tp1 = RankedTensorType::get(shapeVec1, stp2.getElementType()); - return isCompatibleForHloTypeInference(tp1, tp2); + return isCompatibleForHloTypeInference(shapeVec1, tp2); } LogicalResult matchInt(Value value, int64_t& result) { @@ -628,5 +625,75 @@ mlir::Speculation::Speculatability getShapedSpeculatability( : mlir::Speculation::NotSpeculatable; } +bool isValidStablehloQuantizedElementType(Type elementType) { + auto quantizedElementType = dyn_cast(elementType); + if (!quantizedElementType) return false; + + int64_t storageTypeMin = quantizedElementType.getStorageTypeMin(); + int64_t storageTypeMax = quantizedElementType.getStorageTypeMax(); + + SmallVector zeroPoints; + SmallVector scales; + if (auto quantizedPerTensorElementType = + dyn_cast(elementType)) { + zeroPoints.push_back(quantizedPerTensorElementType.getZeroPoint()); + scales.push_back(quantizedPerTensorElementType.getScale()); + } else { + auto quantizedPerAxisElementType = + cast(elementType); + zeroPoints.insert(zeroPoints.begin(), + quantizedPerAxisElementType.getZeroPoints().begin(), + quantizedPerAxisElementType.getZeroPoints().end()); + scales.insert(scales.begin(), + quantizedPerAxisElementType.getScales().begin(), + quantizedPerAxisElementType.getScales().end()); + } + + // quantized_type_c5 + auto maxPosFiniteNum = + APFloat::getLargest(quantizedElementType.getExpressedType() + .cast() + .getFloatSemantics()) + .convertToDouble(); + auto minPosFiniteNum = + APFloat::getSmallest(quantizedElementType.getExpressedType() + .cast() + .getFloatSemantics()) + .convertToDouble(); + if (llvm::any_of(scales, [&](double scale) { + return scale < minPosFiniteNum || scale > maxPosFiniteNum; + })) { + return false; + } + + // quantized_type_c8, quantized_type_c9 + if (llvm::any_of(zeroPoints, [&](int64_t zeroPoint) { + return storageTypeMin > zeroPoint || zeroPoint > storageTypeMax; + })) { + return false; + } + + return true; +} + +bool isValidQuantizedDimension(Type type) { + auto rankedType = dyn_cast(type); + if (!rankedType) return true; + + auto quantizedPerAxisElementType = + dyn_cast( + rankedType.getElementType()); + + if (!quantizedPerAxisElementType) return true; + + // quantized_type_c12, quantized_type_c13, quantized_type_c14 + int64_t quantDim = quantizedPerAxisElementType.getQuantizedDimension(); + int64_t numScales = + static_cast(quantizedPerAxisElementType.getScales().size()); + return quantDim >= 0 && quantDim < rankedType.getRank() && + (!rankedType.isDynamicDim(quantDim) && + numScales == rankedType.getDimSize(quantDim)); +} + } // namespace hlo } // namespace mlir diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h index 7b045db7064..d1ae4d59219 100644 --- a/stablehlo/dialect/Base.h +++ b/stablehlo/dialect/Base.h @@ -90,6 +90,16 @@ bool isCompatibleForHloTypeInference(Value shape1, Type tp2); // compatible with the given type for the purposes of HLO type inference. bool isCompatibleForHloTypeInference(ArrayRef shape1, Type tp2); +// Returns true if the given element-type is a mlir::quant::QuantizedType +// and follow the constraints corresponding to quantization parameters as +// mentioned in the StableHLO specification. +bool isValidStablehloQuantizedElementType(Type elementType); + +// Returns true if the given type is a ranked per-axis tensor type +// and follow the constraints corresponding to quantized dimension as +// mentioned in the StableHLO specification. +bool isValidQuantizedDimension(Type type); + // TODO(zhouxin) Move type inference related methods to TypeInference.cpp std::pair inferConcatenatedDimAndBound(int64_t leftSize, diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index 29bbf0c1ef7..88fa81414b6 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -52,6 +52,9 @@ def HLO_Complex : Complex>; // Quantized element type definitions. //===----------------------------------------------------------------------===// +def IsValidStablehloQuantizedElementType : CPred<"mlir::hlo::isValidStablehloQuantizedElementType($_self)">; +def IsValidQuantizedDimension : CPred<"mlir::hlo::isValidQuantizedDimension($_self)">; + // TODO(b/230381284): Upstream width-specific uniform quantized element types. class StableHLO_UniformQuantizedSignedInt : Type($_self)">, @@ -59,7 +62,7 @@ class StableHLO_UniformQuantizedSignedInt ".getStorageTypeIntegralWidth() == " # width>, CPred<"cast(" # "cast($_self).getStorageType()" # - ").isSignless()">]>, + ").isSignless()">, IsValidStablehloQuantizedElementType]>, "QI" # width # " type"> { string name = "UniformQuantizedSignedInt"; int bitwidth = width; @@ -71,7 +74,7 @@ class StableHLO_UniformQuantizedPerAxisSignedInt ".getStorageTypeIntegralWidth() == " # width>, CPred<"cast(" # "cast($_self).getStorageType()" # - ").isSignless()">]>, + ").isSignless()">, IsValidStablehloQuantizedElementType]>, "QI" # width # " type"> { string name = "UniformQuantizedPerAxisSignedInt"; int bitwidth = width; @@ -82,7 +85,7 @@ class StableHLO_UniformQuantizedUnsignedInt CPred<"cast($_self)" # ".getStorageTypeIntegralWidth() == " # width>, CPred<"!cast($_self)" # - ".isSigned()">]>, + ".isSigned()">, IsValidStablehloQuantizedElementType]>, "QUI" # width # " type"> { string name = "UniformQuantizedUnsignedInt"; int bitwidth = width; @@ -93,7 +96,7 @@ class StableHLO_UniformQuantizedPerAxisUnsignedInt CPred<"cast($_self)" # ".getStorageTypeIntegralWidth() == " # width>, CPred<"!cast($_self)" # - ".isSigned()">]>, + ".isSigned()">, IsValidStablehloQuantizedElementType]>, "QUI" # width # " type"> { string name = "UniformQuantizedPerAxisUnsignedInt"; int bitwidth = width; @@ -152,10 +155,12 @@ def HLO_Fp32Or64Tensor : RankedTensorOf<[HLO_Float32Or64]>; def HLO_QuantizedIntTensor : RankedTensorOf<[HLO_QuantizedInt]>; // per-axis quantized integer tensor type -def HLO_PerAxisQuantizedIntTensor : RankedTensorOf<[HLO_PerAxisQuantizedInt]>; +def HLO_PerAxisQuantizedIntTensor : RankedTensorOf<[HLO_PerAxisQuantizedInt], + [IsValidQuantizedDimension]>; // per-tensor or per-axis quantized integer tensor type -def HLO_QuantizedIntOrPerAxisQuantizedIntTensor : RankedTensorOf<[HLO_QuantizedInt, HLO_PerAxisQuantizedInt]>; +def HLO_QuantizedIntOrPerAxisQuantizedIntTensor : RankedTensorOf<[HLO_QuantizedInt, HLO_PerAxisQuantizedInt], + [IsValidQuantizedDimension]>; def HLO_PredTensor : RankedTensorOf<[HLO_Pred]>; @@ -163,7 +168,8 @@ def HLO_Tensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_ def HLO_NonQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex]>; -def HLO_TensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt]>; +def HLO_TensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt], + [IsValidQuantizedDimension]>; def HLO_ComplexTensor : RankedTensorOf<[HLO_Complex]>; @@ -188,7 +194,8 @@ def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>; // TODO(b/326463552): Remove these when CHLO no longer needs unranked dynamism. //===----------------------------------------------------------------------===// -def HLO_AnyTensor : TensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt]>; +def HLO_AnyTensor : TensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt], + [IsValidQuantizedDimension]>; def HLO_AnyPredTensor : TensorOf<[HLO_Pred]>; @@ -248,8 +255,8 @@ def HLO_StaticDimensionTensor : RankedTensorOf<[HLO_DimensionValue], [HasStaticS def HLO_StaticShapeTensor : StaticShapeTensorOf<[ HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>; -def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : StaticShapeTensorOf<[ - HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt]>; +def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt], + [IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">; def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>; diff --git a/stablehlo/dialect/ChloOps.cpp b/stablehlo/dialect/ChloOps.cpp index f39d3065cdb..915cbc9a240 100644 --- a/stablehlo/dialect/ChloOps.cpp +++ b/stablehlo/dialect/ChloOps.cpp @@ -334,23 +334,6 @@ LogicalResult ConstantLikeOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// MinimumBroadcastShapesOp -//===----------------------------------------------------------------------===// -LogicalResult MinimumBroadcastShapesOp::verify() { - // Check that the number of operands matches the number of outputs. - unsigned resultShapesCount = getResults().size(); - unsigned operandShapesCount = getShapes().size(); - if (operandShapesCount != resultShapesCount) - return emitOpError() << "number of operand shapes (" << operandShapesCount - << ") does not match number of result shapes (" - << resultShapesCount << ")"; - if (operandShapesCount < 2) - return emitOpError() << "number of operand shapes (" << operandShapesCount - << ") should be >= 2"; - return success(); -} - LogicalResult ConstantLikeOp::inferReturnTypeComponents( MLIRContext* /*context*/, std::optional location, ValueShapeRange operands, DictionaryAttr attributes, diff --git a/stablehlo/dialect/ChloOps.td b/stablehlo/dialect/ChloOps.td index c901d8de136..fc7a7054d49 100644 --- a/stablehlo/dialect/ChloOps.td +++ b/stablehlo/dialect/ChloOps.td @@ -834,71 +834,4 @@ row for matrices).}]>:$k }]; } -//===----------------------------------------------------------------------===// -// Helper ops -//===----------------------------------------------------------------------===// - -def CHLO_MinimumBroadcastShapesOp : - CHLO_Op<"minimum_broadcast_shapes", [Pure]> { - string summary = "Minimizes the rank of two or more shapes to be broadcasted"; - - string description = [{ - Given two or more 1D tensors representing shapes, returns one 1D tensor for - each operand, where operand `i` corresponds to output `i`. - - The returned tensors have the property that they specify a shape which is a - reshape of the corresponding input shape, and the broadcasted output shape - (using shape::BroadcastOp) of the returned shapes is a reshape of the - broadcasted output shape of the input shapes. Among all possibilities with - this property, the one is chosen which minimizes the rank of each returned - shape. - - The general idea of this op is that it can be used for ops which have a - broadcasting semantic to operate on shapes with a possibly smaller rank - while preserving equivalence of the computed values. After computing the - result of the op using reshaped operands, the result can be reshaped to the - result that would have been originally computed. - - Here is an example with two input shapes: - - ```mlir - chlo.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1], - [1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3] - ``` - - The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the - broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are - reshapes of each other, and also each output is a reshape of the - corresponding input. - }]; - - let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes); - let results = (outs Variadic<1DTensorOf<[Index]>>:$results); - - let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)"; - - let hasVerifier = 1; -} - -def CHLO_DynamicReshapeOp: CHLO_Op<"dynamic_reshape", [Pure, - DeclareOpInterfaceMethods]> { - let summary = "Reshape a tensor to a given, possibly dynamic, shape."; - let description = [{ - Reshapes `operand` to `output_shape`. This allows a single dimension to be - specified as -1, and it will be computed by the op to create a valid - reshape. - - Requires: - - The length of `output_shape` is equal to the rank of `result`. - - The number of elements in `operand` (that is, the product of extents of - its shape) is equal to the number of elements in `output_shape` (that is, - the product of values in `output_shape` with one dimension possibly - computed). - - All shape values should be at least -1, and only one extent can be -1. - }]; - - let arguments = (ins HLO_AnyTensor:$operand, HLO_DimensionTensor:$output_shape); - let results = (outs HLO_AnyTensor:$result); -} - #endif // STABLEHLO_DIALECT_CHLO_OPS diff --git a/stablehlo/dialect/StablehloAttrs.td b/stablehlo/dialect/StablehloAttrs.td index 28c896eb2e1..0fa7cc9a130 100644 --- a/stablehlo/dialect/StablehloAttrs.td +++ b/stablehlo/dialect/StablehloAttrs.td @@ -30,7 +30,7 @@ def GenericDenseI64ArrayAttr : Attr setNameFn) { mlir::TensorType type = getType(); - if (type.getElementType().isa()) { + if (isa(type.getElementType())) { setNameFn(getResult(), "c"); } else { setNameFn(getResult(), "cst"); diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index dabf8b3ef71..ed4d3a86dac 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -124,20 +124,27 @@ def StableHLO_IotaOp : StableHLO_Op<"iota", [Pure]> { def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [ConditionallySpeculatable, NoMemoryEffect]> { let summary = "DynamicIota operation"; let description = [{ - This operation is a work in progress, so it is not yet included in - the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. + Fills a `result` tensor with values in increasing order starting from zero + along the `iota_dimension` dimension. Informally, this operation does the same thing as IotaOp except that the result shape is specified dynamically via `output_shape`: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_iota + Example: ```mlir - %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<1xindex>) -> tensor<4xi32> + %output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64> + %0 = stablehlo.dynamic_iota %output_shape, dim = 0 : (tensor<2xi64>) -> tensor<4x5xi64> ``` }]; - let arguments = (ins HLO_StaticDimensionTensor:$output_shape, I64Attr:$iota_dimension); + let arguments = (ins + HLO_StaticDimensionTensor:$output_shape /*dynamic_iota_i1*/, + I64Attr:$iota_dimension /*dynamic_iota_i2*/ + ); let results = (outs HLO_Tensor:$result); let hasVerifier = 1; @@ -3519,58 +3526,4 @@ def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", let results = (outs HLO_Tensor); } -def StableHLO_ComputeReshapeShapeOp : StableHLO_Op< - "compute_reshape_shape", - [Pure, AllShapesMatch<["dynamic_shape", "result"]>]> { - let summary = "ComputeReshapeShape operation"; - let description = [{ - This operation is a work in progress, so it is not yet included in - the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. - - Informally, this operation computes an output_shape for DynamicReshapeOp - from the `num_elements` number of elements in an operand of DynamicReshapeOp - and the `dynamic_shape` shape provided to TF's reshape: - https://www.tensorflow.org/api_docs/python/tf/reshape - - For example, for `num_elements = 12` and `dynamic_shape = [2, -1]`, - the `result` is going to be `[2, 6]`. If operands are not valid (e.g. if - dimensions do not evenly divide the number of elements, or if there are - multiple -1 values in dimensions), this leads to undefined behavior. - - Example: - ```mlir - %result = stablehlo.compute_reshape_shape %num_elements, %dynamic_shape - : (index, tensor<2xi32>) -> tensor<2xi32> - ``` - }]; - - let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape); - let results = (outs 1DTensorOf<[AnyInteger, Index]>:$result); - - let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; -} - -def StableHLO_CstrReshapableOp : - StableHLO_Op<"cstr_reshapable", [Pure]> { - let summary = "CstrReshapable operation"; - let description = [{ - This operation is a work in progress, so it is not yet included in - the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. - - Informally, this operation creates a witness on the constraint that - ComputeReshapeShape would succeed with the provided operands. - - Example: - ```mlir - %result = stablehlo.cstr_reshapable %num_elements, %dynamic_shape - : (index, tensor<3xi32>) -> !shape.witness - ``` - }]; - - let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape); - let results = (outs Shape_WitnessType:$result); - - let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; -} - #endif // STABLEHLO_DIALECT_STABLEHLO_OPS diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 7ad992705d5..3973543437d 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -44,6 +44,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/Regex.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/IR/Attributes.h" @@ -66,9 +67,11 @@ limitations under the License. namespace mlir { namespace hlo { namespace { + //===----------------------------------------------------------------------===// // Utils for quantization specific verifications //===----------------------------------------------------------------------===// + template bool allQuantized(ArrayRef typeRange) { return llvm::all_of( @@ -468,6 +471,27 @@ LogicalResult verifyAddOp(std::optional location, Operation* op, return success(); } +// If the shape operand is constant, checks that it is compatible with the +// result's shape. Emits an error if the shapes are incompatible. +LogicalResult verifyShapeOperandIsCompatibleWithResultType( + std::optional loc, Value shapeOperand, Type resultType) { + if (SmallVector shape; + succeeded(matchInts(shapeOperand, shape)) && + !isCompatibleForHloTypeInference(shape, resultType)) { + std::string str; + llvm::raw_string_ostream os(str); + llvm::interleaveComma(shape, os, [&](int64_t i) { os << i; }); + return emitOptionalError(loc, "output shape [", os.str(), + "] is incompatible with return type of operation ", + resultType); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Verifiers +//===----------------------------------------------------------------------===// + LogicalResult verifyTransposeOp(std::optional location, Type operandType, ArrayRef permutation, Type resultType) { @@ -490,7 +514,7 @@ LogicalResult verifyTransposeOp(std::optional location, if (operandQDim != permutation[resultQDim]) return emitOptionalError(location, "operand quantization_dimension ", operandQDim, " is not same as permutation[", - resultQDim, "] ", permutation[resultQDim]); + resultQDim, "] = ", permutation[resultQDim]); } return success(); } @@ -3356,9 +3380,9 @@ LogicalResult verifyBroadcastInDimOp(std::optional location, auto resultQDim = resultQType.getQuantizedDimension(); if (resultQDim != broadcastDimensions[operandQDim]) return emitOptionalError(location, "result quantization_dimension ", - resultQDim, " not same as broadcast_dimensions ", - operandQDim, " (", - broadcastDimensions[operandQDim], ")"); + resultQDim, " not same as broadcast_dimensions[", + operandQDim, + "] = ", broadcastDimensions[operandQDim]); if (operandType.getDimSize(operandQDim) == 1) { for (int64_t j = 0; j != resultType.getDimSize(resultQDim); ++j) { if (resultQType.getScales()[j] != operandQType.getScales()[0]) @@ -3760,11 +3784,9 @@ LogicalResult verifyDynamicBroadcastInDimOp( } } - if (!isCompatibleForHloTypeInference(outputDimensions, resultType)) - return emitOptionalError( - location, - "output_dimensions are incompatible with return type of operation ", - resultType); + if (failed(verifyShapeOperandIsCompatibleWithResultType( + location, outputDimensions, resultType))) + return failure(); return success(); } @@ -3772,17 +3794,17 @@ LogicalResult verifyDynamicBroadcastInDimOp( LogicalResult verifyDynamicIotaOp(std::optional location, Value outputShape, int64_t iotaDimension, Value result) { - auto shape = cast(result.getType()); + auto resultType = cast(result.getType()); - if (!isCompatibleForHloTypeInference(outputShape, shape)) - return emitOptionalError( - location, "output_shape is incompatible with return type of operation ", - result.getType()); - - if (iotaDimension >= shape.getRank() || iotaDimension < 0) + // dynamic_iota_c1 + if (iotaDimension >= resultType.getRank() || iotaDimension < 0) return emitOptionalError( location, "iota dimension cannot go beyond the output rank or be negative."); + // dynamic_iota_c2 + if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape, + resultType))) + return failure(); return success(); } @@ -3850,18 +3872,9 @@ LogicalResult verifyDynamicReshapeOp(std::optional location, } } - if (SmallVector shape; - succeeded(matchInts(outputShape, shape)) && - !isCompatibleForHloTypeInference(shape, resultType)) { - std::string str; - llvm::raw_string_ostream os(str); - os << "["; - llvm::interleaveComma(shape, os, [&](int64_t i) { os << i; }); - os << "]"; - return emitOptionalError(location, "output_shape ", os.str(), - " is incompatible with return type of operation ", - resultType); - } + if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape, + resultType))) + return failure(); return success(); } diff --git a/stablehlo/dialect/VhloOps.td b/stablehlo/dialect/VhloOps.td index 03cf7b8ec32..3c5617b7fdc 100644 --- a/stablehlo/dialect/VhloOps.td +++ b/stablehlo/dialect/VhloOps.td @@ -296,14 +296,6 @@ def VHLO_CompositeOpV1 : VHLO_Op<"composite_v1", "0.19.0", "current"> { let results = (outs Variadic:$results); } -// TODO(#8): ComputeReshapeShapeOp is not part of the StableHLO spec. -// This operation is a work in progress, so it is not yet included in -// the StableHLO specification. -def VHLO_ComputeReshapeShapeOpV1 : VHLO_Op<"compute_reshape_shape_v1", "0.9.0", "current"> { - let arguments = (ins VHLO_AnyType:$num_elements, VHLO_AnyType:$dynamic_shape); - let results = (outs VHLO_AnyType:$result); -} - def VHLO_ConcatenateOpV1 : VHLO_Op<"concatenate_v1", "0.9.0", "current"> { let arguments = (ins Variadic:$inputs, @@ -370,14 +362,6 @@ def VHLO_CrossReplicaSumOpV1 : VHLO_Op<"cross-replica-sum_v1", "0.9.0", "current let results = (outs VHLO_AnyType:$result); } -// TODO(#8): CstrReshapableOp is not part of the StableHLO spec. -// This operation is a work in progress, so it is not yet included in -// the StableHLO specification. -def VHLO_CstrReshapableOpV1 : VHLO_Op<"cstr_reshapable_v1", "0.9.0", "current"> { - let results = (outs VHLO_AnyType:$result); - let arguments = (ins VHLO_AnyType:$num_elements, VHLO_AnyType:$dynamic_shape); -} - // TODO(#1187): api_version is different between VHLO and the StableHLO spec: // in VHLO it's an enum, in the spec it's an si32. // TODO(#629): operand_layouts/result_layouts are not part of the spec. diff --git a/stablehlo/reference/CMakeLists.txt b/stablehlo/reference/CMakeLists.txt index ba75441b464..b1d4f8c35e4 100644 --- a/stablehlo/reference/CMakeLists.txt +++ b/stablehlo/reference/CMakeLists.txt @@ -95,6 +95,7 @@ add_mlir_dialect_library(InterpreterOps InterpreterOpsIncGen LINK_LIBS PUBLIC + StablehloBase StablehloReferenceValue StablehloReferenceNumPy StablehloReferenceOps diff --git a/stablehlo/reference/InterpreterOps.cpp b/stablehlo/reference/InterpreterOps.cpp index 2e20b86d235..ec502ad88a5 100644 --- a/stablehlo/reference/InterpreterOps.cpp +++ b/stablehlo/reference/InterpreterOps.cpp @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" #include "mlir/Support/DebugStringHelper.h" #include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/Base.h" #include "stablehlo/reference/NumPy.h" #include "stablehlo/reference/Ops.h" #include "stablehlo/reference/ProcessGrid.h" diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 7f3c263b773..d98c11b423a 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -579,6 +579,11 @@ SmallVector eval(Region ®ion, lhs, rhs, lhsBatchingDimensions, rhsBatchingDimensions, lhsContractingDimensions, rhsContractingDimensions, op.getType()); scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto iotaDimension = op.getIotaDimension(); + auto outputShape = scope.findTensor(op.getOutputShape()); + auto result = dynamicIotaOp(iotaDimension, outputShape, op.getType()); + scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); auto startIndices = scope.findTensors(op.getStartIndices()); @@ -1576,6 +1581,14 @@ Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs, return result; } +Tensor dynamicIotaOp(Axis iotaDimension, const Tensor &outputShape, + ShapedType resultType) { + if (resultType.hasStaticShape()) return iotaOp(iotaDimension, resultType); + + llvm::report_fatal_error( + "dynamic result types are not supported at the moment"); +} + Tensor dynamicSliceOp(const Tensor &operand, ArrayRef startIndices, const Sizes &sliceSizes, ShapedType resultType) { Tensor result(resultType); diff --git a/stablehlo/reference/Ops.h b/stablehlo/reference/Ops.h index af546e45ac9..bb014ffa1a6 100644 --- a/stablehlo/reference/Ops.h +++ b/stablehlo/reference/Ops.h @@ -91,6 +91,8 @@ Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs, const Axes &lhsContractingDimensions, const Axes &rhsContractingDimensions, ShapedType resultType); +Tensor dynamicIotaOp(Axis iotaDimension, const Tensor &outputShape, + ShapedType resultType); Tensor dynamicSliceOp(const Tensor &operand, ArrayRef startIndices, const Sizes &sliceSizes, ShapedType resultType); Tensor dynamicUpdateSliceOp(const Tensor &operand, const Tensor &update, diff --git a/stablehlo/tests/BUILD.bazel b/stablehlo/tests/BUILD.bazel index 7116e4a05fb..80e7c8030cb 100644 --- a/stablehlo/tests/BUILD.bazel +++ b/stablehlo/tests/BUILD.bazel @@ -32,6 +32,7 @@ cc_library( strip_include_prefix = ".", deps = [ ":check_ops_inc_gen", + "//:base", "//:reference_errors", "//:reference_numpy", "//:reference_tensor", diff --git a/stablehlo/tests/CMakeLists.txt b/stablehlo/tests/CMakeLists.txt index 676f9cdd925..84246c0ee7f 100644 --- a/stablehlo/tests/CMakeLists.txt +++ b/stablehlo/tests/CMakeLists.txt @@ -64,6 +64,7 @@ add_mlir_dialect_library(CheckOps CheckOpsIncGen LINK_LIBS PUBLIC + StablehloBase StablehloReferenceNumPy StablehloReferenceTensor MLIRIR diff --git a/stablehlo/tests/CheckOps.cpp b/stablehlo/tests/CheckOps.cpp index c39d156793f..a4595ef94d5 100644 --- a/stablehlo/tests/CheckOps.cpp +++ b/stablehlo/tests/CheckOps.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/Support/Path.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/DebugStringHelper.h" +#include "stablehlo/dialect/Base.h" #include "stablehlo/reference/Errors.h" #include "stablehlo/reference/NumPy.h" #include "stablehlo/reference/Tensor.h" diff --git a/stablehlo/tests/interpret/dynamic_iota.mlir b/stablehlo/tests/interpret/dynamic_iota.mlir new file mode 100644 index 00000000000..5f9c79af3bd --- /dev/null +++ b/stablehlo/tests/interpret/dynamic_iota.mlir @@ -0,0 +1,17 @@ +// RUN: stablehlo-translate --interpret -split-input-file %s + +func.func @dynamic_iota_op_test_si64_dim_0() { + %output_shape = stablehlo.constant dense<[3, 4]> : tensor<2xi64> + %0 = stablehlo.dynamic_iota %output_shape, dim = 0 : (tensor<2xi64>) -> tensor<3x4xi64> + check.expect_eq_const %0, dense<[[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]> : tensor<3x4xi64> + func.return +} + +// ----- + +func.func @dynamic_iota_op_test_si64_dim_1() { + %output_shape = stablehlo.constant dense<[3, 4]> : tensor<2xi64> + %0 = stablehlo.dynamic_iota %output_shape, dim = 1 : (tensor<2xi64>) -> tensor<3x4xi64> + check.expect_eq_const %0, dense<[[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<3x4xi64> + func.return +} diff --git a/stablehlo/tests/ops_chlo.mlir b/stablehlo/tests/ops_chlo.mlir index bb3e77cfea5..16f399f0c37 100644 --- a/stablehlo/tests/ops_chlo.mlir +++ b/stablehlo/tests/ops_chlo.mlir @@ -73,33 +73,6 @@ func.func @constant_like(%arg0: tensor<1x2xi64>) -> (tensor<1x2xi32>) { // ----- -// CHECK-LABEL: func @minimum_broadcast_shapes -func.func @minimum_broadcast_shapes(%lhs: tensor, %rhs: tensor) - -> (tensor, tensor) { - %0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs : - tensor, tensor -> tensor, tensor - func.return %0, %1 : tensor, tensor -} - -// ----- - -func.func @minimum_broadcast_shapes_mismatch_operand_and_result_count(%lhs: tensor, %rhs: tensor) { - // expected-error @+1{{number of operand shapes (2) does not match number of result shapes (1)}} - %0 = chlo.minimum_broadcast_shapes %lhs, %rhs : - tensor, tensor -> tensor - func.return -} - -// ----- - -func.func @minimum_broadcast_shapes_one_operand(%arg: tensor) { - // expected-error @+1{{number of operand shapes (1) should be >= 2}} - %0 = chlo.minimum_broadcast_shapes %arg : tensor -> tensor - func.return -} - -// ----- - func.func @top_k(%arg0 : tensor) { // expected-error @+2 {{failed to infer returned types}} // @expected-error @+1{{operand's rank must be at least 1}} diff --git a/stablehlo/tests/ops_chlo_roundtrip.mlir b/stablehlo/tests/ops_chlo_roundtrip.mlir index f7dc94cbdee..37b303878ae 100644 --- a/stablehlo/tests/ops_chlo_roundtrip.mlir +++ b/stablehlo/tests/ops_chlo_roundtrip.mlir @@ -396,27 +396,6 @@ func.func @chlo_top_k(%arg : tensor<16x16xi32>) -> (tensor<16x8xi32>, tensor<16x return %1#0, %1#1 : tensor<16x8xi32>, tensor<16x8xi32> } -// CHECK-LABEL: func @chlo_minimum_broadcast_shapes( -// CHECK-SAME: %[[A0:.*]]: tensor, -// CHECK-SAME: %[[A1:.*]]: tensor -// CHECK: %[[T:.*]]:2 = chlo.minimum_broadcast_shapes %[[A0]], %[[A1]] : tensor, tensor -> tensor, tensor -// CHECK: return %[[T]]#0, %[[T]]#1 : tensor, tensor -func.func @chlo_minimum_broadcast_shapes(%lhs: tensor, %rhs: tensor) -> (tensor, tensor) { - %0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs : - tensor, tensor -> tensor, tensor - func.return %0, %1 : tensor, tensor -} - -// CHECK-LABEL: func @chlo_reshape_dynamic( -// CHECK-SAME: %[[A0:.*]]: tensor, -// CHECK-SAME: %[[A1:.*]]: tensor<2xi32> -// CHECK: %[[T:.*]] = "chlo.dynamic_reshape"(%[[A0]], %[[A1]]) : (tensor, tensor<2xi32>) -> tensor -// CHECK: return %[[T]] : tensor -func.func @chlo_reshape_dynamic(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor { - %0 = "chlo.dynamic_reshape"(%arg0, %arg1) : (tensor, tensor<2xi32>) -> tensor - func.return %0 : tensor -} - // CHECK-LABEL: func @chlo_erf_inv // CHECK-SAME: %[[A0:.*0]]: tensor<16x16xf32>) // CHECK: chlo.erf_inv %[[A0]] : tensor<16x16xf32> -> tensor<16x16xf32> diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index dba67cdc266..22907a942a1 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -51,18 +51,18 @@ func.func @all_reduce_with_promotable_types(%operand: tensor) -> tensor>) - -> tensor> { + -> tensor> { %result = "stablehlo.all_reduce"(%operand) ({ - ^bb0(%arg0: tensor>, %arg1: tensor>): - %0 = stablehlo.add %arg0, %arg1 : tensor> - "stablehlo.return"(%0) : (tensor>) -> () + ^bb0(%arg0: tensor>, %arg1: tensor>): + %0 = stablehlo.add %arg0, %arg1 : tensor> + "stablehlo.return"(%0) : (tensor>) -> () }) { replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, channel_handle = #stablehlo.channel_handle - } : (tensor>) -> tensor> + } : (tensor>) -> tensor> - func.return %result : tensor> + func.return %result : tensor> } // ----- @@ -334,16 +334,16 @@ func.func @reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tens // CHECK-LABEL: func @reduce_scatter_with_promotable_quantized_types func.func @reduce_scatter_with_promotable_quantized_types( %data: tensor<4x16x!quant.uniform>) -> - tensor<4x4x!quant.uniform> { + tensor<4x4x!quant.uniform> { %0 = "stablehlo.reduce_scatter"(%data) ({ - ^bb0(%arg2: tensor>, %arg3: tensor>): - %1 = stablehlo.add %arg2, %arg3 : tensor> - "stablehlo.return"(%1) : (tensor>) -> () + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = stablehlo.add %arg2, %arg3 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, scatter_dimension = 1 : i64, channel_handle = #stablehlo.channel_handle, - use_global_device_ids} : (tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform> - func.return %0 : tensor<4x4x!quant.uniform> + use_global_device_ids} : (tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform> + func.return %0 : tensor<4x4x!quant.uniform> } // ----- @@ -1026,7 +1026,7 @@ func.func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape // ----- func.func @dynamic_broadcast_in_dim_output_dimensions_negative_size(%arg0: tensor<4xf32>) -> tensor<3x4xf32> { - // @expected-error@+2 {{output_dimensions are incompatible with return type of operation 'tensor<3x4xf32>'}} + // @expected-error@+2 {{output shape [-1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}} %0 = stablehlo.constant dense<[-1, 4]> : tensor<2xi64> %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32> return %1 : tensor<3x4xf32> @@ -1035,7 +1035,7 @@ func.func @dynamic_broadcast_in_dim_output_dimensions_negative_size(%arg0: tenso // ----- func.func @dynamic_broadcast_in_dim_output_dimensions_mismatching_size(%arg0: tensor<4xf32>) -> tensor<3x4xf32> { - // @expected-error@+2 {{output_dimensions are incompatible with return type of operation 'tensor<3x4xf32>'}} + // @expected-error@+2 {{output shape [1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}} %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32> return %1 : tensor<3x4xf32> @@ -1043,6 +1043,22 @@ func.func @dynamic_broadcast_in_dim_output_dimensions_mismatching_size(%arg0: te // ----- +func.func @dynamic_broadcast_in_dim_output_dimensions_match_result(%arg0: tensor<4xf32>) -> tensor<3x4xf32> { + %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi64> + %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32> + return %1 : tensor<3x4xf32> +} + +// ----- + +func.func @dynamic_broadcast_in_dim_output_dimensions_compatible_with_result(%arg0: tensor<4xf32>) -> tensor { + %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi64> + %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor + return %1 : tensor +} + +// ----- + func.func @dynamic_broadcast_in_dim_negative_size(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { // expected-error@+1 {{broadcast_dimensions contains invalid value -1 for result with rank 3}} %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> @@ -3174,7 +3190,7 @@ func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor, %shape: ten // ----- func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-error@+2 {{output_shape [2, 2] is incompatible with return type of operation 'tensor<1x4xf32>'}} + // expected-error@+2 {{output shape [2, 2] is incompatible with return type of operation 'tensor<1x4xf32>'}} %0 = stablehlo.constant dense<[2, 2]> : tensor<2xi64> %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32> return %1 : tensor<1x4xf32> @@ -3182,6 +3198,22 @@ func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) - // ----- +func.func @dynamic_reshape_output_shape_matches_result(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { + %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> + %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32> + return %1 : tensor<1x4xf32> +} + +// ----- + +func.func @dynamic_reshape_output_shape_compatible_with_result(%arg0: tensor<4xf32>) -> tensor { + %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> + %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor + return %1 : tensor +} + +// ----- + func.func @dynamic_reshape_dynamic_output_shape(%arg0: tensor, %shape: tensor) -> tensor<1x4xf32> { // expected-error@+1 {{op operand #1 must be statically shaped}} %0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor) -> tensor<1x4xf32> @@ -4846,9 +4878,9 @@ func.func @quantized_dot_i8_per_axis(%arg0: tensor<2x2x!quant.uniform>, %arg1: tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> - func.return %0: tensor<2x2x!quant.uniform> +func.func @quantized_dot_i4(%arg0: tensor<2x2x!quant.uniform>, %arg1: tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { + %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> + func.return %0: tensor<2x2x!quant.uniform> } // ----- @@ -5125,7 +5157,7 @@ func.func @add_c4(%arg0: tensor<1x!quant.uniform>) { func.func @add_c3(%arg0: tensor<1x!quant.uniform>) { // expected-error@+1 {{mismatched operands and result quantization storage types}} - %0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> + %0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> func.return } @@ -5149,15 +5181,15 @@ func.func @add_c5(%arg0: tensor<1x!quant.uniform>) { func.func @add_c6(%arg0: tensor<1x2x!quant.uniform>) { // expected-error@+1 {{quantization_dimension of lhs and result are not same}} - %0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> + %0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> func.return } // ----- -func.func @add_c7(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2x!quant.uniform>) { +func.func @add_c7(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2x!quant.uniform>) { // expected-error@+1 {{quantization_dimension of rhs and result are not same}} - %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> func.return } @@ -5431,7 +5463,7 @@ func.func @dynamic_iota_invalid_iota_dimension_too_big() -> tensor { // ----- func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> { - // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}} + // @expected-error@+2 {{output shape [-1] is incompatible with return type of operation 'tensor<4xf32>'}} %0 = stablehlo.constant dense<[-1]> : tensor<1xi64> %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32> func.return %1 : tensor<4xf32> @@ -5440,7 +5472,7 @@ func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> { // ----- func.func @dynamic_iota_output_shape_mismatching_size() -> tensor<4xf32> { - // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}} + // @expected-error@+2 {{output shape [1] is incompatible with return type of operation 'tensor<4xf32>'}} %0 = stablehlo.constant dense<[1]> : tensor<1xi64> %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32> func.return %1 : tensor<4xf32> @@ -5448,6 +5480,22 @@ func.func @dynamic_iota_output_shape_mismatching_size() -> tensor<4xf32> { // ----- +func.func @dynamic_iota_output_shape_matches_result() -> tensor<4xf32> { + %0 = stablehlo.constant dense<[4]> : tensor<1xi64> + %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32> + func.return %1 : tensor<4xf32> +} + +// ----- + +func.func @dynamic_iota_output_shape_compatible_with_result() -> tensor { + %0 = stablehlo.constant dense<[4]> : tensor<1xi64> + %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor + func.return %1 : tensor +} + +// ----- + func.func @first(%arg0: tensor, %arg1: tensor) -> tensor { func.return %arg0 : tensor } diff --git a/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/tests/ops_stablehlo_quantized.mlir index d9a073f2c1b..a4ff81a48dc 100644 --- a/stablehlo/tests/ops_stablehlo_quantized.mlir +++ b/stablehlo/tests/ops_stablehlo_quantized.mlir @@ -29,7 +29,7 @@ func.func @ops_per_axis_quantization( // %arg1 can be a per-axis Quantized func.func @dot_general_per_axis_quantization( %arg0: tensor<2x3x4x!quant.uniform>, - %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { + %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0], @@ -38,8 +38,8 @@ func.func @dot_general_per_axis_quantization( rhs_contracting_dimensions = [1] > } : (tensor<2x3x4x!quant.uniform>, - tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> - func.return %0 : tensor<2x4x5x!quant.uniform> + tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> + func.return %0 : tensor<2x4x5x!quant.uniform> } // ----- @@ -856,30 +856,31 @@ func.func @broadcast_in_dim_c1_mismatch_zero_point( // ----- func.func @broadcast_in_dim_c6( - %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) { - // expected-error@+1 {{result quantization_dimension 3 not same as broadcast_dimensions 2 (2)}} + %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) { + // expected-error@+1 {{result quantization_dimension 3 not same as broadcast_dimensions[2] = 2}} %broadcast_in_dim = "stablehlo.broadcast_in_dim" (%arg0) {broadcast_dimensions = array - } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.1:-30, 0.1:-30}>> + } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) -> + tensor<1x2x3x2x!quant.uniform:f32:3, {0.1:-30, 0.1:-30}>> func.return } // ----- func.func @broadcast_in_dim_c6( - %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) { + %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) { // expected-error@+1 {{mismatch result scale 0 (2.000000e-01) and operand scale 0 (1.000000e-01)}} %broadcast_in_dim = "stablehlo.broadcast_in_dim" (%arg0) {broadcast_dimensions = array - } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.2:2, 0.5:-20}>> + } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.2:2, 0.5:-20}>> func.return } // ----- func.func @broadcast_in_dim_c6( - %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) { + %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) { // expected-error@+1 {{mismatch result zero_point 1 (-20) and operand zero_point 0 (-30)}} %broadcast_in_dim = "stablehlo.broadcast_in_dim" (%arg0) {broadcast_dimensions = array - } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.1:-30, 0.1:-20}>> + } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.1:-30, 0.1:-20}>> func.return } @@ -903,28 +904,28 @@ func.func @transpose_c1_mismatched_zp(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) { +func.func @transpose_c1_mismatched_scales(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) { // expected-error@+1 {{expect same quantization scales and zero_points}} %transpose = "stablehlo.transpose"(%arg0) {permutation = array - } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.9:-20}>> + } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) -> tensor<1x2x2x!quant.uniform:f32:0, {0.2:-30}>> func.return } // ----- -func.func @transpose_c1_mismatched_zps(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) { +func.func @transpose_c1_mismatched_zps(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-20}>>) { // expected-error@+1 {{expect same quantization scales and zero_points}} %transpose = "stablehlo.transpose"(%arg0) {permutation = array - } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-10}>> + } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-20}>>) -> tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>> func.return } // ----- -func.func @transpose_c4(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) { - // expected-error@+1 {{operand quantization_dimension 0 is not same as permutation[1] 2}} +func.func @transpose_c4(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) { + // expected-error@+1 {{operand quantization_dimension 0 is not same as permutation[1] = 2}} %transpose = "stablehlo.transpose"(%arg0) {permutation = array - } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x2x!quant.uniform:f32:1, {0.1:-30, 0.5:-20}>> + } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) -> tensor<1x2x2x!quant.uniform:f32:1, {0.1:-30, 0.5:-20}>> func.return } @@ -956,7 +957,7 @@ func.func @reshape_c1(%arg0: tensor<1x2x2x!quant.uniform>){ func.func @reshape_c3_mismatch_qdim_size(%arg0: tensor<1x2x3x4x5x!quant.uniform>){ // expected-error@+1 {{expect same quantization dimension size for operand and result}} - %reshape = "stablehlo.reshape" (%arg0) : (tensor<1x2x3x4x5x!quant.uniform>) -> tensor<2x3x20x!quant.uniform> + %reshape = "stablehlo.reshape" (%arg0) : (tensor<1x2x3x4x5x!quant.uniform>) -> tensor<2x3x20x!quant.uniform> func.return } @@ -1111,7 +1112,7 @@ func.func @dot_general_c15_per_tensor(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x func.func @dot_general_c15_per_axis( %arg0: tensor<2x3x4x!quant.uniform>, - %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { + %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { // expected-error@+1 {{Zero points of rhs should be 0}} %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< @@ -1121,15 +1122,15 @@ func.func @dot_general_c15_per_axis( rhs_contracting_dimensions = [1] > } : (tensor<2x3x4x!quant.uniform>, - tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> - func.return %0 : tensor<2x4x5x!quant.uniform> + tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> + func.return %0 : tensor<2x4x5x!quant.uniform> } // ----- func.func @dot_general_c16( %arg0: tensor<2x3x4x!quant.uniform>, - %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { + %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { // expected-error@+1 {{Quantization dimension of rhs should not be in the contracting dimension of rhs}} %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< @@ -1139,8 +1140,8 @@ func.func @dot_general_c16( rhs_contracting_dimensions = [0] > } : (tensor<2x3x4x!quant.uniform>, - tensor<2x3x5x!quant.uniform>) -> tensor<3x4x5x!quant.uniform> - func.return %0 : tensor<3x4x5x!quant.uniform> + tensor<2x3x5x!quant.uniform>) -> tensor<3x4x5x!quant.uniform> + func.return %0 : tensor<3x4x5x!quant.uniform> } // ----- @@ -1175,7 +1176,7 @@ func.func @dot_general_c18(%arg0: tensor<2x3x4x!quant.uniform>, // ----- -func.func @dot_general_c19(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { +func.func @dot_general_c19(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { // expected-error@+1 {{mismatched rhs and result quantization granularity}} %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< @@ -1184,8 +1185,8 @@ func.func @dot_general_c19(%arg0: tensor<2x3x4x!quant.uniform>, lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1] > - } : (tensor<2x3x4x!quant.uniform>, tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> - func.return %0 : tensor<2x4x5x!quant.uniform> + } : (tensor<2x3x4x!quant.uniform>, tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> + func.return %0 : tensor<2x4x5x!quant.uniform> } // ----- @@ -1202,3 +1203,59 @@ func.func @dot_general_c20(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5x!quant. } : (tensor<2x3x4xf32>, tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5xf32> func.return %0 : tensor<2x4x5xf32> } + +// ----- + +func.func @quantized_element_type_c8(%arg0: tensor<1x2x!quant.uniform:f32, 1.0:300>>) { + // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}} + %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform:f32, 1.0:300>> + func.return +} + +// ----- + +func.func @quantized_element_type_c8(%arg0: tensor<1x2x!quant.uniform:f32, 1.0:-129>>) { + // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}} + %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform:f32, 1.0:-129>> + func.return +} + +// ----- + +func.func @quantized_element_type_c5(%arg0: tensor<1x2x!quant.uniform>) { + // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}} + %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform> + func.return +} + +// ----- + +func.func @quantized_element_type_c5(%arg0: tensor<1x2x!quant.uniform>) { + // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}} + %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform> + func.return +} + +// ----- + +func.func @quantized_element_type_c12(%arg0: tensor<1x5x2x!quant.uniform:f32:-1, {0.1:-30, 0.1:-30}>>) { + // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x5x2x!quant.uniform>'}} + %0 = stablehlo.add %arg0, %arg0 : tensor<1x5x2x!quant.uniform:f32:-1, {0.1:-30, 0.1:-30}>> + func.return +} + +// ----- + +func.func @quantized_element_type_c13(%arg0: tensor<1x5x2x!quant.uniform:f32:10, {0.1:-30, 0.1:-30}>>) { + // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x5x2x!quant.uniform>'}} + %0 = stablehlo.add %arg0, %arg0 : tensor<1x5x2x!quant.uniform:f32:10, {0.1:-30, 0.1:-30}>> + func.return +} + +// ----- + +func.func @quantized_element_type_c14(%arg0: tensor<1x5x2x!quant.uniform:f32:1, {0.1:-30,0.1:-30 }>>) { + // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x5x2x!quant.uniform>'}} + %0 = stablehlo.add %arg0, %arg0 : tensor<1x5x2x!quant.uniform:f32:1, {0.1:-30,0.1:-30 }>> + func.return +} diff --git a/stablehlo/tests/print_stablehlo.mlir b/stablehlo/tests/print_stablehlo.mlir index 3cf998dd4f7..a215bd1836f 100644 --- a/stablehlo/tests/print_stablehlo.mlir +++ b/stablehlo/tests/print_stablehlo.mlir @@ -172,30 +172,6 @@ func.func @type_convert_ops(%arg0 : tensor<2xf32>) -> () { "stablehlo.return"() : () -> () } -// CHECK-LABEL: func @no_attr_ops -func.func @no_attr_ops(%arg0 : tensor<4xf32>, %arg1 : !stablehlo.token, - %arg2 : tensor<4xi32>, %arg3 : index) -> !stablehlo.token { - // CHECK-NEXT: %0 = stablehlo.clamp %arg0, %arg0, %arg0 : tensor<4xf32> - // CHECK-NEXT: %1 = stablehlo.complex %arg0, %arg0 : tensor<4xcomplex> - // CHECK-NEXT: %2 = stablehlo.compute_reshape_shape %arg3, %arg2 : (index, tensor<4xi32>) -> tensor<4xi32> - // CHECK-NEXT: %3 = stablehlo.uniform_quantize %arg0 : (tensor<4xf32>) -> tensor<4x!quant.uniform> - // CHECK-NEXT: %4 = stablehlo.uniform_dequantize %3 : (tensor<4x!quant.uniform>) -> tensor<4xf32> - // CHECK-NEXT: %5 = stablehlo.after_all %arg1, %arg1 : !stablehlo.token - // CHECK-NEXT: %6 = stablehlo.after_all : !stablehlo.token - // CHECK-NEXT: %7 = stablehlo.cstr_reshapable %arg3, %arg2 : (index, tensor<4xi32>) -> !shape.witness - // CHECK-NEXT: %8 = stablehlo.compute_reshape_shape %arg3, %arg2 : (index, tensor<4xi32>) -> tensor<4xi32> - %0 = "stablehlo.clamp"(%arg0, %arg0, %arg0) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %1 = "stablehlo.complex"(%arg0, %arg0) {} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> - %2 = "stablehlo.compute_reshape_shape"(%arg3, %arg2) : (index, tensor<4xi32>) -> tensor<4xi32> - %3 = "stablehlo.uniform_quantize"(%arg0) : (tensor<4xf32>) -> tensor<4x!quant.uniform> - %4 = "stablehlo.uniform_dequantize"(%3) : (tensor<4x!quant.uniform>) -> tensor<4xf32> - %5 = "stablehlo.after_all"(%arg1, %arg1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token - %6 = "stablehlo.after_all"() : () -> !stablehlo.token - %7 = "stablehlo.cstr_reshapable"(%arg3, %arg2) : (index, tensor<4xi32>) -> !shape.witness - %8 = "stablehlo.compute_reshape_shape"(%arg3, %arg2) : (index, tensor<4xi32>) -> tensor<4xi32> - "stablehlo.return"(%arg1) : (!stablehlo.token) -> () -} - // CHECK-LABEL: func @multiple_attr_ops func.func @multiple_attr_ops(%arg0 : tensor<3x4xf32>) -> () { // CHECK: %0 = stablehlo.reduce_precision %arg0, format = e8m10 : tensor<3x4xf32> diff --git a/stablehlo/tests/shape_legalize_to_stablehlo.mlir b/stablehlo/tests/shape_legalize_to_stablehlo.mlir index f69dca5d554..f47d94889c7 100644 --- a/stablehlo/tests/shape_legalize_to_stablehlo.mlir +++ b/stablehlo/tests/shape_legalize_to_stablehlo.mlir @@ -1,31 +1,5 @@ // RUN: stablehlo-opt --shape-legalize-to-stablehlo --split-input-file --verify-diagnostics %s | FileCheck %s -// CHECK-LABEL: func.func @compute_reshape_shape -func.func @compute_reshape_shape(%arg0: index, %arg1: tensor<2xi32>) -> tensor<2xi32> { - %0 = stablehlo.compute_reshape_shape %arg0, %arg1 : (index, tensor<2xi32>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> - // CHECK: %[[ARG0_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor - // CHECK-NEXT: %[[TMP0:.*]] = stablehlo.constant dense<-1> : tensor - // CHECK-NEXT: %[[INPUT_SIZE0x1:.*]] = stablehlo.slice %arg1 [0:1] : (tensor<2xi32>) -> tensor<1xi32> - // CHECK-NEXT: %[[INPUT_SIZE0:.*]] = stablehlo.reshape %[[INPUT_SIZE0x1]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[TMP1:.*]] = stablehlo.multiply %[[TMP0]], %[[INPUT_SIZE0]] : tensor - // CHECK-NEXT: %[[INPUT_SIZE1x1:.*]] = stablehlo.slice %arg1 [1:2] : (tensor<2xi32>) -> tensor<1xi32> - // CHECK-NEXT: %[[INPUT_SIZE1:.*]] = stablehlo.reshape %[[INPUT_SIZE1x1]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[INPUT_SIZE_PRODUCT:.*]] = stablehlo.multiply %[[TMP1]], %[[INPUT_SIZE1]] : tensor - // CHECK-NEXT: %[[COMPUTED_SIZE:.*]] = stablehlo.divide %[[ARG0_I32]], %[[INPUT_SIZE_PRODUCT]] : tensor - // CHECK-NEXT: %[[M1:.*]] = stablehlo.constant dense<-1> : tensor - // CHECK-NEXT: %[[INPUT_SIZE0_EQ_M1:.*]] = stablehlo.compare EQ, %[[INPUT_SIZE0]], %[[M1]], NOTYPE : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[RESULT_SIZE0:.*]] = stablehlo.select %[[INPUT_SIZE0_EQ_M1]], %[[COMPUTED_SIZE]], %[[INPUT_SIZE0]] : tensor, tensor - // CHECK-NEXT: %[[RESULT_SIZE0x1:.*]] = stablehlo.reshape %[[RESULT_SIZE0]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[INPUT_SIZE1_EQ_M1:.*]] = stablehlo.compare EQ, %[[INPUT_SIZE1]], %[[M1]], NOTYPE : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[RESULT_SIZE1:.*]] = stablehlo.select %[[INPUT_SIZE1_EQ_M1]], %[[COMPUTED_SIZE]], %[[INPUT_SIZE1]] : tensor, tensor - // CHECK-NEXT: %[[RESULT_SIZE1x1:.*]] = stablehlo.reshape %[[RESULT_SIZE1]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.concatenate %[[RESULT_SIZE0x1]], %[[RESULT_SIZE1x1]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2xi32> -} - -// ----- - // CHECK-LABEL: func.func @num_elements_tensor_to_index func.func @num_elements_tensor_to_index(%arg0: tensor<2xindex>) -> index { %0 = shape.num_elements %arg0 : tensor<2xindex> -> index @@ -213,18 +187,6 @@ func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xinde // ----- -func.func @mhlo_cstr_reshapable(%arg0: index, %arg1: tensor<2xindex>, %arg2: tensor) -> tensor { - // expected-error@+1 {{failed to legalize operation 'stablehlo.cstr_reshapable' that was explicitly marked illegal}} - %0 = stablehlo.cstr_reshapable %arg0, %arg1 : (index, tensor<2xindex>) -> !shape.witness - %1 = shape.assuming %0 -> (tensor) { - %2 = stablehlo.dynamic_reshape %arg2, %arg1 : (tensor, tensor<2xindex>) -> tensor - shape.assuming_yield %2 : tensor - } - func.return %1 : tensor -} - -// ----- - // CHECK-LABEL: func @const_shape func.func @const_shape() -> tensor<2xindex> { %0 = shape.const_shape [6, 4] : tensor<2xindex> diff --git a/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/tests/stablehlo_refine_shapes.mlir index 11db39fe933..d756482607e 100644 --- a/stablehlo/tests/stablehlo_refine_shapes.mlir +++ b/stablehlo/tests/stablehlo_refine_shapes.mlir @@ -231,85 +231,6 @@ func.func @eval_compare_lt() -> tensor { // ----- -// CHECK-LABEL: func @eval_compute_reshape_shape -func.func @eval_compute_reshape_shape() -> tensor<4xi32> { - // CHECK-NOT: stablehlo.compute_reshape_shape - // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[2, 128, 2, 64]> : tensor<4xi32> - // CHECK: return [[RESULT]] - %0 = arith.constant dense<[2, 128, 2, 64]> : tensor<4xi32> - %1 = arith.constant 32768 : index - %2 = stablehlo.compute_reshape_shape %1, %0 : (index, tensor<4xi32>) -> tensor<4xi32> - func.return %2 : tensor<4xi32> -} - -// ----- - -// CHECK-LABEL: func @eval_compute_reshape_shape_zero_dynamic_shape -func.func @eval_compute_reshape_shape_zero_dynamic_shape() -> tensor<0xi32> { - // CHECK-NOT: stablehlo.compute_reshape_shape - // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<> : tensor<0xi32> - // CHECK: return [[RESULT]] - %0 = arith.constant dense<[]> : tensor<0xi32> - %1 = arith.constant 32768 : index - %2 = stablehlo.compute_reshape_shape %1, %0 : (index, tensor<0xi32>) -> tensor<0xi32> - func.return %2 : tensor<0xi32> -} - -// ----- - -// CHECK-LABEL: func @eval_compute_reshape_shape_unknown_dimension -func.func @eval_compute_reshape_shape_unknown_dimension() -> (tensor<4xi32>, tensor<1xi32>) { - // CHECK-NOT: stablehlo.compute_reshape_shape - // CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<[2, 128, 2, 64]> : tensor<4xi32> - // CHECK: [[RESULT2:%.*]] = stablehlo.constant dense<32768> : tensor<1xi32> - // CHECK: return [[RESULT1]], [[RESULT2]] - %0 = arith.constant dense<[2, -1, 2, 64]> : tensor<4xi32> - %1 = arith.constant dense<[-1]> : tensor<1xi32> - %2 = arith.constant 32768 : index - %3 = stablehlo.compute_reshape_shape %2, %0 : (index, tensor<4xi32>) -> tensor<4xi32> - %4 = stablehlo.compute_reshape_shape %2, %1 : (index, tensor<1xi32>) -> tensor<1xi32> - func.return %3, %4 : tensor<4xi32>, tensor<1xi32> -} - -// ----- - -// CHECK-LABEL: func @eval_compute_reshape_shape_two_unknown_dims -func.func @eval_compute_reshape_shape_two_unknown_dims() -> tensor<4xi32> { - // CHECK: [[RESULT:%.*]] = stablehlo.compute_reshape_shape - // CHECK: return [[RESULT]] - %0 = arith.constant dense<[2, -1, -1, 64]> : tensor<4xi32> - %1 = arith.constant 32768 : index - %2 = stablehlo.compute_reshape_shape %1, %0 : (index, tensor<4xi32>) -> tensor<4xi32> - func.return %2 : tensor<4xi32> -} - -// ----- - -// CHECK-LABEL: func @eval_compute_reshape_shape_non_divisible_shape -func.func @eval_compute_reshape_shape_non_divisible_shape() -> (tensor<4xi32>, tensor<4xi32>) { - // CHECK: [[RESULT1:%.*]] = stablehlo.compute_reshape_shape - // CHECK: [[RESULT2:%.*]] = stablehlo.compute_reshape_shape - // CHECK: return [[RESULT1]], [[RESULT2]] - %0 = arith.constant dense<[2, 128, 3, -1]> : tensor<4xi32> - %1 = arith.constant dense<[2, 128, 2, 63]> : tensor<4xi32> - %2 = arith.constant 32768 : index - %3 = stablehlo.compute_reshape_shape %2, %0 : (index, tensor<4xi32>) -> tensor<4xi32> - %4 = stablehlo.compute_reshape_shape %2, %1 : (index, tensor<4xi32>) -> tensor<4xi32> - func.return %3, %4 : tensor<4xi32>, tensor<4xi32> -} - -// ----- - -// CHECK-LABEL: func @eval_compute_reshape_shape_non_specializable -func.func @eval_compute_reshape_shape_non_specializable(%arg0 : tensor<4xi32>, %arg1 : index) -> tensor<4xi32> { - // CHECK: [[RESULT:%.*]] = stablehlo.compute_reshape_shape - // CHECK: return [[RESULT]] - %0 = stablehlo.compute_reshape_shape %arg1, %arg0 : (index, tensor<4xi32>) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} - -// ----- - // CHECK-LABEL: func @eval_concatenate_1d func.func @eval_concatenate_1d() -> tensor<4xi64> { // CHECK-NOT: stablehlo.concatenate diff --git a/stablehlo/tests/verify_reduce.mlir b/stablehlo/tests/verify_reduce.mlir index 2e2a7d20239..2e5d16f5627 100644 --- a/stablehlo/tests/verify_reduce.mlir +++ b/stablehlo/tests/verify_reduce.mlir @@ -80,13 +80,13 @@ func.func @reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor>, - %arg1: tensor>) -> tensor<4x!quant.uniform> { - %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform>, tensor>) -> tensor<4x!quant.uniform> - reducer(%arg2: tensor>, %arg3: tensor>) { - %1 = stablehlo.add %arg2, %arg3 : tensor> - stablehlo.return %1 : tensor> + %arg1: tensor>) -> tensor<4x!quant.uniform> { + %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform>, tensor>) -> tensor<4x!quant.uniform> + reducer(%arg2: tensor>, %arg3: tensor>) { + %1 = stablehlo.add %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> } - return %0 : tensor<4x!quant.uniform> + return %0 : tensor<4x!quant.uniform> } // ----- @@ -425,17 +425,17 @@ func.func @reduce_c6(%arg0: tensor<4x4x!quant.uniform>, // ----- func.func @reduce_c6(%arg0: tensor<4x4x!quant.uniform>, - %arg1: tensor>) -> tensor<4x!quant.uniform> { + %arg1: tensor>) -> tensor<4x!quant.uniform> { - // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform>, - tensor>) -> tensor<4x!quant.uniform> + tensor>) -> tensor<4x!quant.uniform> - reducer(%arg2: tensor>, %arg3: tensor>) { - %1 = stablehlo.add %arg2, %arg3 : tensor> - stablehlo.return %1 : tensor> + reducer(%arg2: tensor>, %arg3: tensor>) { + %1 = stablehlo.add %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> } - return %0 : tensor<4x!quant.uniform> + return %0 : tensor<4x!quant.uniform> } // The following invalid cases arises while parsing a pretty-printed version of reduce-op will "non-eligible" inner-op. diff --git a/stablehlo/tests/verify_reduce_window.mlir b/stablehlo/tests/verify_reduce_window.mlir index f3fd57faa92..123a023db37 100644 --- a/stablehlo/tests/verify_reduce_window.mlir +++ b/stablehlo/tests/verify_reduce_window.mlir @@ -82,19 +82,19 @@ func.func @reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, // CHECK-LABEL: func @reduce_window_with_promotable_quantized_types func.func @reduce_window_with_promotable_quantized_types(%arg0: tensor<4x2x!quant.uniform>, - %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { + %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ - ^bb0(%a0: tensor>, %b0: tensor>): - %1 = stablehlo.add %a0, %b0 : tensor> - "stablehlo.return"(%1) : (tensor>) -> () + ^bb0(%a0: tensor>, %b0: tensor>): + %1 = stablehlo.add %a0, %b0 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () }) { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, window_dimensions = array, window_strides = array } - : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) - func.return %0 : tensor<2x2x!quant.uniform> + : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) + func.return %0 : tensor<2x2x!quant.uniform> } // ----- @@ -619,20 +619,20 @@ func.func @reduce_window_c13(%arg0: tensor<4x2x!quant.uniform>, - %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { + %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { - // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ - ^bb0(%a0: tensor>, %b0: tensor>): - %1 = stablehlo.add %a0, %b0 : tensor> - "stablehlo.return"(%1) : (tensor>) -> () + ^bb0(%a0: tensor>, %b0: tensor>): + %1 = stablehlo.add %a0, %b0 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () }) { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, window_dimensions = array, window_strides = array } - : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) - func.return %0 : tensor<2x2x!quant.uniform> + : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) + func.return %0 : tensor<2x2x!quant.uniform> } // ----- diff --git a/stablehlo/tests/verify_scatter.mlir b/stablehlo/tests/verify_scatter.mlir index 135ea404c08..383a7a89e70 100644 --- a/stablehlo/tests/verify_scatter.mlir +++ b/stablehlo/tests/verify_scatter.mlir @@ -73,12 +73,12 @@ func.func @scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, // CHECK: func @scatter_with_promotable_quantized_types func.func @scatter_with_promotable_quantized_types(%input_tensor: tensor<200x100x300x!quant.uniform>, - %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300x!quant.uniform>) -> - tensor<200x100x300x!quant.uniform> { + %scatter_indices: tensor<10x2xi16>, %updates: tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> { %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ - ^bb0(%lhs: tensor>, %rhs: tensor>): - %add = stablehlo.add %lhs, %rhs : tensor> - "stablehlo.return"(%add) : (tensor>) -> () + ^bb0(%lhs: tensor>, %rhs: tensor>): + %add = stablehlo.add %lhs, %rhs : tensor> + "stablehlo.return"(%add) : (tensor>) -> () }) { scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [1], @@ -88,10 +88,10 @@ func.func @scatter_with_promotable_quantized_types(%input_tensor: tensor<200x100 >, indices_are_sorted = true, unique_indices = true - } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi32>, + } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi16>, tensor<10x300x!quant.uniform>) -> - tensor<200x100x300x!quant.uniform> - func.return %0 : tensor<200x100x300x!quant.uniform> + tensor<200x100x300x!quant.uniform> + func.return %0 : tensor<200x100x300x!quant.uniform> } // ----- @@ -898,14 +898,14 @@ func.func @scatter_c15(%input_tensor: tensor<200x100x300x!quant.uniform>, - %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300x!quant.uniform>) -> - tensor<200x100x300x!quant.uniform> { + %scatter_indices: tensor<10x2xi16>, %updates: tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> { - // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ - ^bb0(%lhs: tensor>, %rhs: tensor>): - %add = stablehlo.add %lhs, %rhs : tensor> - "stablehlo.return"(%add) : (tensor>) -> () + ^bb0(%lhs: tensor>, %rhs: tensor>): + %add = stablehlo.add %lhs, %rhs : tensor> + "stablehlo.return"(%add) : (tensor>) -> () }) { scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [1], @@ -915,7 +915,7 @@ func.func @scatter_c15(%input_tensor: tensor<200x100x300x!quant.uniform, indices_are_sorted = true, unique_indices = true - } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi32>, tensor<10x300x!quant.uniform>) -> - tensor<200x100x300x!quant.uniform> - func.return %0 : tensor<200x100x300x!quant.uniform> + } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi16>, tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> + func.return %0 : tensor<200x100x300x!quant.uniform> } diff --git a/stablehlo/tests/verify_select_and_scatter.mlir b/stablehlo/tests/verify_select_and_scatter.mlir index d8d2b9d610f..1d7af4b37e1 100644 --- a/stablehlo/tests/verify_select_and_scatter.mlir +++ b/stablehlo/tests/verify_select_and_scatter.mlir @@ -57,7 +57,7 @@ func.func @select_and_scatter_with_promotable_quantized_types( %arg0: tensor<10x24x24x64x!quant.uniform>, %arg1: tensor<10x12x12x64x!quant.uniform>, %arg2 : tensor>) -> - tensor<10x24x24x64x!quant.uniform> { + tensor<10x24x24x64x!quant.uniform> { %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor>, %arg4: tensor>): @@ -67,17 +67,17 @@ func.func @select_and_scatter_with_promotable_quantized_types( } : (tensor>, tensor>) -> tensor "stablehlo.return"(%2) : (tensor) -> () }, { - ^bb0(%arg3: tensor>, %arg4: tensor>): - %2 = stablehlo.add %arg3, %arg4 : tensor> - "stablehlo.return"(%2) : (tensor>) -> () + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = stablehlo.add %arg3, %arg4 : tensor> + "stablehlo.return"(%2) : (tensor>) -> () }) { window_dimensions = array, window_strides = array } : (tensor<10x24x24x64x!quant.uniform>, tensor<10x12x12x64x!quant.uniform>, tensor>) -> - tensor<10x24x24x64x!quant.uniform> - func.return %1 : tensor<10x24x24x64x!quant.uniform> + tensor<10x24x24x64x!quant.uniform> + func.return %1 : tensor<10x24x24x64x!quant.uniform> } // ----- diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir index 7d3abcd1c24..3b661cedd56 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir @@ -1000,13 +1000,6 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1096,13 +1089,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc index 6a23d459f80..57c7217934d 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir index 26cbf16e2a8..ef44a742a69 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir @@ -1000,13 +1000,6 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1096,13 +1089,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc index 1359e5aa0d9..207f87cff80 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir index 8c14213f51c..4264a708b01 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir @@ -1000,13 +1000,6 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1096,13 +1089,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc index 5419e894caf..2b5e8532454 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir index 3042dd39316..1f0c6135a53 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir @@ -1000,13 +1000,6 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1096,13 +1089,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc index d3fd5b4a209..8cf9ae29e88 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir index 29157786fcd..bee620a3a0f 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir @@ -1000,13 +1000,6 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1096,13 +1089,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc index 70753428fff..e2dfd9857e5 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir index 888648b1076..c69529c05f3 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir @@ -1006,13 +1006,6 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1102,13 +1095,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc index deaa7ce3387..f09e1ece959 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir index 4af6ff02f01..629925d15da 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir @@ -1018,13 +1018,6 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1114,13 +1107,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc index 8fc24e44f8e..585bff6c9e2 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir index e2da21d4cf0..8bccdad3efe 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir @@ -1037,13 +1037,6 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1133,13 +1126,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc index a807120a74c..700069cc9bb 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir index f879ac17e06..1c325fe8dab 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir @@ -1037,13 +1037,6 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1133,13 +1126,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc index c23c9aa4d35..7e0959152ef 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir index d8ee1a47dc8..a9fdca57ee6 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir @@ -1072,13 +1072,6 @@ func.func @op_composite(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1168,13 +1161,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir.bc index 7b377862210..677f8cb116d 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir index 9f04f45056a..4dcdd8a4aa0 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir @@ -1000,13 +1000,6 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1096,13 +1089,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc index c5d4a4f2f30..c0ce9717ad4 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir index d9a32643410..39fb0de0634 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir @@ -1072,13 +1072,6 @@ func.func @op_composite(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_compute_reshape_shape" -func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { - // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> - %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> - func.return %0 : tensor<1xindex> -} - // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ @@ -1168,13 +1161,6 @@ func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "op_cstr_reshapable" -func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { - // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 - %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness - func.return %0 : !shape.witness -} - // CHECK-LABEL: "op_custom_call" func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK: "vhlo.custom_call_v1"(%arg0) <{ diff --git a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp index cea610fcbad..1339f64743b 100644 --- a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp @@ -436,38 +436,6 @@ struct ConvertSelectOp final } }; -struct ConvertDynamicReshapeOp final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::chlo::DynamicReshapeOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - TypedValue tensor = op.getOperand(); - TypedValue shape = op.getOutputShape(); - - auto shapeTy = cast(shape.getType()); - auto resultTy = cast(op.getType()); - - Value inputShape = rewriter.create(loc, tensor); - Value numEls = rewriter.create(loc, inputShape); - Value cstr = - rewriter.create(loc, numEls, shape); - rewriter.replaceOpWithNewOp( - op, cstr, [&](OpBuilder &b, Location l) { - Value computedShape = - b.create(l, shapeTy, - numEls, shape); - SmallVector result; - result.push_back(b.create( - l, resultTy, tensor, computedShape)); - return result; - }); - - return success(); - } -}; - //===----------------------------------------------------------------------===// // Decomposition Patterns. //===----------------------------------------------------------------------===// @@ -2163,7 +2131,6 @@ struct ChloLegalizeToStablehloPass final LogicalResult initialize(MLIRContext *context) override { target = std::make_shared(*context); target->addIllegalDialect(); - target->addLegalOp(); target->addLegalDialect( context, patterns, 5); - patterns - ->add( - context); + patterns->add(context); } static void populateChloDecompositionPatterns(MLIRContext *context, diff --git a/stablehlo/transforms/MapStablehloToVhlo.h b/stablehlo/transforms/MapStablehloToVhlo.h index 95dd139e571..5c43b9eca05 100644 --- a/stablehlo/transforms/MapStablehloToVhlo.h +++ b/stablehlo/transforms/MapStablehloToVhlo.h @@ -73,7 +73,6 @@ MAP_STABLEHLO_TO_VHLO(CollectivePermuteOp, V1) MAP_STABLEHLO_TO_VHLO(CompareOp, V1) MAP_STABLEHLO_TO_VHLO(ComplexOp, V1) MAP_STABLEHLO_TO_VHLO(CompositeOp, V1) -MAP_STABLEHLO_TO_VHLO(ComputeReshapeShapeOp, V1) MAP_STABLEHLO_TO_VHLO(ConcatenateOp, V1) MAP_STABLEHLO_TO_VHLO(ConstantOp, V1) MAP_STABLEHLO_TO_VHLO(ConvertOp, V1) @@ -81,7 +80,6 @@ MAP_STABLEHLO_TO_VHLO(ConvolutionOp, V1) MAP_STABLEHLO_TO_VHLO(CosineOp, V1) MAP_STABLEHLO_TO_VHLO(CreateTokenOp, V1) MAP_STABLEHLO_TO_VHLO(CrossReplicaSumOp, V1) -MAP_STABLEHLO_TO_VHLO(CstrReshapableOp, V1) MAP_STABLEHLO_TO_VHLO(CustomCallOp, V1) MAP_STABLEHLO_TO_VHLO(DivOp, V1) MAP_STABLEHLO_TO_VHLO(DotGeneralOp, V1) diff --git a/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp b/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp index a44bad060a9..8aec09e944b 100644 --- a/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp @@ -134,78 +134,6 @@ Value convertToConstantOrI32Cast(Value value, PatternRewriter& rewriter) { return castToI32(rewriter, value.getLoc(), value); } -struct ConvertComputeReshapeShapeOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ComputeReshapeShapeOp op, - PatternRewriter& rewriter) const override { - // Cast num_elements from index to tensor. - // Cast dynamic_shape from tensor to tensor if needed. - // (stablehlo.compute_reshape_shape supports both index- and integer-based - // dynamic_shape operands). - // This cannot error out given how the operation is currently defined. - auto numElementsI32 = castToI32(rewriter, op.getLoc(), op.getNumElements()); - auto dynamicShapeI32x1 = - castToI32(rewriter, op.getLoc(), op.getDynamicShape()); - if (!numElementsI32 || !dynamicShapeI32x1) - return rewriter.notifyMatchFailure(op, "cast to i32 failed"); - auto rank = cast(dynamicShapeI32x1.getType()).getNumElements(); - - // Obtain individual input dimension sizes and also compute the product of - // all these dimension sizes. - auto i32Type = RankedTensorType::get({}, rewriter.getI32Type()); - Value dynamicNumElementsI32 = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get(i32Type, -1)); - SmallVector dynamicSizesI32; - for (auto i = 0; i < rank; ++i) { - auto dynamicSizeI32x1 = rewriter.create( - op.getLoc(), dynamicShapeI32x1, rewriter.getDenseI64ArrayAttr(i), - rewriter.getDenseI64ArrayAttr(i + 1), - rewriter.getDenseI64ArrayAttr(1)); - auto dynamicSizeI32 = - rewriter.create(op.getLoc(), i32Type, dynamicSizeI32x1); - dynamicSizesI32.push_back(dynamicSizeI32); - dynamicNumElementsI32 = rewriter.create( - op.getLoc(), dynamicNumElementsI32, dynamicSizeI32); - } - - // Compute the dimension size that corresponds to -1 in dynamic_shape. - // If such a dimension doesn't exist, then this value doesn't matter. - auto computedSizeI32 = rewriter.create(op.getLoc(), numElementsI32, - dynamicNumElementsI32); - - // Compute individual output dimension sizes, replacing a potential -1 - // with the value computed above. - auto i32x1Type = RankedTensorType::get({1}, rewriter.getI32Type()); - Value minusOneI32 = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get(i32Type, -1)); - SmallVector resultSizesI32x1; - for (auto i = 0; i < rank; ++i) { - auto eqMinusOne = - rewriter.create(op.getLoc(), dynamicSizesI32[i], - minusOneI32, ComparisonDirection::EQ); - auto resultSizeI32 = rewriter.create( - op.getLoc(), eqMinusOne, computedSizeI32, dynamicSizesI32[i]); - auto resultSizeI32x1 = - rewriter.create(op.getLoc(), i32x1Type, resultSizeI32); - resultSizesI32x1.push_back(resultSizeI32x1); - } - auto resultI32 = - rewriter.create(op.getLoc(), resultSizesI32x1, - /*dimension=*/0); - - // Cast the result to tensor if needed. - // (stablehlo.compute_reshape_shape supports both index- and integer-based - // results). - // This cannot error out given how the operation is currently defined. - auto resultIndex = maybeCastToIndex(op.getResult(), resultI32, rewriter); - if (!resultIndex || resultIndex.getType() != op.getType()) - return rewriter.notifyMatchFailure(op, "cast to index failed"); - rewriter.replaceOp(op, resultIndex); - return success(); - } -}; - struct ConvertNumElementsOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -623,7 +551,6 @@ struct ShapeLegalizeToStablehloPass target = std::make_shared(*context); target->addIllegalDialect(); target->addIllegalDialect(); - target->addIllegalOp(); target->addIllegalOp(); target->addIllegalOp(); target->addDynamicallyLegalDialect( @@ -662,7 +589,6 @@ struct ShapeLegalizeToStablehloPass void populateShapeToStablehloPatterns(MLIRContext* context, RewritePatternSet* patterns) { - patterns->add(context); patterns->add(context); patterns->add(context); patterns->add(context); diff --git a/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/StablehloAggressiveSimplification.cpp index 8daac7c027a..9d4cb22d708 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplification.cpp +++ b/stablehlo/transforms/StablehloAggressiveSimplification.cpp @@ -822,7 +822,7 @@ struct UnusedResultReduceOpCanon final auto newOp = rewriter.create( op.getLoc(), newInputs, newInitVals, - op.getDimensionsAttr().cast(), newElementTypes); + cast(op.getDimensionsAttr()), newElementTypes); Block *newReducerBlock = rewriter.createBlock(&newOp.getBody()); IRMapping mapper; diff --git a/stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp b/stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp index c8c5a1ba3af..bb1d855072c 100644 --- a/stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp +++ b/stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp @@ -1,10 +1,17 @@ -// Copyright 2024 The StableHLO Authors -// -// Licensed under the Apache License, Version 2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +/* Copyright 2024 The StableHLO Authors. -// Implements composite inlining. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #include diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp index 1c1ba528f21..7c5f04ab6b5 100644 --- a/stablehlo/transforms/StablehloRefineShapes.cpp +++ b/stablehlo/transforms/StablehloRefineShapes.cpp @@ -418,62 +418,6 @@ struct EvalCompareOpPattern : public OpRewritePattern { } }; -struct EvalComputeReshapeShapeOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ComputeReshapeShapeOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (failed(validateResultTypeForEval(rewriter, op, resultType))) - return failure(); - - int64_t numElems; - if (failed(hlo::matchInt(op.getNumElements(), numElems))) - return rewriter.notifyMatchFailure( - op, "expected constant number of elements"); - - SmallVector dynShapeValues; - if (failed(hlo::matchInts(op.getDynamicShape(), dynShapeValues))) - return rewriter.notifyMatchFailure(op, "expected constant dynamic shape"); - - std::optional unspecifiedDimIdx; - int64_t dimProduct = 1; - constexpr int64_t kUnspecifiedDimSize = -1; - for (size_t i = 0; i < dynShapeValues.size(); ++i) { - if (dynShapeValues[i] == kUnspecifiedDimSize) { - if (unspecifiedDimIdx.has_value()) - return rewriter.notifyMatchFailure( - op, "multiple -1 values in dimensions is an undefined behavior"); - - unspecifiedDimIdx = i; - continue; - } - - dimProduct *= dynShapeValues[i]; - } - - if (numElems % dimProduct != 0) - return rewriter.notifyMatchFailure( - op, - "dimensions that can't evenly divide num elements is an undefined " - "behavior"); - - if (unspecifiedDimIdx.has_value()) - dynShapeValues[unspecifiedDimIdx.value()] = numElems / dimProduct; - - const auto resultBitWidth = resultType.getElementTypeBitWidth(); - auto result = - llvm::map_to_vector(dynShapeValues, [&](int64_t value) -> APSInt { - return APSInt(APInt(resultBitWidth, value), false); - }); - - rewriter.replaceOpWithNewOp(op, - getTensorAttr(resultType, result)); - - return success(); - } -}; - struct EvalConcatenateOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConcatenateOp op, @@ -1278,7 +1222,6 @@ void populateStablehloRefineShapesPatterns(RewritePatternSet* patterns, patterns->add(context); patterns->add(context); patterns->add(context); - patterns->add(context); patterns->add(context); patterns->add(context); patterns->add(context);