Skip to content

Commit

Permalink
Merge branch 'openxla:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist authored May 8, 2024
2 parents c7a536f + 8ba7728 commit 0d48845
Show file tree
Hide file tree
Showing 67 changed files with 939 additions and 870 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ cc_library(
],
strip_include_prefix = ".",
deps = [
":base",
":interpreter_ops_inc_gen",
":reference_numpy",
":reference_ops",
Expand Down
2 changes: 1 addition & 1 deletion docs/images/spec/gather.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/images/spec/scatter.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 8 additions & 23 deletions docs/interpreter_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,14 @@ interpreter supports resides in [hlo_expand_main.cc](https://github.com/openxla/

### Not in HLO

Apart from the specced ops, this category consists of 10 unspecced ops (see
[StableHLO Ops Categories](#stablehlo-ops-categories)) which are planed to be
moved out of StableHLO. Some of these ops have existing passes in
Apart from the specced ops, this category consists of 8 unspecced ops (see
[StableHLO Ops Categories](#stablehlo-ops-categories)) which are planned to be
moved out of StableHLO. Most of these ops have existing passes in
[mhlo](https://github.com/openxla/xla/tree/main/xla/mlir_hlo/mhlo/transforms) to
convert them to StableHLO equivalent ops. There are three ops the interpreter
does not support because there are no existing decompositions to StableHLO ops:

* `compute_reshape_shape`
* `cstr_reshapable`
* `trace`

`compute_reshape_shape` and `cstr_reshapable` ops are part of the ongoing
Dynamism work, and they are planned to be removed from StableHLO (see
[#1668](https://github.com/openxla/stablehlo/issues/1668)).

`trace` op is private to XLA and there no no users in JAX, PyTorch or TensorFlow
(see [#604](https://github.com/openxla/stablehlo/issues/604)).
convert them to StableHLO equivalent ops. There is one op the interpreter
does not support because there is no existing decomposition to StableHLO ops:
`trace`. `trace` op is private to XLA and there no users in JAX, PyTorch or
TensorFlow (see [#604](https://github.com/openxla/stablehlo/issues/604)).

<!-- markdownlint-disable line-length -->
The tool to convert remaining ops in this category to equivalent StableHLO ops
Expand Down Expand Up @@ -254,18 +245,12 @@ hlo-expand --triangular_solve_expander <path/to/hlo_module>
# broadcast
mlir-hlo-opt -mhlo-legalize-broadcast-to-broadcast-in-dim <path/to/input>

# compute_reshape_shape
# This op will be removed from StableHLO as part of Dynamism work (see #1668).

# create_token
mlir-hlo-opt -mhlo-legalize-create-token-to-after-all <path/to/input>

# cross-replica-sum
mlir-hlo-opt -mhlo-legalize-cross-replica-sum-to-all-reduce <path/to/input>

# cstr_reshapable
# This op will be removed from StableHLO as part of Dynamism work (see #1668).

# dot
mlir-hlo-opt -mhlo-legalize-dot-to-dot-general <path/to/input>

Expand Down Expand Up @@ -295,6 +280,6 @@ mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>
| Extensibility | custom_call, get_tuple_element, tuple | 3 |
| Miscellaneous | batch_norm_grad, batch_norm_inference, batch_norm_training, cholesky, constant, fft, iota, rng, rng_bit_generator, triangular_solve | 10 |
| Modularity | call, func, module, return | 4 |
| Not In HLO | broadcast, compute_reshape_shape, create_token, cross-replica-sum, cstr_reshapable, dot, einsum, torch_index_select, trace, unary_einsum | 10 |
| Not In HLO | broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, trace, unary_einsum | 8 |
| Quantization | uniform_dequantize, uniform_quantize | 2 |
| Reduction | convolution, dot_general, reduce, reduce_window, select_and_scatter | 5 |
53 changes: 48 additions & 5 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ constraints:
* For per-tensor quantization:
* No additional constraints.
* For per-axis quantization:
* (C12) `quantization_dimension < rank(self)`.
* (C13) `dim(self, quantization_dimension) = size(scales)`.
* (C13) `quantization_dimension < rank(self)`.
* (C14) `dim(self, quantization_dimension) = size(scales)`.

```ebnf
TokenType ::= 'token'
Expand Down Expand Up @@ -331,9 +331,8 @@ in StableHLO programs. In the meanwhile, here is the list of these operations:
([#3](https://github.com/openxla/stablehlo/issues/3)), and
`trace` ([#604](https://github.com/openxla/stablehlo/issues/604)).
* "Dynamism" category of StableHLO operations - they were bootstrapped from
MHLO, but we haven't specced them yet: `compute_reshape_shape`,
`cstr_reshapable`, `dynamic_broadcast_in_dim`, `dynamic_conv`,
`dynamic_gather`, `dynamic_iota`, `dynamic_pad`, `dynamic_reshape`,
MHLO,and we are in the process of speccing them: `dynamic_broadcast_in_dim`,
`dynamic_conv`, `dynamic_gather`, `dynamic_iota`, `dynamic_pad`, `dynamic_reshape`,
`real_dynamic_slice`, `set_dimension_size`
([#8](https://github.com/openxla/stablehlo/issues/8)).
* Shape computations, including `arith`, `shape` and `tensor` operations
Expand Down Expand Up @@ -2636,6 +2635,50 @@ planning to address this in

&nbsp;[More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/dot_general.mlir)

### dynamic_iota

#### Semantics

This operation is functionally identical to
[iota](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota)
op, but the result shape is specified dynamically via `output_shape`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|------------------|------------------------------------------------------------------------------------|-------------|
| (I1) | `output_shape` | 1-dimensional tensor constant of type `si64` | (C1), (C2) |
| (I2) | `iota_dimension` | `si64` | (C1) |

#### Outputs

| Name | Type | Constraints |
|----------|-----------------------------------------------------------------------------------|-------------|
| `result` | tensor of integer, floating-point, or complex type or per-tensor quantized tensor | (C2) |

#### Constraints

* (C1) `0 <= iota_dimension < size(output_shape)`.
* (C2) `rank(result) = size(output_shape)`.

#### Examples

```mlir
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
```

&nbsp;[More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/dynamic_iota.mlir)

### dynamic_slice

#### Semantics
Expand Down
4 changes: 1 addition & 3 deletions docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ one of the following tracking labels.
| compare | yes | yes | yes | yes | yes |
| complex | yes | yes | yes | yes | yes |
| composite | yes | yes | infeasible | yes | yes |
| compute_reshape_shape | no | revisit | no | yes | no |
| concatenate | yes | yes | yes | yes | yes |
| constant | yes | yes | yes | yes | yes |
| convert | yes | yes | infeasible | yes | yes |
Expand All @@ -75,15 +74,14 @@ one of the following tracking labels.
| count_leading_zeros | yes | yes | yes | yes | yes |
| create_token | no | yes\* | yes\* | yes | revisit |
| cross-replica-sum | no | revisit | yes\* | no | revisit |
| cstr_reshapable | no | revisit | no | yes | no |
| custom_call | yes | yes | infeasible | yes | yes |
| divide | yes | yes | yes | yes | yes |
| dot | no | revisit | infeasible | yes | revisit |
| dot_general | yes | revisit | infeasible | no | yes |
| dynamic_broadcast_in_dim | no | revisit | infeasible | no | no |
| dynamic_conv | no | revisit | no | no | no |
| dynamic_gather | no | revisit | revisit | no | no |
| dynamic_iota | no | revisit | infeasible | yes | no |
| dynamic_iota | yes | yes | infeasible | yes | revisit |
| dynamic_pad | no | revisit | no | yes | no |
| dynamic_reshape | no | revisit | infeasible | yes | no |
| dynamic_slice | yes | yes | yes | yes | yes |
Expand Down
Loading

0 comments on commit 0d48845

Please sign in to comment.