Skip to content

Commit

Permalink
Remove TraceOp from StableHLO (openxla#2295)
Browse files Browse the repository at this point in the history
Part of openxla#2283. This op doesn't seem to have any uses in frameworks or
compilers. It shouldn't cause any large issues to remove this op all
together, also since this op was never specced, it is exempt from
compatibility guarantees.

In the case that it is used somewhere, I would recommend migration to a
custom_call.
  • Loading branch information
GleasonK authored May 9, 2024
1 parent 9d4f27d commit e5b9c99
Show file tree
Hide file tree
Showing 33 changed files with 7 additions and 184 deletions.
10 changes: 2 additions & 8 deletions docs/interpreter_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@ Apart from the specced ops, this category consists of 8 unspecced ops (see
[StableHLO Ops Categories](#stablehlo-ops-categories)) which are planned to be
moved out of StableHLO. Most of these ops have existing passes in
[mhlo](https://github.com/openxla/xla/tree/main/xla/mlir_hlo/mhlo/transforms) to
convert them to StableHLO equivalent ops. There is one op the interpreter
does not support because there is no existing decomposition to StableHLO ops:
`trace`. `trace` op is private to XLA and there no users in JAX, PyTorch or
TensorFlow (see [#604](https://github.com/openxla/stablehlo/issues/604)).
convert them to StableHLO equivalent ops.

<!-- markdownlint-disable line-length -->
The tool to convert remaining ops in this category to equivalent StableHLO ops
Expand Down Expand Up @@ -260,9 +257,6 @@ mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general <path/to/input>
# torch_index_select
mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather <path/to/input>

# trace
# There are no current users of trace (see #604).

# unary_einsum
mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>
```
Expand All @@ -280,6 +274,6 @@ mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>
| Extensibility | custom_call, get_tuple_element, tuple | 3 |
| Miscellaneous | batch_norm_grad, batch_norm_inference, batch_norm_training, cholesky, constant, fft, iota, rng, rng_bit_generator, triangular_solve | 10 |
| Modularity | call, func, module, return | 4 |
| Not In HLO | broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, trace, unary_einsum | 8 |
| Not In HLO | broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum | 8 |
| Quantization | uniform_dequantize, uniform_quantize | 2 |
| Reduction | convolution, dot_general, reduce, reduce_window, select_and_scatter | 5 |
3 changes: 1 addition & 2 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,7 @@ in StableHLO programs. In the meanwhile, here is the list of these operations:
the StableHLO opset but have been later deemed to not fit it well:
`broadcast`, `create_token`, `cross-replica-sum`, `dot`, `einsum`,
`torch_index_select`, `unary_einsum`
([#3](https://github.com/openxla/stablehlo/issues/3)), and
`trace` ([#604](https://github.com/openxla/stablehlo/issues/604)).
([#3](https://github.com/openxla/stablehlo/issues/3)).
* "Dynamism" category of StableHLO operations - they were bootstrapped from
MHLO,and we are in the process of speccing them: `dynamic_broadcast_in_dim`,
`dynamic_conv`, `dynamic_gather`, `real_dynamic_slice`, `set_dimension_size`.
Expand Down
1 change: 0 additions & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ one of the following tracking labels.
| subtract | yes | yes | yes | yes | yes |
| tanh | yes | yes | yes | yes | yes |
| torch_index_select | no | revisit | no | no | revisit |
| trace | no | revisit | no | yes | revisit |
| transpose | yes | yes | yes | yes | yes |
| triangular_solve | yes | revisit | yes | no | revisit |
| tuple | yes | yes | yes | yes | yes |
Expand Down
22 changes: 0 additions & 22 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2967,28 +2967,6 @@ def StableHLO_PadOp: StableHLO_ShapedInterfaceOp<"pad",
}];
}

def StableHLO_TraceOp: StableHLO_Op<"trace"> {
let summary = "Trace operation";
let description = [{
This operation is on its way out of StableHLO, so it is not included in
the StableHLO specification: https://github.com/openxla/stablehlo/issues/604.

It is not used by JAX, PyTorch or TensorFlow, so it looks like we should've
classified it as "Private to XLA" and not included it in StableHLO in the
first place. With that in mind, its semantics will not be documented here.

Example:
```mlir
stablehlo.trace %arg0, "In test code." : tensor<5x1x5xi32>
```
}];
let arguments = (ins
HLO_Tensor:$operand,
StrAttr:$tag
);
let assemblyFormat = "$operand `,` $tag attr-dict `:` type($operand)";
}

def StableHLO_TransposeOp: StableHLO_ShapedInterfaceOp<"transpose",
[ConditionallySpeculatable, NoMemoryEffect,
HLO_CompatibleOperandsAndResultElementType, /*transpose_c1*/
Expand Down
10 changes: 0 additions & 10 deletions stablehlo/dialect/VhloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1007,16 +1007,6 @@ def VHLO_TorchIndexSelectOpV1 : VHLO_Op<"torch_index_select_v1", "0.9.0", "curre
let results = (outs VHLO_AnyType:$result);
}

// TODO(#3): TraceOp is not part of the StableHLO spec.
// This operation is on its way out of StableHLO, so it is not included in
// the StableHLO specification.
def VHLO_TraceOpV1 : VHLO_Op<"trace_v1", "0.9.0", "current"> {
let arguments = (ins
VHLO_AnyType:$operand,
VHLO_AnyAttr:$tag
);
}

def VHLO_TransposeOpV1 : VHLO_Op<"transpose_v1", "0.9.0", "current"> {
let arguments = (ins
VHLO_AnyType:$operand,
Expand Down
2 changes: 0 additions & 2 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,6 @@ SmallVector<InterpreterValue> eval(Region &region,
scope.add(op.getResult(), result);
} else if (isa<TorchIndexSelectOp>(operation)) {
failOnDecomposableOp(operation);
} else if (isa<TraceOp>(operation)) {
failOnDecomposableOp(operation);
} else if (auto op = dyn_cast<TransposeOp>(operation)) {
auto operand = scope.findTensor(op.getOperand());
auto permutation = Axes(op.getPermutation());
Expand Down
4 changes: 1 addition & 3 deletions stablehlo/tests/print_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ func.func @zero_input() -> !stablehlo.token {

// CHECK-LABEL: func @zero_output_ret2
func.func @zero_output_ret2(%arg0 : tensor<3xi64>) -> (tensor<3xi64>, tensor<3xi64>) {
// CHECK: stablehlo.trace %arg0, "This is a test" : tensor<3xi64>
// CHECK-NEXT: stablehlo.return %arg0, %arg0 : tensor<3xi64>, tensor<3xi64>
"stablehlo.trace"(%arg0) {tag = "This is a test"} : (tensor<3xi64>) -> ()
// CHECK: stablehlo.return %arg0, %arg0 : tensor<3xi64>, tensor<3xi64>
"stablehlo.return"(%arg0, %arg0) : (tensor<3xi64>, tensor<3xi64>) -> ()
}

Expand Down
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1962,17 +1962,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc
Binary file not shown.
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1962,17 +1962,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc
Binary file not shown.
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1962,17 +1962,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc
Binary file not shown.
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1962,17 +1962,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc
Binary file not shown.
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1962,17 +1962,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc
Binary file not shown.
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1968,17 +1968,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc
Binary file not shown.
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1980,17 +1980,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc
Binary file not shown.
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2106,17 +2106,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc
Binary file not shown.
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2106,17 +2106,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc
Binary file not shown.
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2141,17 +2141,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir.bc
Binary file not shown.
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1962,17 +1962,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc
Binary file not shown.
11 changes: 0 additions & 11 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2141,17 +2141,6 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>)
func.return %0 : tensor<2x1x5xf32>
}

// CHECK-LABEL: "op_trace"
func.func @op_trace(%arg0: tensor<f32>) {
// CHECK: "vhlo.trace_v1"(%arg0) <{
// CHECK-SAME: tag = #vhlo.string_v1<"foo">
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
"stablehlo.trace"(%arg0) {
tag = "foo"
} : (tensor<f32>) -> ()
func.return
}

// CHECK-LABEL: "op_transpose"
func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> {
// CHECK: "vhlo.transpose_v1"(%arg0) <{
Expand Down
1 change: 0 additions & 1 deletion stablehlo/transforms/MapStablehloToVhlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ MAP_STABLEHLO_TO_VHLO(SqrtOp, V1)
MAP_STABLEHLO_TO_VHLO(SubtractOp, V1)
MAP_STABLEHLO_TO_VHLO(TanhOp, V1)
MAP_STABLEHLO_TO_VHLO(TorchIndexSelectOp, V1)
MAP_STABLEHLO_TO_VHLO(TraceOp, V1)
MAP_STABLEHLO_TO_VHLO(TransposeOp, V1)
MAP_STABLEHLO_TO_VHLO(TriangularSolveOp, V1)
MAP_STABLEHLO_TO_VHLO(TupleOp, V1)
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/transforms/StablehloInstrumentWithProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ void StablehloInstrumentWithProbePass::runOnOperation() {
bool StablehloInstrumentWithProbePass::shouldProbeOp(Operation& op) const {
if (isa<ConstantOp>(op)) return false;

// Operations that do not produce values should not be instrumented (ReturnOp,
// TraceOp, etc.)
// Operations that do not produce values should not be instrumented
// (ReturnOp, CustomCallOp with no result, etc)
if (op.getNumResults() == 0) return false;

return true;
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct StablehloLegalizeDeprecatedOpsPass final

if (failOnUnusedOps) {
// Deprecated ops to be removed with no replacements
target->addIllegalOp<MapOp, RngOp, TraceOp>();
target->addIllegalOp<MapOp, RngOp>();
}

target->addLegalDialect<StablehloDialect>();
Expand Down

0 comments on commit e5b9c99

Please sign in to comment.