diff --git a/BUILD.bazel b/BUILD.bazel
index 49ca5b4dcce..03754b5fd92 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -311,6 +311,7 @@ cc_library(
],
strip_include_prefix = ".",
deps = [
+ ":base",
":interpreter_ops_inc_gen",
":reference_numpy",
":reference_ops",
diff --git a/docs/images/spec/gather.svg b/docs/images/spec/gather.svg
index ad7414c56c1..d2891862c5f 100644
--- a/docs/images/spec/gather.svg
+++ b/docs/images/spec/gather.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/docs/images/spec/scatter.svg b/docs/images/spec/scatter.svg
index c257d692aeb..e0e6a72a2fb 100644
--- a/docs/images/spec/scatter.svg
+++ b/docs/images/spec/scatter.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
diff --git a/docs/interpreter_status.md b/docs/interpreter_status.md
index df68df9a4e8..3d1bf406b07 100644
--- a/docs/interpreter_status.md
+++ b/docs/interpreter_status.md
@@ -60,23 +60,14 @@ interpreter supports resides in [hlo_expand_main.cc](https://github.com/openxla/
### Not in HLO
-Apart from the specced ops, this category consists of 10 unspecced ops (see
-[StableHLO Ops Categories](#stablehlo-ops-categories)) which are planed to be
-moved out of StableHLO. Some of these ops have existing passes in
+Apart from the specced ops, this category consists of 8 unspecced ops (see
+[StableHLO Ops Categories](#stablehlo-ops-categories)) which are planned to be
+moved out of StableHLO. Most of these ops have existing passes in
[mhlo](https://github.com/openxla/xla/tree/main/xla/mlir_hlo/mhlo/transforms) to
-convert them to StableHLO equivalent ops. There are three ops the interpreter
-does not support because there are no existing decompositions to StableHLO ops:
-
-* `compute_reshape_shape`
-* `cstr_reshapable`
-* `trace`
-
-`compute_reshape_shape` and `cstr_reshapable` ops are part of the ongoing
-Dynamism work, and they are planned to be removed from StableHLO (see
-[#1668](https://github.com/openxla/stablehlo/issues/1668)).
-
-`trace` op is private to XLA and there no no users in JAX, PyTorch or TensorFlow
-(see [#604](https://github.com/openxla/stablehlo/issues/604)).
+convert them to StableHLO equivalent ops. There is one op the interpreter
+does not support because there is no existing decomposition to StableHLO ops:
+`trace`. `trace` op is private to XLA and there no users in JAX, PyTorch or
+TensorFlow (see [#604](https://github.com/openxla/stablehlo/issues/604)).
The tool to convert remaining ops in this category to equivalent StableHLO ops
@@ -254,18 +245,12 @@ hlo-expand --triangular_solve_expander
# broadcast
mlir-hlo-opt -mhlo-legalize-broadcast-to-broadcast-in-dim
-# compute_reshape_shape
-# This op will be removed from StableHLO as part of Dynamism work (see #1668).
-
# create_token
mlir-hlo-opt -mhlo-legalize-create-token-to-after-all
# cross-replica-sum
mlir-hlo-opt -mhlo-legalize-cross-replica-sum-to-all-reduce
-# cstr_reshapable
-# This op will be removed from StableHLO as part of Dynamism work (see #1668).
-
# dot
mlir-hlo-opt -mhlo-legalize-dot-to-dot-general
@@ -295,6 +280,6 @@ mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general
| Extensibility | custom_call, get_tuple_element, tuple | 3 |
| Miscellaneous | batch_norm_grad, batch_norm_inference, batch_norm_training, cholesky, constant, fft, iota, rng, rng_bit_generator, triangular_solve | 10 |
| Modularity | call, func, module, return | 4 |
-| Not In HLO | broadcast, compute_reshape_shape, create_token, cross-replica-sum, cstr_reshapable, dot, einsum, torch_index_select, trace, unary_einsum | 10 |
+| Not In HLO | broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, trace, unary_einsum | 8 |
| Quantization | uniform_dequantize, uniform_quantize | 2 |
| Reduction | convolution, dot_general, reduce, reduce_window, select_and_scatter | 5 |
diff --git a/docs/spec.md b/docs/spec.md
index aa27e58cb65..95aaabfb79c 100644
--- a/docs/spec.md
+++ b/docs/spec.md
@@ -207,8 +207,8 @@ constraints:
* For per-tensor quantization:
* No additional constraints.
* For per-axis quantization:
- * (C12) `quantization_dimension < rank(self)`.
- * (C13) `dim(self, quantization_dimension) = size(scales)`.
+ * (C13) `quantization_dimension < rank(self)`.
+ * (C14) `dim(self, quantization_dimension) = size(scales)`.
```ebnf
TokenType ::= 'token'
@@ -331,9 +331,8 @@ in StableHLO programs. In the meanwhile, here is the list of these operations:
([#3](https://github.com/openxla/stablehlo/issues/3)), and
`trace` ([#604](https://github.com/openxla/stablehlo/issues/604)).
* "Dynamism" category of StableHLO operations - they were bootstrapped from
- MHLO, but we haven't specced them yet: `compute_reshape_shape`,
- `cstr_reshapable`, `dynamic_broadcast_in_dim`, `dynamic_conv`,
- `dynamic_gather`, `dynamic_iota`, `dynamic_pad`, `dynamic_reshape`,
+ MHLO,and we are in the process of speccing them: `dynamic_broadcast_in_dim`,
+ `dynamic_conv`, `dynamic_gather`, `dynamic_iota`, `dynamic_pad`, `dynamic_reshape`,
`real_dynamic_slice`, `set_dimension_size`
([#8](https://github.com/openxla/stablehlo/issues/8)).
* Shape computations, including `arith`, `shape` and `tensor` operations
@@ -2636,6 +2635,50 @@ planning to address this in
[More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/dot_general.mlir)
+### dynamic_iota
+
+#### Semantics
+
+This operation is functionally identical to
+[iota](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota)
+op, but the result shape is specified dynamically via `output_shape`.
+
+#### Inputs
+
+| Label | Name | Type | Constraints |
+|-------|------------------|------------------------------------------------------------------------------------|-------------|
+| (I1) | `output_shape` | 1-dimensional tensor constant of type `si64` | (C1), (C2) |
+| (I2) | `iota_dimension` | `si64` | (C1) |
+
+#### Outputs
+
+| Name | Type | Constraints |
+|----------|-----------------------------------------------------------------------------------|-------------|
+| `result` | tensor of integer, floating-point, or complex type or per-tensor quantized tensor | (C2) |
+
+#### Constraints
+
+* (C1) `0 <= iota_dimension < size(output_shape)`.
+* (C2) `rank(result) = size(output_shape)`.
+
+#### Examples
+
+```mlir
+
+%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
+%result = "stablehlo.dynamic_iota"(%output_shape) {
+ iota_dimension = 0 : i64
+} : (tensor<2xi64>) -> tensor<4x5xi64>
+// %result: [
+// [0, 0, 0, 0, 0],
+// [1, 1, 1, 1, 1],
+// [2, 2, 2, 2, 2],
+// [3, 3, 3, 3, 3]
+// ]
+```
+
+ [More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/dynamic_iota.mlir)
+
### dynamic_slice
#### Semantics
diff --git a/docs/status.md b/docs/status.md
index 85512cc6528..28cf6d318c7 100644
--- a/docs/status.md
+++ b/docs/status.md
@@ -66,7 +66,6 @@ one of the following tracking labels.
| compare | yes | yes | yes | yes | yes |
| complex | yes | yes | yes | yes | yes |
| composite | yes | yes | infeasible | yes | yes |
-| compute_reshape_shape | no | revisit | no | yes | no |
| concatenate | yes | yes | yes | yes | yes |
| constant | yes | yes | yes | yes | yes |
| convert | yes | yes | infeasible | yes | yes |
@@ -75,7 +74,6 @@ one of the following tracking labels.
| count_leading_zeros | yes | yes | yes | yes | yes |
| create_token | no | yes\* | yes\* | yes | revisit |
| cross-replica-sum | no | revisit | yes\* | no | revisit |
-| cstr_reshapable | no | revisit | no | yes | no |
| custom_call | yes | yes | infeasible | yes | yes |
| divide | yes | yes | yes | yes | yes |
| dot | no | revisit | infeasible | yes | revisit |
@@ -83,7 +81,7 @@ one of the following tracking labels.
| dynamic_broadcast_in_dim | no | revisit | infeasible | no | no |
| dynamic_conv | no | revisit | no | no | no |
| dynamic_gather | no | revisit | revisit | no | no |
-| dynamic_iota | no | revisit | infeasible | yes | no |
+| dynamic_iota | yes | yes | infeasible | yes | revisit |
| dynamic_pad | no | revisit | no | yes | no |
| dynamic_reshape | no | revisit | infeasible | yes | no |
| dynamic_slice | yes | yes | yes | yes | yes |
diff --git a/rfcs/20240311-gather-scatter-batching-dims.md b/rfcs/20240311-gather-scatter-batching-dims.md
new file mode 100644
index 00000000000..7f2f6933885
--- /dev/null
+++ b/rfcs/20240311-gather-scatter-batching-dims.md
@@ -0,0 +1,466 @@
+# [RFC] Add batching dims to `stablehlo.gather` and `stable.scatter` specification
+
+Status: Review
+Initial version: 03/11/2024
+Last updated: 03/11/2024
+Discussion thread: TBD
+
+## Overview
+
+This RFC proposes adding `operand_batching_dims` and
+`start_indices_batching_dims` attributes to `stablehlo.gather`.
+`operand_batching_dims` refers to the dimensions of the `operand` that are
+treated as batch. `start_indices_batching_dims` refers to the dimensions of the
+`start_indices` that are treated as batch. The corresponding dimension sizes
+must be equal. The semantics is equivalent to concatenating the outputs of the
+gather with each slices of `operand` and `start_indices`.
+
+Similarly, this RFC proposes adding `input_batching_dims` and
+`scatter_indices_batching_dims` attributes to `stablehlo.scatter`.
+`input_batching_dims` refers to the dimensions of each tensor in `inputs` that
+are treated as batch. `scatter_indices_batching_dims` refers to the dimensions
+of the `scatter_indices` that are treated as batch.
+
+## Motivation
+
+StableHLO gather and scatter ops currently have no way of specifying batch
+dimensions that correspond between all operands and result, only the indices
+tensor can have implicit batch dimensions that only exist in the
+result/update tensor (`batch_dims` in specification).
+
+This is important when the user wants a batched/vectorized version of
+gather/scatter across all operands and result (e.g., when using [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)).
+
+The current workaround is to use the implicit batch dimensions in the indices
+tensor (all dimensions but `index_vector_dim`) with `stablehlo.concatenate` and
+`stablehlo.iota` (and other ops like `stablehlo.clamp`) to mimic batch
+dimensions in the operand.
+
+This isn't ideal as it hides the fact that those are batch dimensions that
+correspond between the operands and result tensors. This information is crucial
+when doing sharding propagation between the operands and result tensors of the
+gather/scatter op, where partitioning batch dimensions across tensors is
+trivial, as they can be sharded in the same way with no communication needed.
+The above workaround requires pattern matching to identify those batch
+dimensions, which can be error prone and hard to maintain.
+
+This proposal is inspired by `lhs_batching_dims` and `rhs_batching_dims` of
+`stablehlo.dot_general`, which serve the same purpose.
+
+## Compatibility
+
+The new `stablehlo.gather` and `stablehlo.scatter` can be decomposed to the old
+ops by applying the workaround above. For `stablehlo.gather` this would mean
+making the `operand_batching_dims` as `collapsed_slice_dims` and
+`start_indices_batching_dims` as implicit batch dimensions in `start_indices`,
+by incrementing `index_vector_dim` by the size of `operand_batching_dims` (and
+updating `start_index_map` accordingly), and concatenating an iota for each
+batch dimension to the original `start_indices`.
+
+In the backward compatibility window (assuming 6 months), the old ops being
+loaded will automatically get an empty tensor for these added attributes.
+
+## Alternatives considered
+
+We could do the workaround above (using concatenated iota for the indices
+tensor), but in addition mark the batch dimension in each tensor using an
+unregistered attribute, so that the information won't be lost and partitioning
+systems can use it to partition and propagate through these batch dimensions.
+However, unregistered attributes can be discarded at any time and are harder to
+maintain (e.g. if the iota is replaced by another op, who is responsible for
+updating the unregistered attributes).
+
+Another option is to start with `stablehlo.dynamic_slice` and
+`stablehlo.dynamic_update_slice`, which are simpler ops that would suffice for a
+lot of the use cases (but not all) that we've encountered. Currently
+`stablehlo.gather` and `stable.scatter` are used to get a vectorized version of
+`stablehlo.dynamic_slice` and `stablehlo.dynamic_update_slice` respectively.
+However, for the same reason that they can be expressed by gather/scatter, we
+propose to go with the more general solution, that would address all use cases.
+
+## Naming
+
+As indicated above, the proposed naming of the new attributes is inspired by
+`lhs_batching_dims` and `rhs_batching_dims` of `stablehlo.dot_general`. Note
+that the gather op specification already uses the term `batch_dims`
+(`update_scatter_dims` for scatter op), to refer to all dimensions in the result
+tensor that aren't offset dimensions, which have corresponding dimensions in the
+start-indices tensor. The difference is that the proposed dimensions are
+explicit batching dimensions that exist in all operands and result, whereas the
+existing `batch_dims` are implicit batch dimensions (as they are derived by the
+offset dimensions) that exist only in the start-indices and result tensors.
+
+## Proposed Specification
+
+### gather
+
+#### Semantics
+
+Gathers slices from `operand` tensor from offsets specified in `start_indices`
+and produces a `result` tensor.
+
+The following diagram shows how elements in `result` map on elements in
+`operand` using a concrete example. The diagram picks a few example `result`
+indices and explains in detail which `operand` indices they correspond to.
+
+![gather](images/20240311-gather-scatter-batching-dims/gather.svg)
+
+More formally, `result[result_index] = operand[operand_index]` where:
+
+
+* `batch_dims = [d for d in axes(result) and d not in offset_dims]`.
+* `batch_index = result_index[batch_dims...]`.
+* `start_index` is defined as:
+ * `start_indices[bi0, ..., :, ..., biN]` where `bi` are individual elements in
+ `batch_index` and `:` is inserted at the `index_vector_dim` index, if
+ `index_vector_dim` < `rank(start_indices)`.
+ * `[start_indices[batch_index]]` otherwise.
+* For `d_operand` in `axes(operand)`,
+ * `full_start_index[d_operand] = clamp(start_index[d_start], 0,
+ dim(operand, d_operand) - slice_sizes[d_operand])`
+ if `d_operand = start_index_map[d_start]`.
+ * `full_start_index[d_operand] = 0` otherwise.
+* For `d_operand` in `axes(operand)`,
+ * `full_batching_index[d_operand] =
+ batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]`
+ if `d_operand = operand_batching_dims[i_batching]` and
+ `d_start = start_indices_batching_dims[i_batching]`.
+ * `full_batching_index[d_operand] = 0` otherwise.
+* `offset_index = result_index[offset_dims...]`.
+* `full_offset_index = [oi0, ..., 0, ..., oiN]` where `oi` are individual
+ elements in `offset_index`, and `0` is inserted at indices from
+ `collapsed_slice_dims` and `operand_batching_dims`.
+* `operand_index = full_start_index + full_batching_index + full_offset_index`.
+
+
+If `indices_are_sorted` is `true` then the implementation can assume that
+`start_indices` are sorted with respect to `start_index_map`, otherwise the
+behavior is undefined. More formally, for all `i1 < i2` from `indices(result)`,
+`full_start_index(i1) <= full_start_index(i2)`.
+
+#### Inputs
+
+| Label | Name | Type | Constraints |
+|-------|-------------------------------|----------------------------------------------|--------------------------------------------|
+| (I1) | `operand` | tensor or per-tensor quantized tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
+| (I2) | `start_indices` | tensor of integer type | (C2-C3), (C14), (C17), (C22) |
+| (I3) | `offset_dims` | 1-dimensional tensor constant of type `si64` | (C1), (C4-C5), (C22) |
+| (I4) | `collapsed_slice_dims` | 1-dimensional tensor constant of type `si64` | (C1), (C6-C9), (C22) |
+| (I5) | `operand_batching_dims` | 1-dimensional tensor constant of type `si64` | (C1), (C6), (C10-C12), (C16-C18), (C22) |
+| (I6) | `start_indices_batching_dims` | 1-dimensional tensor constant of type `si64` | (C13-C17) |
+| (I7) | `start_index_map` | 1-dimensional tensor constant of type `si64` | (C3), (C18-C19) |
+| (I8) | `index_vector_dim` | constant of type `si64` | (C2-C3), (C15), (C22) |
+| (I9) | `slice_sizes` | 1-dimensional tensor constant of type `si64` | (C9), (C12), (C20-C22) |
+| (I10) | `indices_are_sorted` | constant of type `i1` | |
+
+#### Outputs
+
+| Name | Type | Constraints |
+|----------|---------------------------------------|-----------------|
+| `result` | tensor or per-tensor quantized tensor | (C5), (C22-C23) |
+
+#### Constraints
+
+* (C1) `rank(operand) = size(offset_dims) + size(collapsed_slice_dims) +
+ size(operand_batching_dims)`.
+* (C2) `0 <= index_vector_dim <= rank(start_indices)`.
+* (C3) `size(start_index_map) =
+ index_vector_dim < rank(start_indices) ?
+ dim(start_indices, index_vector_dim) : 1`.
+* (C4) `is_unique(offset_dims) and is_sorted(offset_dims)`.
+* (C5) `0 <= offset_dims < rank(result)`.
+* (C6) `is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))`
+* (C7) `is_sorted(collapsed_slice_dims)`.
+* (C8) `0 <= collapsed_slice_dims < rank(operand)`.
+* (C9) `slice_sizes[collapsed_slice_dims...] <= 1`.
+* (C10) `is_sorted(operand_batching_dims)`.
+* (C11) `0 <= operand_batching_dims < rank(operand)`.
+* (C12) `slice_sizes[operand_batching_dims...] <= 1`.
+* (C13) `is_unique(start_indices_batching_dims)`.
+* (C14) `0 <= start_indices_batching_dims < rank(start_indices)`.
+* (C15) `index_vector_dim not in start_indices_batching_dims`.
+* (C16) `size(operand_batching_dims) == size(start_indices_batching_dims)`.
+* (C17) `dim(operand, operand_batching_dims...) =
+ dim(start_indices, start_indices_batching_dims...)`.
+* (C18) `is_unique(concatenate(start_index_map, operand_batching_dims))`.
+* (C19) `0 <= start_index_map < rank(operand)`.
+* (C20) `size(slice_sizes) = rank(operand)`.
+* (C21) `0 <= slice_sizes <= shape(operand)`.
+* (C22) `shape(result) = combine(batch_dim_sizes, offset_dim_sizes)` where:
+ * `batch_dim_sizes = shape(start_indices)` except that the dimension size
+ of `start_indices` corresponding to `index_vector_dim` is not included.
+ * `offset_dim_sizes = slice_sizes` except that the dimension sizes in
+ `slice_sizes` corresponding to `collapsed_slice_dims` and
+ `operand_batching_dims` are not included.
+ * `combine` puts `batch_dim_sizes` at axes corresponding to `batch_dims` and
+ `offset_dim_sizes` at axes corresponding to `offset_dims`.
+* (C23) `element_type(operand) = element_type(result)`.
+
+#### Examples
+
+```mlir
+// %operand: [
+// [
+// [[1, 2], [3, 4], [5, 6], [7, 8]],
+// [[9, 10],[11, 12], [13, 14], [15, 16]],
+// [[17, 18], [19, 20], [21, 22], [23, 24]]
+// ],
+// [
+// [[25, 26], [27, 28], [29, 30], [31, 32]],
+// [[33, 34], [35, 36], [37, 38], [39, 40]],
+// [[41, 42], [43, 44], [45, 46], [47, 48]]
+// ]
+// ]
+// %start_indices: [
+// [
+// [[0, 0], [1, 0], [2, 1]],
+// [[0, 1], [1, 1], [0, 9]]
+// ],
+// [
+// [[0, 0], [2, 1], [2, 2]],
+// [[1, 2], [0, 1], [1, 0]]
+// ]
+// ]
+%result = "stablehlo.gather"(%operand, %start_indices) {
+ dimension_numbers = #stablehlo.gather<
+ offset_dims = [3, 4],
+ collapsed_slice_dims = [1],
+ operand_batching_dims = [0],
+ start_indices_batching_dims = [1],
+ start_index_map = [2, 1],
+ index_vector_dim = 3>,
+ slice_sizes = array,
+ indices_are_sorted = false
+} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
+// %result: [
+// [
+// [
+// [[1, 2], [3, 4]],
+// [[3, 4], [5, 6]],
+// [[13, 14], [15, 16]]
+// ],
+// [
+// [[33, 34], [35, 36]],
+// [[35, 36], [37, 38]],
+// [[41, 42], [43, 44]]
+// ]
+// ],
+// [
+// [
+// [[1, 2], [3, 4]],
+// [[13, 14], [15, 16]],
+// [[21, 22], [23, 24]]
+// ],
+// [
+// [[43, 44], [45, 46]],
+// [[33, 34], [35, 36]],
+// [[27, 28], [29, 30]]
+// ]
+// ]
+// ]
+```
+
+ [More Examples](../stablehlo/tests/interpret/gather.mlir)
+
+### scatter
+
+#### Semantics
+
+Produces `results` tensors which are equal to `inputs` tensors except that
+several slices specified by `scatter_indices` are updated with the values
+`updates` using `update_computation`.
+
+The following diagram shows how elements in `updates...` map on elements in
+`results...` using a concrete example. The diagram picks a few example
+`updates...` indices and explains in detail which `results...` indices they
+correspond to.
+
+![](images/20240311-gather-scatter-batching-dims/scatter.svg)
+
+More formally, for all `update_index` in `index_space(updates[0])`:
+
+* `update_scatter_dims = [d for d in axes(updates[0]) and d not in
+ update_window_dims]`.
+* `update_scatter_index = update_index[update_scatter_dims...]`.
+* `start_index` is defined as:
+ * `scatter_indices[si0, ..., :, ..., siN]` where `si` are individual
+ elements in `update_scatter_index` and `:` is inserted at the
+ `index_vector_dim` index, if `index_vector_dim` <
+ `rank(scatter_indices)`.
+ * `[scatter_indices[update_scatter_index]]` otherwise.
+* For `d_input` in `axes(inputs[0])`,
+ * `full_start_index[d_input] = start_index[d_start]` if
+ `d_input = scatter_dims_to_operand_dims[d_start]`.
+ * `full_start_index[d_input] = 0` otherwise.
+* For `d_input` in `axes(inputs[0])`,
+ * `full_batching_index[d_input] =
+ update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]`
+ if `d_input = input_batching_dims[i_batching]` and
+ `d_start = scatter_indices_batching_dims[i_batching]`.
+ * `full_batching_index[d_input] = 0` otherwise.
+* `update_window_index = update_index[update_window_dims...]`.
+* `full_window_index = [wi0, ..., 0, ..., wiN]` where `wi` are individual
+ elements in `update_window_index`, and `0` is inserted at indices from
+ `inserted_window_dims` and `input_batching_dims`.
+* `result_index = full_start_index + full_batching_index + full_window_index`.
+
+Given that, `results = exec(schedule, inputs)`, where:
+
+* `schedule` is an implementation-defined permutation of
+ `index_space(updates[0])`.
+* `exec([update_index, ...], results) = exec([...], updated_results)` where:
+ * If `result_index` is in bounds for `shape(results...)`
+ * `updates_converted = to_destination_type(
+ updates...[update_index], type(func_inputs(update_computation)
+ [len(func_inputs(update_computation))//2:])... )`
+ * `updated_values = update_computation(results...[result_index],
+ updates_converted)`
+ * `updated_results` is a copy of `results` with `results...[result_index]`
+ set to `updated_values...`.
+ * Otherwise
+ * `updated_results = results`.
+* `exec([], results) = results`.
+
+If `indices_are_sorted` is `true` then the implementation can assume that
+`scatter_indices` are sorted with respect to `scatter_dims_to_operand_dims`,
+otherwise the behavior is undefined. More formally, for all `i1 < i2` from
+`indices(result)`, `full_start_index(i1)` <= `full_start_index(i2)`.
+
+If `unique_indices` is `true` then the implementation can assume that all
+`result_index` indices being scattered to are unique. If `unique_indices` is
+`true` but the indices being scattered to are not unique then the behavior is
+undefined.
+
+#### Inputs
+
+| Label | Name | Type | Constraints |
+|-------|---------------------------------------|------------------------------------------------------------|------------------------------------------------------------|
+| (I1) | `inputs` | variadic number of tensors or per-tensor quantized tensors | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
+| (I2) | `scatter_indices` | tensor of integer type | (C4), (C15), (C19), (C22) |
+| (I3) | `updates` | variadic number of tensors or per-tensor quantized tensors | (C3-C6), (C8) |
+| (I4) | `update_window_dims` | 1-dimensional tensor constant of type `si64` | (C2), (C4), (C7-C8) |
+| (I5) | `inserted_window_dims` | 1-dimensional tensor constant of type `si64` | (C2), (C4), (C9-C11) |
+| (I6) | `input_batching_dims` | 1-dimensional tensor constant of type `si64` | (C2), (C4), (C9), (C12-13), (C17-18), (C20) |
+| (I7) | `scatter_indices_batching_dims` | 1-dimensional tensor constant of type `si64` | (C14-C18) |
+| (I8) | `scatter_dims_to_operand_dims` | 1-dimensional tensor constant of type `si64` | (C19-C21) |
+| (I9) | `index_vector_dim` | constant of type `si64` | (C4), (C16), (C19), (C22) |
+| (I10) | `indices_are_sorted` | constant of type `i1` | |
+| (I11) | `unique_indices` | constant of type `i1` | |
+| (I12) | `update_computation` | function | (C23) |
+
+#### Outputs
+
+| Name | Type | Constraints |
+|-----------|------------------------------------------------------------|-------------|
+| `results` | variadic number of tensors or per-tensor quantized tensors | (C24-C25) |
+
+#### Constraints
+
+* (C1) `same(shape(inputs...))`.
+* (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
+ + size(input_batching_dims)`.
+* (C3) `same(shape(updates...))`.
+* (C4) `shape(updates[0]) = combine(update_scatter_dim_sizes,
+ update_window_dim_sizes)` where:
+ * `update_scatter_dim_sizes = shape(scatter_indices)` except that
+ the dimension size of `scatter_indices` corresponding to
+ `index_vector_dim` is not included.
+ * `update_window_dim_sizes <= shape(inputs[0])` except that
+ the dimension sizes in `inputs[0]` corresponding to `inserted_window_dims`
+ and `input_batching_dims` are not included.
+ * `combine` puts `update_scatter_dim_sizes` at axes corresponding to
+ `update_scatter_dims` and `update_window_dim_sizes` at axes corresponding
+ to `update_window_dims`.
+* (C5) `0 < size(inputs) = size(updates) = N`.
+* (C6) `element_type(updates...) = element_type(inputs...)`.
+* (C7) `is_unique(update_window_dims) and is_sorted(update_window_dims)`.
+* (C8) `0 <= update_window_dims < rank(updates[0])`.
+* (C9) `is_unique(concatenate(inserted_window_dims, input_batching_dims))`
+* (C10) `is_sorted(inserted_window_dims)`.
+* (C11) `0 <= inserted_window_dims < rank(inputs[0])`.
+* (C12) `is_sorted(input_batching_dims)`.
+* (C13) `0 <= input_batching_dims < rank(inputs[0]))`.
+* (C14) `is_unique(scatter_indices_batching_dims)`.
+* (C15) `0 <= scatter_indices_batching_dims < rank(scatter_indices)`.
+* (C16) `index_vector_dim not in scatter_indices_batching_dims`.
+* (C17) `size(input_batching_dims) == size(scatter_indices_batching_dims)`.
+* (C18) `dim(inputs[0], input_batching_dims...) =
+ dim(scatter_indices, scatter_indices_batching_dims...)`.
+* (C19) `size(scatter_dims_to_operand_dims) =
+ index_vector_dim < rank(scatter_indices) ?
+ dim(scatter_indices, index_vector_dim) : 1`.
+* (C20) `is_unique(concatenate(scatter_dims_to_operand_dims,
+ input_batching_dims))`.
+* (C21) `0 <= scatter_dims_to_operand_dims < rank(inputs[0])`.
+* (C22) `0 <= index_vector_dim <= rank(scatter_indices)`.
+* (C23) `update_computation` has type `(tensor, ..., tensor,
+ tensor, ..., tensor) -> (tensor, ..., tensor)`,
+ where `is_promotable(element_type(inputs[i]), Ei)`.
+* (C24) `shape(inputs...) = shape(results...)`.
+* (C25) `element_type(results[i]) = Ei` for all `i` in `[0,N)`.
+
+#### Examples
+
+```mlir
+// %input: [
+// [
+// [[1, 2], [3, 4], [5, 6], [7, 8]],
+// [[9, 10],[11, 12], [13, 14], [15, 16]],
+// [[17, 18], [19, 20], [21, 22], [23, 24]]
+// ],
+// [
+// [[25, 26], [27, 28], [29, 30], [31, 32]],
+// [[33, 34], [35, 36], [37, 38], [39, 40]],
+// [[41, 42], [43, 44], [45, 46], [47, 48]]
+// ]
+// ]
+// %scatter_indices: [
+// [
+// [[0, 0], [1, 0], [2, 1]],
+// [[0, 1], [1, 1], [0, 9]]
+// ],
+// [
+// [[0, 0], [2, 1], [2, 2]],
+// [[1, 2], [0, 1], [1, 0]]
+// ]
+// ]
+// %update: [
+// [
+// [[1, 1], [1, 1], [1, 1]],
+// [[1, 1], [1, 1], [1, 1]]
+// ],
+// [
+// [[1, 1], [1, 1], [1, 1]],
+// [[1, 1], [1, 1], [1, 1]]
+// ]
+// ]
+%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
+ ^bb0(%arg0: tensor, %arg1: tensor):
+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor
+ "stablehlo.return"(%0) : (tensor) -> ()
+}) {
+ scatter_dimension_numbers = #stablehlo.scatter<
+ update_window_dims = [3, 4],
+ inserted_window_dims = [1],
+ input_batching_dims = [0],
+ scatter_indices_batching_dims = [1],
+ scatter_dims_to_operand_dims = [2, 1],
+ index_vector_dim = 3>,
+ indices_are_sorted = false,
+ unique_indices = false
+} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
+// %result: [
+// [
+// [[3, 4], [6, 7], [6, 7], [7, 8]],
+// [[9, 10],[11, 12], [15, 16], [17, 18]],
+// [[17, 18], [19, 20], [22, 23], [24, 25]]
+// ],
+// [
+// [[25, 26], [28, 29], [30, 31], [31, 32]],
+// [[35, 36], [38, 39], [38, 39], [39, 40]],
+// [[41, 42], [44, 45], [46, 47], [47, 48]]
+// ]
+// ]
+```
+
+ [More Examples](../stablehlo/tests/interpret/scatter.mlir)
diff --git a/rfcs/images/20240311-gather-scatter-batching-dims/gather.svg b/rfcs/images/20240311-gather-scatter-batching-dims/gather.svg
new file mode 100644
index 00000000000..d2891862c5f
--- /dev/null
+++ b/rfcs/images/20240311-gather-scatter-batching-dims/gather.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/rfcs/images/20240311-gather-scatter-batching-dims/scatter.svg b/rfcs/images/20240311-gather-scatter-batching-dims/scatter.svg
new file mode 100644
index 00000000000..b7581cf9031
--- /dev/null
+++ b/rfcs/images/20240311-gather-scatter-batching-dims/scatter.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp
index f938f2c233a..316cb1073ce 100644
--- a/stablehlo/dialect/Base.cpp
+++ b/stablehlo/dialect/Base.cpp
@@ -132,6 +132,7 @@ bool isCompatibleForHloTypeInference(TypeRange tp1, TypeRange tp2) {
}
bool isCompatibleForHloTypeInference(ArrayRef shape1, Type tp2) {
+ if (llvm::any_of(shape1, [&](int64_t x) { return x < 0; })) return false;
auto stp2 = dyn_cast(tp2);
if (!stp2) return false;
return isCompatibleForHloTypeInference(
@@ -141,11 +142,7 @@ bool isCompatibleForHloTypeInference(ArrayRef shape1, Type tp2) {
bool isCompatibleForHloTypeInference(Value shape1, Type tp2) {
SmallVector shapeVec1;
if (!succeeded(matchInts(shape1, shapeVec1))) return true;
- if (llvm::any_of(shapeVec1, [&](int64_t x) { return x < 0; })) return false;
- auto stp2 = dyn_cast(tp2);
- if (!stp2) return false;
- auto tp1 = RankedTensorType::get(shapeVec1, stp2.getElementType());
- return isCompatibleForHloTypeInference(tp1, tp2);
+ return isCompatibleForHloTypeInference(shapeVec1, tp2);
}
LogicalResult matchInt(Value value, int64_t& result) {
@@ -628,5 +625,75 @@ mlir::Speculation::Speculatability getShapedSpeculatability(
: mlir::Speculation::NotSpeculatable;
}
+bool isValidStablehloQuantizedElementType(Type elementType) {
+ auto quantizedElementType = dyn_cast(elementType);
+ if (!quantizedElementType) return false;
+
+ int64_t storageTypeMin = quantizedElementType.getStorageTypeMin();
+ int64_t storageTypeMax = quantizedElementType.getStorageTypeMax();
+
+ SmallVector zeroPoints;
+ SmallVector scales;
+ if (auto quantizedPerTensorElementType =
+ dyn_cast(elementType)) {
+ zeroPoints.push_back(quantizedPerTensorElementType.getZeroPoint());
+ scales.push_back(quantizedPerTensorElementType.getScale());
+ } else {
+ auto quantizedPerAxisElementType =
+ cast(elementType);
+ zeroPoints.insert(zeroPoints.begin(),
+ quantizedPerAxisElementType.getZeroPoints().begin(),
+ quantizedPerAxisElementType.getZeroPoints().end());
+ scales.insert(scales.begin(),
+ quantizedPerAxisElementType.getScales().begin(),
+ quantizedPerAxisElementType.getScales().end());
+ }
+
+ // quantized_type_c5
+ auto maxPosFiniteNum =
+ APFloat::getLargest(quantizedElementType.getExpressedType()
+ .cast()
+ .getFloatSemantics())
+ .convertToDouble();
+ auto minPosFiniteNum =
+ APFloat::getSmallest(quantizedElementType.getExpressedType()
+ .cast()
+ .getFloatSemantics())
+ .convertToDouble();
+ if (llvm::any_of(scales, [&](double scale) {
+ return scale < minPosFiniteNum || scale > maxPosFiniteNum;
+ })) {
+ return false;
+ }
+
+ // quantized_type_c8, quantized_type_c9
+ if (llvm::any_of(zeroPoints, [&](int64_t zeroPoint) {
+ return storageTypeMin > zeroPoint || zeroPoint > storageTypeMax;
+ })) {
+ return false;
+ }
+
+ return true;
+}
+
+bool isValidQuantizedDimension(Type type) {
+ auto rankedType = dyn_cast(type);
+ if (!rankedType) return true;
+
+ auto quantizedPerAxisElementType =
+ dyn_cast(
+ rankedType.getElementType());
+
+ if (!quantizedPerAxisElementType) return true;
+
+ // quantized_type_c12, quantized_type_c13, quantized_type_c14
+ int64_t quantDim = quantizedPerAxisElementType.getQuantizedDimension();
+ int64_t numScales =
+ static_cast(quantizedPerAxisElementType.getScales().size());
+ return quantDim >= 0 && quantDim < rankedType.getRank() &&
+ (!rankedType.isDynamicDim(quantDim) &&
+ numScales == rankedType.getDimSize(quantDim));
+}
+
} // namespace hlo
} // namespace mlir
diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h
index 7b045db7064..d1ae4d59219 100644
--- a/stablehlo/dialect/Base.h
+++ b/stablehlo/dialect/Base.h
@@ -90,6 +90,16 @@ bool isCompatibleForHloTypeInference(Value shape1, Type tp2);
// compatible with the given type for the purposes of HLO type inference.
bool isCompatibleForHloTypeInference(ArrayRef shape1, Type tp2);
+// Returns true if the given element-type is a mlir::quant::QuantizedType
+// and follow the constraints corresponding to quantization parameters as
+// mentioned in the StableHLO specification.
+bool isValidStablehloQuantizedElementType(Type elementType);
+
+// Returns true if the given type is a ranked per-axis tensor type
+// and follow the constraints corresponding to quantized dimension as
+// mentioned in the StableHLO specification.
+bool isValidQuantizedDimension(Type type);
+
// TODO(zhouxin) Move type inference related methods to TypeInference.cpp
std::pair inferConcatenatedDimAndBound(int64_t leftSize,
diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td
index 29bbf0c1ef7..88fa81414b6 100644
--- a/stablehlo/dialect/Base.td
+++ b/stablehlo/dialect/Base.td
@@ -52,6 +52,9 @@ def HLO_Complex : Complex>;
// Quantized element type definitions.
//===----------------------------------------------------------------------===//
+def IsValidStablehloQuantizedElementType : CPred<"mlir::hlo::isValidStablehloQuantizedElementType($_self)">;
+def IsValidQuantizedDimension : CPred<"mlir::hlo::isValidQuantizedDimension($_self)">;
+
// TODO(b/230381284): Upstream width-specific uniform quantized element types.
class StableHLO_UniformQuantizedSignedInt
: Type($_self)">,
@@ -59,7 +62,7 @@ class StableHLO_UniformQuantizedSignedInt
".getStorageTypeIntegralWidth() == " # width>,
CPred<"cast(" #
"cast($_self).getStorageType()" #
- ").isSignless()">]>,
+ ").isSignless()">, IsValidStablehloQuantizedElementType]>,
"QI" # width # " type"> {
string name = "UniformQuantizedSignedInt";
int bitwidth = width;
@@ -71,7 +74,7 @@ class StableHLO_UniformQuantizedPerAxisSignedInt
".getStorageTypeIntegralWidth() == " # width>,
CPred<"cast(" #
"cast($_self).getStorageType()" #
- ").isSignless()">]>,
+ ").isSignless()">, IsValidStablehloQuantizedElementType]>,
"QI" # width # " type"> {
string name = "UniformQuantizedPerAxisSignedInt";
int bitwidth = width;
@@ -82,7 +85,7 @@ class StableHLO_UniformQuantizedUnsignedInt
CPred<"cast($_self)" #
".getStorageTypeIntegralWidth() == " # width>,
CPred<"!cast($_self)" #
- ".isSigned()">]>,
+ ".isSigned()">, IsValidStablehloQuantizedElementType]>,
"QUI" # width # " type"> {
string name = "UniformQuantizedUnsignedInt";
int bitwidth = width;
@@ -93,7 +96,7 @@ class StableHLO_UniformQuantizedPerAxisUnsignedInt
CPred<"cast($_self)" #
".getStorageTypeIntegralWidth() == " # width>,
CPred<"!cast($_self)" #
- ".isSigned()">]>,
+ ".isSigned()">, IsValidStablehloQuantizedElementType]>,
"QUI" # width # " type"> {
string name = "UniformQuantizedPerAxisUnsignedInt";
int bitwidth = width;
@@ -152,10 +155,12 @@ def HLO_Fp32Or64Tensor : RankedTensorOf<[HLO_Float32Or64]>;
def HLO_QuantizedIntTensor : RankedTensorOf<[HLO_QuantizedInt]>;
// per-axis quantized integer tensor type
-def HLO_PerAxisQuantizedIntTensor : RankedTensorOf<[HLO_PerAxisQuantizedInt]>;
+def HLO_PerAxisQuantizedIntTensor : RankedTensorOf<[HLO_PerAxisQuantizedInt],
+ [IsValidQuantizedDimension]>;
// per-tensor or per-axis quantized integer tensor type
-def HLO_QuantizedIntOrPerAxisQuantizedIntTensor : RankedTensorOf<[HLO_QuantizedInt, HLO_PerAxisQuantizedInt]>;
+def HLO_QuantizedIntOrPerAxisQuantizedIntTensor : RankedTensorOf<[HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
+ [IsValidQuantizedDimension]>;
def HLO_PredTensor : RankedTensorOf<[HLO_Pred]>;
@@ -163,7 +168,8 @@ def HLO_Tensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_
def HLO_NonQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex]>;
-def HLO_TensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt]>;
+def HLO_TensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
+ [IsValidQuantizedDimension]>;
def HLO_ComplexTensor : RankedTensorOf<[HLO_Complex]>;
@@ -188,7 +194,8 @@ def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>;
// TODO(b/326463552): Remove these when CHLO no longer needs unranked dynamism.
//===----------------------------------------------------------------------===//
-def HLO_AnyTensor : TensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt]>;
+def HLO_AnyTensor : TensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
+ [IsValidQuantizedDimension]>;
def HLO_AnyPredTensor : TensorOf<[HLO_Pred]>;
@@ -248,8 +255,8 @@ def HLO_StaticDimensionTensor : RankedTensorOf<[HLO_DimensionValue], [HasStaticS
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>;
-def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : StaticShapeTensorOf<[
- HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt]>;
+def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
+ [IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">;
def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>;
diff --git a/stablehlo/dialect/ChloOps.cpp b/stablehlo/dialect/ChloOps.cpp
index f39d3065cdb..915cbc9a240 100644
--- a/stablehlo/dialect/ChloOps.cpp
+++ b/stablehlo/dialect/ChloOps.cpp
@@ -334,23 +334,6 @@ LogicalResult ConstantLikeOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// MinimumBroadcastShapesOp
-//===----------------------------------------------------------------------===//
-LogicalResult MinimumBroadcastShapesOp::verify() {
- // Check that the number of operands matches the number of outputs.
- unsigned resultShapesCount = getResults().size();
- unsigned operandShapesCount = getShapes().size();
- if (operandShapesCount != resultShapesCount)
- return emitOpError() << "number of operand shapes (" << operandShapesCount
- << ") does not match number of result shapes ("
- << resultShapesCount << ")";
- if (operandShapesCount < 2)
- return emitOpError() << "number of operand shapes (" << operandShapesCount
- << ") should be >= 2";
- return success();
-}
-
LogicalResult ConstantLikeOp::inferReturnTypeComponents(
MLIRContext* /*context*/, std::optional location,
ValueShapeRange operands, DictionaryAttr attributes,
diff --git a/stablehlo/dialect/ChloOps.td b/stablehlo/dialect/ChloOps.td
index c901d8de136..fc7a7054d49 100644
--- a/stablehlo/dialect/ChloOps.td
+++ b/stablehlo/dialect/ChloOps.td
@@ -834,71 +834,4 @@ row for matrices).}]>:$k
}];
}
-//===----------------------------------------------------------------------===//
-// Helper ops
-//===----------------------------------------------------------------------===//
-
-def CHLO_MinimumBroadcastShapesOp :
- CHLO_Op<"minimum_broadcast_shapes", [Pure]> {
- string summary = "Minimizes the rank of two or more shapes to be broadcasted";
-
- string description = [{
- Given two or more 1D tensors representing shapes, returns one 1D tensor for
- each operand, where operand `i` corresponds to output `i`.
-
- The returned tensors have the property that they specify a shape which is a
- reshape of the corresponding input shape, and the broadcasted output shape
- (using shape::BroadcastOp) of the returned shapes is a reshape of the
- broadcasted output shape of the input shapes. Among all possibilities with
- this property, the one is chosen which minimizes the rank of each returned
- shape.
-
- The general idea of this op is that it can be used for ops which have a
- broadcasting semantic to operate on shapes with a possibly smaller rank
- while preserving equivalence of the computed values. After computing the
- result of the op using reshaped operands, the result can be reshaped to the
- result that would have been originally computed.
-
- Here is an example with two input shapes:
-
- ```mlir
- chlo.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1],
- [1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3]
- ```
-
- The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the
- broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are
- reshapes of each other, and also each output is a reshape of the
- corresponding input.
- }];
-
- let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes);
- let results = (outs Variadic<1DTensorOf<[Index]>>:$results);
-
- let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)";
-
- let hasVerifier = 1;
-}
-
-def CHLO_DynamicReshapeOp: CHLO_Op<"dynamic_reshape", [Pure,
- DeclareOpInterfaceMethods]> {
- let summary = "Reshape a tensor to a given, possibly dynamic, shape.";
- let description = [{
- Reshapes `operand` to `output_shape`. This allows a single dimension to be
- specified as -1, and it will be computed by the op to create a valid
- reshape.
-
- Requires:
- - The length of `output_shape` is equal to the rank of `result`.
- - The number of elements in `operand` (that is, the product of extents of
- its shape) is equal to the number of elements in `output_shape` (that is,
- the product of values in `output_shape` with one dimension possibly
- computed).
- - All shape values should be at least -1, and only one extent can be -1.
- }];
-
- let arguments = (ins HLO_AnyTensor:$operand, HLO_DimensionTensor:$output_shape);
- let results = (outs HLO_AnyTensor:$result);
-}
-
#endif // STABLEHLO_DIALECT_CHLO_OPS
diff --git a/stablehlo/dialect/StablehloAttrs.td b/stablehlo/dialect/StablehloAttrs.td
index 28c896eb2e1..0fa7cc9a130 100644
--- a/stablehlo/dialect/StablehloAttrs.td
+++ b/stablehlo/dialect/StablehloAttrs.td
@@ -30,7 +30,7 @@ def GenericDenseI64ArrayAttr : Attr setNameFn) {
mlir::TensorType type = getType();
- if (type.getElementType().isa()) {
+ if (isa(type.getElementType())) {
setNameFn(getResult(), "c");
} else {
setNameFn(getResult(), "cst");
diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td
index dabf8b3ef71..ed4d3a86dac 100644
--- a/stablehlo/dialect/StablehloOps.td
+++ b/stablehlo/dialect/StablehloOps.td
@@ -124,20 +124,27 @@ def StableHLO_IotaOp : StableHLO_Op<"iota", [Pure]> {
def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [ConditionallySpeculatable, NoMemoryEffect]> {
let summary = "DynamicIota operation";
let description = [{
- This operation is a work in progress, so it is not yet included in
- the StableHLO specification: https://github.com/openxla/stablehlo/issues/8.
+ Fills a `result` tensor with values in increasing order starting from zero
+ along the `iota_dimension` dimension.
Informally, this operation does the same thing as IotaOp except that the
result shape is specified dynamically via `output_shape`:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_iota
+
Example:
```mlir
- %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<1xindex>) -> tensor<4xi32>
+ %output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
+ %0 = stablehlo.dynamic_iota %output_shape, dim = 0 : (tensor<2xi64>) -> tensor<4x5xi64>
```
}];
- let arguments = (ins HLO_StaticDimensionTensor:$output_shape, I64Attr:$iota_dimension);
+ let arguments = (ins
+ HLO_StaticDimensionTensor:$output_shape /*dynamic_iota_i1*/,
+ I64Attr:$iota_dimension /*dynamic_iota_i2*/
+ );
let results = (outs HLO_Tensor:$result);
let hasVerifier = 1;
@@ -3519,58 +3526,4 @@ def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv",
let results = (outs HLO_Tensor);
}
-def StableHLO_ComputeReshapeShapeOp : StableHLO_Op<
- "compute_reshape_shape",
- [Pure, AllShapesMatch<["dynamic_shape", "result"]>]> {
- let summary = "ComputeReshapeShape operation";
- let description = [{
- This operation is a work in progress, so it is not yet included in
- the StableHLO specification: https://github.com/openxla/stablehlo/issues/8.
-
- Informally, this operation computes an output_shape for DynamicReshapeOp
- from the `num_elements` number of elements in an operand of DynamicReshapeOp
- and the `dynamic_shape` shape provided to TF's reshape:
- https://www.tensorflow.org/api_docs/python/tf/reshape
-
- For example, for `num_elements = 12` and `dynamic_shape = [2, -1]`,
- the `result` is going to be `[2, 6]`. If operands are not valid (e.g. if
- dimensions do not evenly divide the number of elements, or if there are
- multiple -1 values in dimensions), this leads to undefined behavior.
-
- Example:
- ```mlir
- %result = stablehlo.compute_reshape_shape %num_elements, %dynamic_shape
- : (index, tensor<2xi32>) -> tensor<2xi32>
- ```
- }];
-
- let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape);
- let results = (outs 1DTensorOf<[AnyInteger, Index]>:$result);
-
- let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
-}
-
-def StableHLO_CstrReshapableOp :
- StableHLO_Op<"cstr_reshapable", [Pure]> {
- let summary = "CstrReshapable operation";
- let description = [{
- This operation is a work in progress, so it is not yet included in
- the StableHLO specification: https://github.com/openxla/stablehlo/issues/8.
-
- Informally, this operation creates a witness on the constraint that
- ComputeReshapeShape would succeed with the provided operands.
-
- Example:
- ```mlir
- %result = stablehlo.cstr_reshapable %num_elements, %dynamic_shape
- : (index, tensor<3xi32>) -> !shape.witness
- ```
- }];
-
- let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape);
- let results = (outs Shape_WitnessType:$result);
-
- let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
-}
-
#endif // STABLEHLO_DIALECT_STABLEHLO_OPS
diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp
index 7ad992705d5..3973543437d 100644
--- a/stablehlo/dialect/TypeInference.cpp
+++ b/stablehlo/dialect/TypeInference.cpp
@@ -44,6 +44,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/Regex.h"
+#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Attributes.h"
@@ -66,9 +67,11 @@ limitations under the License.
namespace mlir {
namespace hlo {
namespace {
+
//===----------------------------------------------------------------------===//
// Utils for quantization specific verifications
//===----------------------------------------------------------------------===//
+
template
bool allQuantized(ArrayRef typeRange) {
return llvm::all_of(
@@ -468,6 +471,27 @@ LogicalResult verifyAddOp(std::optional location, Operation* op,
return success();
}
+// If the shape operand is constant, checks that it is compatible with the
+// result's shape. Emits an error if the shapes are incompatible.
+LogicalResult verifyShapeOperandIsCompatibleWithResultType(
+ std::optional loc, Value shapeOperand, Type resultType) {
+ if (SmallVector shape;
+ succeeded(matchInts(shapeOperand, shape)) &&
+ !isCompatibleForHloTypeInference(shape, resultType)) {
+ std::string str;
+ llvm::raw_string_ostream os(str);
+ llvm::interleaveComma(shape, os, [&](int64_t i) { os << i; });
+ return emitOptionalError(loc, "output shape [", os.str(),
+ "] is incompatible with return type of operation ",
+ resultType);
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Verifiers
+//===----------------------------------------------------------------------===//
+
LogicalResult verifyTransposeOp(std::optional location,
Type operandType, ArrayRef permutation,
Type resultType) {
@@ -490,7 +514,7 @@ LogicalResult verifyTransposeOp(std::optional location,
if (operandQDim != permutation[resultQDim])
return emitOptionalError(location, "operand quantization_dimension ",
operandQDim, " is not same as permutation[",
- resultQDim, "] ", permutation[resultQDim]);
+ resultQDim, "] = ", permutation[resultQDim]);
}
return success();
}
@@ -3356,9 +3380,9 @@ LogicalResult verifyBroadcastInDimOp(std::optional location,
auto resultQDim = resultQType.getQuantizedDimension();
if (resultQDim != broadcastDimensions[operandQDim])
return emitOptionalError(location, "result quantization_dimension ",
- resultQDim, " not same as broadcast_dimensions ",
- operandQDim, " (",
- broadcastDimensions[operandQDim], ")");
+ resultQDim, " not same as broadcast_dimensions[",
+ operandQDim,
+ "] = ", broadcastDimensions[operandQDim]);
if (operandType.getDimSize(operandQDim) == 1) {
for (int64_t j = 0; j != resultType.getDimSize(resultQDim); ++j) {
if (resultQType.getScales()[j] != operandQType.getScales()[0])
@@ -3760,11 +3784,9 @@ LogicalResult verifyDynamicBroadcastInDimOp(
}
}
- if (!isCompatibleForHloTypeInference(outputDimensions, resultType))
- return emitOptionalError(
- location,
- "output_dimensions are incompatible with return type of operation ",
- resultType);
+ if (failed(verifyShapeOperandIsCompatibleWithResultType(
+ location, outputDimensions, resultType)))
+ return failure();
return success();
}
@@ -3772,17 +3794,17 @@ LogicalResult verifyDynamicBroadcastInDimOp(
LogicalResult verifyDynamicIotaOp(std::optional location,
Value outputShape, int64_t iotaDimension,
Value result) {
- auto shape = cast(result.getType());
+ auto resultType = cast(result.getType());
- if (!isCompatibleForHloTypeInference(outputShape, shape))
- return emitOptionalError(
- location, "output_shape is incompatible with return type of operation ",
- result.getType());
-
- if (iotaDimension >= shape.getRank() || iotaDimension < 0)
+ // dynamic_iota_c1
+ if (iotaDimension >= resultType.getRank() || iotaDimension < 0)
return emitOptionalError(
location,
"iota dimension cannot go beyond the output rank or be negative.");
+ // dynamic_iota_c2
+ if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape,
+ resultType)))
+ return failure();
return success();
}
@@ -3850,18 +3872,9 @@ LogicalResult verifyDynamicReshapeOp(std::optional location,
}
}
- if (SmallVector shape;
- succeeded(matchInts(outputShape, shape)) &&
- !isCompatibleForHloTypeInference(shape, resultType)) {
- std::string str;
- llvm::raw_string_ostream os(str);
- os << "[";
- llvm::interleaveComma(shape, os, [&](int64_t i) { os << i; });
- os << "]";
- return emitOptionalError(location, "output_shape ", os.str(),
- " is incompatible with return type of operation ",
- resultType);
- }
+ if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape,
+ resultType)))
+ return failure();
return success();
}
diff --git a/stablehlo/dialect/VhloOps.td b/stablehlo/dialect/VhloOps.td
index 03cf7b8ec32..3c5617b7fdc 100644
--- a/stablehlo/dialect/VhloOps.td
+++ b/stablehlo/dialect/VhloOps.td
@@ -296,14 +296,6 @@ def VHLO_CompositeOpV1 : VHLO_Op<"composite_v1", "0.19.0", "current"> {
let results = (outs Variadic:$results);
}
-// TODO(#8): ComputeReshapeShapeOp is not part of the StableHLO spec.
-// This operation is a work in progress, so it is not yet included in
-// the StableHLO specification.
-def VHLO_ComputeReshapeShapeOpV1 : VHLO_Op<"compute_reshape_shape_v1", "0.9.0", "current"> {
- let arguments = (ins VHLO_AnyType:$num_elements, VHLO_AnyType:$dynamic_shape);
- let results = (outs VHLO_AnyType:$result);
-}
-
def VHLO_ConcatenateOpV1 : VHLO_Op<"concatenate_v1", "0.9.0", "current"> {
let arguments = (ins
Variadic:$inputs,
@@ -370,14 +362,6 @@ def VHLO_CrossReplicaSumOpV1 : VHLO_Op<"cross-replica-sum_v1", "0.9.0", "current
let results = (outs VHLO_AnyType:$result);
}
-// TODO(#8): CstrReshapableOp is not part of the StableHLO spec.
-// This operation is a work in progress, so it is not yet included in
-// the StableHLO specification.
-def VHLO_CstrReshapableOpV1 : VHLO_Op<"cstr_reshapable_v1", "0.9.0", "current"> {
- let results = (outs VHLO_AnyType:$result);
- let arguments = (ins VHLO_AnyType:$num_elements, VHLO_AnyType:$dynamic_shape);
-}
-
// TODO(#1187): api_version is different between VHLO and the StableHLO spec:
// in VHLO it's an enum, in the spec it's an si32.
// TODO(#629): operand_layouts/result_layouts are not part of the spec.
diff --git a/stablehlo/reference/CMakeLists.txt b/stablehlo/reference/CMakeLists.txt
index ba75441b464..b1d4f8c35e4 100644
--- a/stablehlo/reference/CMakeLists.txt
+++ b/stablehlo/reference/CMakeLists.txt
@@ -95,6 +95,7 @@ add_mlir_dialect_library(InterpreterOps
InterpreterOpsIncGen
LINK_LIBS PUBLIC
+ StablehloBase
StablehloReferenceValue
StablehloReferenceNumPy
StablehloReferenceOps
diff --git a/stablehlo/reference/InterpreterOps.cpp b/stablehlo/reference/InterpreterOps.cpp
index 2e20b86d235..ec502ad88a5 100644
--- a/stablehlo/reference/InterpreterOps.cpp
+++ b/stablehlo/reference/InterpreterOps.cpp
@@ -30,6 +30,7 @@ limitations under the License.
#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Support/LLVM.h"
+#include "stablehlo/dialect/Base.h"
#include "stablehlo/reference/NumPy.h"
#include "stablehlo/reference/Ops.h"
#include "stablehlo/reference/ProcessGrid.h"
diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp
index 7f3c263b773..d98c11b423a 100644
--- a/stablehlo/reference/Ops.cpp
+++ b/stablehlo/reference/Ops.cpp
@@ -579,6 +579,11 @@ SmallVector eval(Region ®ion,
lhs, rhs, lhsBatchingDimensions, rhsBatchingDimensions,
lhsContractingDimensions, rhsContractingDimensions, op.getType());
scope.add(op.getResult(), result);
+ } else if (auto op = dyn_cast(operation)) {
+ auto iotaDimension = op.getIotaDimension();
+ auto outputShape = scope.findTensor(op.getOutputShape());
+ auto result = dynamicIotaOp(iotaDimension, outputShape, op.getType());
+ scope.add(op.getResult(), result);
} else if (auto op = dyn_cast(operation)) {
auto operand = scope.findTensor(op.getOperand());
auto startIndices = scope.findTensors(op.getStartIndices());
@@ -1576,6 +1581,14 @@ Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs,
return result;
}
+Tensor dynamicIotaOp(Axis iotaDimension, const Tensor &outputShape,
+ ShapedType resultType) {
+ if (resultType.hasStaticShape()) return iotaOp(iotaDimension, resultType);
+
+ llvm::report_fatal_error(
+ "dynamic result types are not supported at the moment");
+}
+
Tensor dynamicSliceOp(const Tensor &operand, ArrayRef startIndices,
const Sizes &sliceSizes, ShapedType resultType) {
Tensor result(resultType);
diff --git a/stablehlo/reference/Ops.h b/stablehlo/reference/Ops.h
index af546e45ac9..bb014ffa1a6 100644
--- a/stablehlo/reference/Ops.h
+++ b/stablehlo/reference/Ops.h
@@ -91,6 +91,8 @@ Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs,
const Axes &lhsContractingDimensions,
const Axes &rhsContractingDimensions,
ShapedType resultType);
+Tensor dynamicIotaOp(Axis iotaDimension, const Tensor &outputShape,
+ ShapedType resultType);
Tensor dynamicSliceOp(const Tensor &operand, ArrayRef startIndices,
const Sizes &sliceSizes, ShapedType resultType);
Tensor dynamicUpdateSliceOp(const Tensor &operand, const Tensor &update,
diff --git a/stablehlo/tests/BUILD.bazel b/stablehlo/tests/BUILD.bazel
index 7116e4a05fb..80e7c8030cb 100644
--- a/stablehlo/tests/BUILD.bazel
+++ b/stablehlo/tests/BUILD.bazel
@@ -32,6 +32,7 @@ cc_library(
strip_include_prefix = ".",
deps = [
":check_ops_inc_gen",
+ "//:base",
"//:reference_errors",
"//:reference_numpy",
"//:reference_tensor",
diff --git a/stablehlo/tests/CMakeLists.txt b/stablehlo/tests/CMakeLists.txt
index 676f9cdd925..84246c0ee7f 100644
--- a/stablehlo/tests/CMakeLists.txt
+++ b/stablehlo/tests/CMakeLists.txt
@@ -64,6 +64,7 @@ add_mlir_dialect_library(CheckOps
CheckOpsIncGen
LINK_LIBS PUBLIC
+ StablehloBase
StablehloReferenceNumPy
StablehloReferenceTensor
MLIRIR
diff --git a/stablehlo/tests/CheckOps.cpp b/stablehlo/tests/CheckOps.cpp
index c39d156793f..a4595ef94d5 100644
--- a/stablehlo/tests/CheckOps.cpp
+++ b/stablehlo/tests/CheckOps.cpp
@@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/Support/Path.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/DebugStringHelper.h"
+#include "stablehlo/dialect/Base.h"
#include "stablehlo/reference/Errors.h"
#include "stablehlo/reference/NumPy.h"
#include "stablehlo/reference/Tensor.h"
diff --git a/stablehlo/tests/interpret/dynamic_iota.mlir b/stablehlo/tests/interpret/dynamic_iota.mlir
new file mode 100644
index 00000000000..5f9c79af3bd
--- /dev/null
+++ b/stablehlo/tests/interpret/dynamic_iota.mlir
@@ -0,0 +1,17 @@
+// RUN: stablehlo-translate --interpret -split-input-file %s
+
+func.func @dynamic_iota_op_test_si64_dim_0() {
+ %output_shape = stablehlo.constant dense<[3, 4]> : tensor<2xi64>
+ %0 = stablehlo.dynamic_iota %output_shape, dim = 0 : (tensor<2xi64>) -> tensor<3x4xi64>
+ check.expect_eq_const %0, dense<[[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]> : tensor<3x4xi64>
+ func.return
+}
+
+// -----
+
+func.func @dynamic_iota_op_test_si64_dim_1() {
+ %output_shape = stablehlo.constant dense<[3, 4]> : tensor<2xi64>
+ %0 = stablehlo.dynamic_iota %output_shape, dim = 1 : (tensor<2xi64>) -> tensor<3x4xi64>
+ check.expect_eq_const %0, dense<[[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<3x4xi64>
+ func.return
+}
diff --git a/stablehlo/tests/ops_chlo.mlir b/stablehlo/tests/ops_chlo.mlir
index bb3e77cfea5..16f399f0c37 100644
--- a/stablehlo/tests/ops_chlo.mlir
+++ b/stablehlo/tests/ops_chlo.mlir
@@ -73,33 +73,6 @@ func.func @constant_like(%arg0: tensor<1x2xi64>) -> (tensor<1x2xi32>) {
// -----
-// CHECK-LABEL: func @minimum_broadcast_shapes
-func.func @minimum_broadcast_shapes(%lhs: tensor, %rhs: tensor)
- -> (tensor, tensor) {
- %0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs :
- tensor, tensor -> tensor, tensor
- func.return %0, %1 : tensor, tensor
-}
-
-// -----
-
-func.func @minimum_broadcast_shapes_mismatch_operand_and_result_count(%lhs: tensor, %rhs: tensor) {
- // expected-error @+1{{number of operand shapes (2) does not match number of result shapes (1)}}
- %0 = chlo.minimum_broadcast_shapes %lhs, %rhs :
- tensor, tensor -> tensor
- func.return
-}
-
-// -----
-
-func.func @minimum_broadcast_shapes_one_operand(%arg: tensor) {
- // expected-error @+1{{number of operand shapes (1) should be >= 2}}
- %0 = chlo.minimum_broadcast_shapes %arg : tensor -> tensor
- func.return
-}
-
-// -----
-
func.func @top_k(%arg0 : tensor) {
// expected-error @+2 {{failed to infer returned types}}
// @expected-error @+1{{operand's rank must be at least 1}}
diff --git a/stablehlo/tests/ops_chlo_roundtrip.mlir b/stablehlo/tests/ops_chlo_roundtrip.mlir
index f7dc94cbdee..37b303878ae 100644
--- a/stablehlo/tests/ops_chlo_roundtrip.mlir
+++ b/stablehlo/tests/ops_chlo_roundtrip.mlir
@@ -396,27 +396,6 @@ func.func @chlo_top_k(%arg : tensor<16x16xi32>) -> (tensor<16x8xi32>, tensor<16x
return %1#0, %1#1 : tensor<16x8xi32>, tensor<16x8xi32>
}
-// CHECK-LABEL: func @chlo_minimum_broadcast_shapes(
-// CHECK-SAME: %[[A0:.*]]: tensor,
-// CHECK-SAME: %[[A1:.*]]: tensor
-// CHECK: %[[T:.*]]:2 = chlo.minimum_broadcast_shapes %[[A0]], %[[A1]] : tensor, tensor -> tensor, tensor
-// CHECK: return %[[T]]#0, %[[T]]#1 : tensor, tensor
-func.func @chlo_minimum_broadcast_shapes(%lhs: tensor, %rhs: tensor) -> (tensor, tensor) {
- %0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs :
- tensor, tensor -> tensor, tensor
- func.return %0, %1 : tensor, tensor
-}
-
-// CHECK-LABEL: func @chlo_reshape_dynamic(
-// CHECK-SAME: %[[A0:.*]]: tensor,
-// CHECK-SAME: %[[A1:.*]]: tensor<2xi32>
-// CHECK: %[[T:.*]] = "chlo.dynamic_reshape"(%[[A0]], %[[A1]]) : (tensor, tensor<2xi32>) -> tensor
-// CHECK: return %[[T]] : tensor
-func.func @chlo_reshape_dynamic(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor {
- %0 = "chlo.dynamic_reshape"(%arg0, %arg1) : (tensor, tensor<2xi32>) -> tensor
- func.return %0 : tensor
-}
-
// CHECK-LABEL: func @chlo_erf_inv
// CHECK-SAME: %[[A0:.*0]]: tensor<16x16xf32>)
// CHECK: chlo.erf_inv %[[A0]] : tensor<16x16xf32> -> tensor<16x16xf32>
diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir
index dba67cdc266..22907a942a1 100644
--- a/stablehlo/tests/ops_stablehlo.mlir
+++ b/stablehlo/tests/ops_stablehlo.mlir
@@ -51,18 +51,18 @@ func.func @all_reduce_with_promotable_types(%operand: tensor) -> tensor>)
- -> tensor> {
+ -> tensor> {
%result = "stablehlo.all_reduce"(%operand) ({
- ^bb0(%arg0: tensor>, %arg1: tensor>):
- %0 = stablehlo.add %arg0, %arg1 : tensor>
- "stablehlo.return"(%0) : (tensor>) -> ()
+ ^bb0(%arg0: tensor>, %arg1: tensor>):
+ %0 = stablehlo.add %arg0, %arg1 : tensor>
+ "stablehlo.return"(%0) : (tensor>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle
- } : (tensor>) -> tensor>
+ } : (tensor>) -> tensor>
- func.return %result : tensor>
+ func.return %result : tensor>
}
// -----
@@ -334,16 +334,16 @@ func.func @reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tens
// CHECK-LABEL: func @reduce_scatter_with_promotable_quantized_types
func.func @reduce_scatter_with_promotable_quantized_types(
%data: tensor<4x16x!quant.uniform>) ->
- tensor<4x4x!quant.uniform> {
+ tensor<4x4x!quant.uniform> {
%0 = "stablehlo.reduce_scatter"(%data) ({
- ^bb0(%arg2: tensor>, %arg3: tensor>):
- %1 = stablehlo.add %arg2, %arg3 : tensor>
- "stablehlo.return"(%1) : (tensor>) -> ()
+ ^bb0(%arg2: tensor>, %arg3: tensor>):
+ %1 = stablehlo.add %arg2, %arg3 : tensor>
+ "stablehlo.return"(%1) : (tensor>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 1 : i64,
channel_handle = #stablehlo.channel_handle,
- use_global_device_ids} : (tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform>
- func.return %0 : tensor<4x4x!quant.uniform>
+ use_global_device_ids} : (tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform>
+ func.return %0 : tensor<4x4x!quant.uniform>
}
// -----
@@ -1026,7 +1026,7 @@ func.func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape
// -----
func.func @dynamic_broadcast_in_dim_output_dimensions_negative_size(%arg0: tensor<4xf32>) -> tensor<3x4xf32> {
- // @expected-error@+2 {{output_dimensions are incompatible with return type of operation 'tensor<3x4xf32>'}}
+ // @expected-error@+2 {{output shape [-1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}}
%0 = stablehlo.constant dense<[-1, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32>
return %1 : tensor<3x4xf32>
@@ -1035,7 +1035,7 @@ func.func @dynamic_broadcast_in_dim_output_dimensions_negative_size(%arg0: tenso
// -----
func.func @dynamic_broadcast_in_dim_output_dimensions_mismatching_size(%arg0: tensor<4xf32>) -> tensor<3x4xf32> {
- // @expected-error@+2 {{output_dimensions are incompatible with return type of operation 'tensor<3x4xf32>'}}
+ // @expected-error@+2 {{output shape [1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}}
%0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32>
return %1 : tensor<3x4xf32>
@@ -1043,6 +1043,22 @@ func.func @dynamic_broadcast_in_dim_output_dimensions_mismatching_size(%arg0: te
// -----
+func.func @dynamic_broadcast_in_dim_output_dimensions_match_result(%arg0: tensor<4xf32>) -> tensor<3x4xf32> {
+ %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi64>
+ %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32>
+ return %1 : tensor<3x4xf32>
+}
+
+// -----
+
+func.func @dynamic_broadcast_in_dim_output_dimensions_compatible_with_result(%arg0: tensor<4xf32>) -> tensor {
+ %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi64>
+ %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor
+ return %1 : tensor
+}
+
+// -----
+
func.func @dynamic_broadcast_in_dim_negative_size(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> {
// expected-error@+1 {{broadcast_dimensions contains invalid value -1 for result with rank 3}}
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
@@ -3174,7 +3190,7 @@ func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor, %shape: ten
// -----
func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> {
- // expected-error@+2 {{output_shape [2, 2] is incompatible with return type of operation 'tensor<1x4xf32>'}}
+ // expected-error@+2 {{output shape [2, 2] is incompatible with return type of operation 'tensor<1x4xf32>'}}
%0 = stablehlo.constant dense<[2, 2]> : tensor<2xi64>
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32>
return %1 : tensor<1x4xf32>
@@ -3182,6 +3198,22 @@ func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -
// -----
+func.func @dynamic_reshape_output_shape_matches_result(%arg0: tensor<4xf32>) -> tensor<1x4xf32> {
+ %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64>
+ %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32>
+ return %1 : tensor<1x4xf32>
+}
+
+// -----
+
+func.func @dynamic_reshape_output_shape_compatible_with_result(%arg0: tensor<4xf32>) -> tensor {
+ %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64>
+ %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor
+ return %1 : tensor
+}
+
+// -----
+
func.func @dynamic_reshape_dynamic_output_shape(%arg0: tensor, %shape: tensor) -> tensor<1x4xf32> {
// expected-error@+1 {{op operand #1 must be statically shaped}}
%0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor) -> tensor<1x4xf32>
@@ -4846,9 +4878,9 @@ func.func @quantized_dot_i8_per_axis(%arg0: tensor<2x2x!quant.uniform>, %arg1: tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> {
- %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform>
- func.return %0: tensor<2x2x!quant.uniform>
+func.func @quantized_dot_i4(%arg0: tensor<2x2x!quant.uniform>, %arg1: tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> {
+ %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform>
+ func.return %0: tensor<2x2x!quant.uniform>
}
// -----
@@ -5125,7 +5157,7 @@ func.func @add_c4(%arg0: tensor<1x!quant.uniform>) {
func.func @add_c3(%arg0: tensor<1x!quant.uniform>) {
// expected-error@+1 {{mismatched operands and result quantization storage types}}
- %0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform>
+ %0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform>
func.return
}
@@ -5149,15 +5181,15 @@ func.func @add_c5(%arg0: tensor<1x!quant.uniform>) {
func.func @add_c6(%arg0: tensor<1x2x!quant.uniform>) {
// expected-error@+1 {{quantization_dimension of lhs and result are not same}}
- %0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform>
+ %0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform>
func.return
}
// -----
-func.func @add_c7(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2x!quant.uniform>) {
+func.func @add_c7(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2x!quant.uniform>) {
// expected-error@+1 {{quantization_dimension of rhs and result are not same}}
- %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform>
+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform>
func.return
}
@@ -5431,7 +5463,7 @@ func.func @dynamic_iota_invalid_iota_dimension_too_big() -> tensor {
// -----
func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> {
- // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}}
+ // @expected-error@+2 {{output shape [-1] is incompatible with return type of operation 'tensor<4xf32>'}}
%0 = stablehlo.constant dense<[-1]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32>
func.return %1 : tensor<4xf32>
@@ -5440,7 +5472,7 @@ func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> {
// -----
func.func @dynamic_iota_output_shape_mismatching_size() -> tensor<4xf32> {
- // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}}
+ // @expected-error@+2 {{output shape [1] is incompatible with return type of operation 'tensor<4xf32>'}}
%0 = stablehlo.constant dense<[1]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32>
func.return %1 : tensor<4xf32>
@@ -5448,6 +5480,22 @@ func.func @dynamic_iota_output_shape_mismatching_size() -> tensor<4xf32> {
// -----
+func.func @dynamic_iota_output_shape_matches_result() -> tensor<4xf32> {
+ %0 = stablehlo.constant dense<[4]> : tensor<1xi64>
+ %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32>
+ func.return %1 : tensor<4xf32>
+}
+
+// -----
+
+func.func @dynamic_iota_output_shape_compatible_with_result() -> tensor {
+ %0 = stablehlo.constant dense<[4]> : tensor<1xi64>
+ %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor
+ func.return %1 : tensor
+}
+
+// -----
+
func.func @first(%arg0: tensor, %arg1: tensor) -> tensor {
func.return %arg0 : tensor
}
diff --git a/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/tests/ops_stablehlo_quantized.mlir
index d9a073f2c1b..a4ff81a48dc 100644
--- a/stablehlo/tests/ops_stablehlo_quantized.mlir
+++ b/stablehlo/tests/ops_stablehlo_quantized.mlir
@@ -29,7 +29,7 @@ func.func @ops_per_axis_quantization(
// %arg1 can be a per-axis Quantized
func.func @dot_general_per_axis_quantization(
%arg0: tensor<2x3x4x!quant.uniform>,
- %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> {
+ %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> {
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
@@ -38,8 +38,8 @@ func.func @dot_general_per_axis_quantization(
rhs_contracting_dimensions = [1]
>
} : (tensor<2x3x4x!quant.uniform>,
- tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform>
- func.return %0 : tensor<2x4x5x!quant.uniform>
+ tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform>
+ func.return %0 : tensor<2x4x5x!quant.uniform>
}
// -----
@@ -856,30 +856,31 @@ func.func @broadcast_in_dim_c1_mismatch_zero_point(
// -----
func.func @broadcast_in_dim_c6(
- %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) {
- // expected-error@+1 {{result quantization_dimension 3 not same as broadcast_dimensions 2 (2)}}
+ %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) {
+ // expected-error@+1 {{result quantization_dimension 3 not same as broadcast_dimensions[2] = 2}}
%broadcast_in_dim = "stablehlo.broadcast_in_dim" (%arg0) {broadcast_dimensions = array
- } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.1:-30, 0.1:-30}>>
+ } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) ->
+ tensor<1x2x3x2x!quant.uniform:f32:3, {0.1:-30, 0.1:-30}>>
func.return
}
// -----
func.func @broadcast_in_dim_c6(
- %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) {
+ %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) {
// expected-error@+1 {{mismatch result scale 0 (2.000000e-01) and operand scale 0 (1.000000e-01)}}
%broadcast_in_dim = "stablehlo.broadcast_in_dim" (%arg0) {broadcast_dimensions = array
- } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.2:2, 0.5:-20}>>
+ } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.2:2, 0.5:-20}>>
func.return
}
// -----
func.func @broadcast_in_dim_c6(
- %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) {
+ %arg0: tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) {
// expected-error@+1 {{mismatch result zero_point 1 (-20) and operand zero_point 0 (-30)}}
%broadcast_in_dim = "stablehlo.broadcast_in_dim" (%arg0) {broadcast_dimensions = array
- } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.1:-30, 0.1:-20}>>
+ } : (tensor<1x2x1x!quant.uniform:f32:2, {0.1:-30}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.1:-30, 0.1:-20}>>
func.return
}
@@ -903,28 +904,28 @@ func.func @transpose_c1_mismatched_zp(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) {
+func.func @transpose_c1_mismatched_scales(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) {
// expected-error@+1 {{expect same quantization scales and zero_points}}
%transpose = "stablehlo.transpose"(%arg0) {permutation = array
- } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.9:-20}>>
+ } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) -> tensor<1x2x2x!quant.uniform:f32:0, {0.2:-30}>>
func.return
}
// -----
-func.func @transpose_c1_mismatched_zps(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) {
+func.func @transpose_c1_mismatched_zps(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-20}>>) {
// expected-error@+1 {{expect same quantization scales and zero_points}}
%transpose = "stablehlo.transpose"(%arg0) {permutation = array
- } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-10}>>
+ } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-20}>>) -> tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>
func.return
}
// -----
-func.func @transpose_c4(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) {
- // expected-error@+1 {{operand quantization_dimension 0 is not same as permutation[1] 2}}
+func.func @transpose_c4(%arg0: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) {
+ // expected-error@+1 {{operand quantization_dimension 0 is not same as permutation[1] = 2}}
%transpose = "stablehlo.transpose"(%arg0) {permutation = array
- } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x2x!quant.uniform:f32:1, {0.1:-30, 0.5:-20}>>
+ } : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) -> tensor<1x2x2x!quant.uniform:f32:1, {0.1:-30, 0.5:-20}>>
func.return
}
@@ -956,7 +957,7 @@ func.func @reshape_c1(%arg0: tensor<1x2x2x!quant.uniform>){
func.func @reshape_c3_mismatch_qdim_size(%arg0: tensor<1x2x3x4x5x!quant.uniform>){
// expected-error@+1 {{expect same quantization dimension size for operand and result}}
- %reshape = "stablehlo.reshape" (%arg0) : (tensor<1x2x3x4x5x!quant.uniform>) -> tensor<2x3x20x!quant.uniform>
+ %reshape = "stablehlo.reshape" (%arg0) : (tensor<1x2x3x4x5x!quant.uniform>) -> tensor<2x3x20x!quant.uniform>
func.return
}
@@ -1111,7 +1112,7 @@ func.func @dot_general_c15_per_tensor(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x
func.func @dot_general_c15_per_axis(
%arg0: tensor<2x3x4x!quant.uniform>,
- %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> {
+ %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> {
// expected-error@+1 {{Zero points of rhs should be 0}}
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
@@ -1121,15 +1122,15 @@ func.func @dot_general_c15_per_axis(
rhs_contracting_dimensions = [1]
>
} : (tensor<2x3x4x!quant.uniform>,
- tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform>
- func.return %0 : tensor<2x4x5x!quant.uniform>
+ tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform>
+ func.return %0 : tensor<2x4x5x!quant.uniform>
}
// -----
func.func @dot_general_c16(
%arg0: tensor<2x3x4x!quant.uniform>,
- %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> {
+ %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> {
// expected-error@+1 {{Quantization dimension of rhs should not be in the contracting dimension of rhs}}
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
@@ -1139,8 +1140,8 @@ func.func @dot_general_c16(
rhs_contracting_dimensions = [0]
>
} : (tensor<2x3x4x!quant.uniform>,
- tensor<2x3x5x!quant.uniform>) -> tensor<3x4x5x!quant.uniform>
- func.return %0 : tensor<3x4x5x!quant.uniform>
+ tensor<2x3x5x!quant.uniform>) -> tensor<3x4x5x!quant.uniform>
+ func.return %0 : tensor<3x4x5x!quant.uniform>
}
// -----
@@ -1175,7 +1176,7 @@ func.func @dot_general_c18(%arg0: tensor<2x3x4x!quant.uniform>,
// -----
-func.func @dot_general_c19(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> {
+func.func @dot_general_c19(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> {
// expected-error@+1 {{mismatched rhs and result quantization granularity}}
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
@@ -1184,8 +1185,8 @@ func.func @dot_general_c19(%arg0: tensor<2x3x4x!quant.uniform>,
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [1]
>
- } : (tensor<2x3x4x!quant.uniform>, tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform>
- func.return %0 : tensor<2x4x5x!quant.uniform>
+ } : (tensor<2x3x4x!quant.uniform>, tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform>
+ func.return %0 : tensor<2x4x5x!quant.uniform>
}
// -----
@@ -1202,3 +1203,59 @@ func.func @dot_general_c20(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5x!quant.
} : (tensor<2x3x4xf32>, tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5xf32>
func.return %0 : tensor<2x4x5xf32>
}
+
+// -----
+
+func.func @quantized_element_type_c8(%arg0: tensor<1x2x!quant.uniform:f32, 1.0:300>>) {
+ // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}}
+ %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform:f32, 1.0:300>>
+ func.return
+}
+
+// -----
+
+func.func @quantized_element_type_c8(%arg0: tensor<1x2x!quant.uniform:f32, 1.0:-129>>) {
+ // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}}
+ %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform:f32, 1.0:-129>>
+ func.return
+}
+
+// -----
+
+func.func @quantized_element_type_c5(%arg0: tensor<1x2x!quant.uniform>) {
+ // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}}
+ %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform>
+ func.return
+}
+
+// -----
+
+func.func @quantized_element_type_c5(%arg0: tensor<1x2x!quant.uniform>) {
+ // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}}
+ %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform>
+ func.return
+}
+
+// -----
+
+func.func @quantized_element_type_c12(%arg0: tensor<1x5x2x!quant.uniform:f32:-1, {0.1:-30, 0.1:-30}>>) {
+ // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x5x2x!quant.uniform>'}}
+ %0 = stablehlo.add %arg0, %arg0 : tensor<1x5x2x!quant.uniform:f32:-1, {0.1:-30, 0.1:-30}>>
+ func.return
+}
+
+// -----
+
+func.func @quantized_element_type_c13(%arg0: tensor<1x5x2x!quant.uniform:f32:10, {0.1:-30, 0.1:-30}>>) {
+ // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x5x2x!quant.uniform>'}}
+ %0 = stablehlo.add %arg0, %arg0 : tensor<1x5x2x!quant.uniform