From dea89ec13b518da568154c4cdb0cf79fc1803217 Mon Sep 17 00:00:00 2001 From: Chase Riley Roberts Date: Thu, 30 Nov 2023 11:55:38 +0900 Subject: [PATCH] Integration of collective_broadcast into spec (#1856) This is the first PR for RFC #1809. I did not add an interpreter implementation as @GleasonK specifically asked me to leave that for new staff joining his team. --- docs/spec.md | 72 +- stablehlo/dialect/StablehloAttrs.td | 5 +- stablehlo/dialect/StablehloOps.cpp | 35 +- stablehlo/dialect/StablehloOps.td | 36 + stablehlo/dialect/TypeInference.cpp | 29 + stablehlo/dialect/TypeInference.h | 3 + stablehlo/dialect/Version.h | 2 +- stablehlo/dialect/VhloDialect.td | 1 + stablehlo/dialect/VhloOps.td | 9 + .../stablehlo_legalize_to_vhlo.0_16_0.mlir | 2287 +++++++++++++++++ .../stablehlo_legalize_to_vhlo.0_16_0.mlir.bc | Bin 0 -> 16714 bytes .../tests/stablehlo_legalize_to_vhlo.mlir | 12 + ...o_to_version_downgrade_invalid.0_15_0.mlir | 9 + stablehlo/transforms/MapStablehloToVhlo.h | 1 + .../transforms/StablehloLegalizeToVhlo.cpp | 8 +- .../transforms/VhloLegalizeToStablehlo.cpp | 6 +- 16 files changed, 2492 insertions(+), 23 deletions(-) create mode 100644 stablehlo/tests/stablehlo_legalize_to_vhlo.0_16_0.mlir create mode 100644 stablehlo/tests/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc create mode 100644 stablehlo/tests/vhlo_to_version_downgrade_invalid.0_15_0.mlir diff --git a/docs/spec.md b/docs/spec.md index 2503e7d59ce..ad3dfdc09ca 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -1647,6 +1647,68 @@ for this operation ([#560](https://github.com/openxla/stablehlo/issues/560)).  [More Examples](../stablehlo/tests/interpret_clamp.mlir) +### collective_broadcast + +#### Semantics + +Within each process group in the StableHLO process grid, send the value of the +`operand` tensor from the source process to the target processes and produce a +`result` tensor. + +The operation splits the StableHLO process grid into `process_groups` which is +defined as follows: + +* `cross_replica(replica_groups)` if `channel_id <= 0`. +* `cross_partition(replica_groups)` if `channel_id > 0`. + +Afterwards, `result@process` is given by: + +* `operand@process_groups[i, 0]` if there exists an `i` such that the process is + in `process_groups[i]`. +* `broadcast_in_dim(constant(0, element_type(result)), [], type(result))` + otherwise. + +#### Inputs + +| Label | Name | Type | Constraints | +|-------|-------------------------|------------------------------------------------------------------|-------------| +| (I1) | `operand` | tensor | (C3) | +| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1), (C2) | +| (I3) | `channel_id` | constant of type `si64` | | + +#### Outputs + +| Name | Type | Constraints | +|----------|--------|-------------| +| `result` | tensor | (C3) | + +#### Constraints + +* (C1) `is_unique(replica_groups)`. +* (C2) `0 <= replica_groups < N` where `N` is defined as: + * `num_replicas` if `cross_replica` is used. + * `num_partitions` if `cross_partition` is used. +* (C3) `type(result) = type(operand)`. + +#### Examples + +```mlir +// num_replicas: 4 +// num_partitions: 1 +// %operand@(0, 0): [[1, 2]] +// %operand@(1, 0): [[3, 4]] +// %operand@(2, 0): [[5, 6]] +// %operand@(3, 0): [[7, 8]] +%result = "stablehlo.collective_broadcast"(%operand) { + replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle +} : (tensor1x2xi64>) -> tensor<1x2xi64> +// %result@(0, 0): [[0, 0]] +// %result@(1, 0): [[5, 6]] +// %result@(2, 0): [[5, 6]] +// %result@(3, 0): [[0, 0]] +``` + ### collective_permute #### Semantics @@ -5949,11 +6011,11 @@ order and what kind of synchronization is introduced by it, is TBD ### Collective ops -There are five collective ops in StableHLO: `all_gather`, `all_reduce`, -`all_to_all`, `collective_permute` and `reduce_scatter`. All these ops split -the processes in the StableHLO process grid into **StableHLO process groups** -and execute a joint computation within each process group, independently from -other process groups. +There are six collective ops in StableHLO: `all_gather`, `all_reduce`, +`all_to_all`, `collective_broadcast`, `collective_permute`, and +`reduce_scatter`. All these ops split the processes in the StableHLO process +grid into **StableHLO process groups** and execute a joint computation within +each process group, independently from other process groups. Within each process group, collective ops may introduce a synchronization barrier. Further formalization, e.g. elaborating on when exactly this diff --git a/stablehlo/dialect/StablehloAttrs.td b/stablehlo/dialect/StablehloAttrs.td index cd8556748a3..e8fdcd2c8e4 100644 --- a/stablehlo/dialect/StablehloAttrs.td +++ b/stablehlo/dialect/StablehloAttrs.td @@ -109,8 +109,9 @@ def StableHLO_OutputOperandAlias : AttrDef { let cppNamespace = "::mlir::stablehlo"; let mnemonic = "channel_handle"; diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index d11b11c09c2..cb9bde80f17 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -122,17 +122,6 @@ LogicalResult TypeExtensionsAttr::verifyEncoding( getBounds(), RankedTensorType::get(shape, elementType), emitError); } -//===----------------------------------------------------------------------===// -// CollectivePermuteOp -//===----------------------------------------------------------------------===// - -void CollectivePermuteOp::build(OpBuilder& odsBuilder, OperationState& odsState, - Type resultType, Value operand, - DenseIntElementsAttr sourceTargetPairs) { - CollectivePermuteOp::build(odsBuilder, odsState, resultType, operand, - sourceTargetPairs, /*channel_handle=*/nullptr); -} - //===----------------------------------------------------------------------===// // ReduceScatterOp //===----------------------------------------------------------------------===// @@ -171,6 +160,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ClzOp) +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectiveBroadcastOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectivePermuteOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp) @@ -807,10 +797,33 @@ LogicalResult AbsOp::inferReturnTypes( return hlo::inferAbsOp(location, adaptor.getOperand(), inferredReturnTypes); } +//===----------------------------------------------------------------------===// +// CollectiveBroadcastOp +//===----------------------------------------------------------------------===// + +void CollectiveBroadcastOp::build(OpBuilder& odsBuilder, + OperationState& odsState, Type resultType, + Value operand, + DenseIntElementsAttr replica_groups) { + CollectiveBroadcastOp::build(odsBuilder, odsState, resultType, operand, + replica_groups, /*channel_handle=*/nullptr); +} + +LogicalResult CollectiveBroadcastOp::verify() { + return hlo::verifyCollectiveBroadcastOp(getLoc(), getReplicaGroups()); +} + //===----------------------------------------------------------------------===// // CollectivePermuteOp //===----------------------------------------------------------------------===// +void CollectivePermuteOp::build(OpBuilder& odsBuilder, OperationState& odsState, + Type resultType, Value operand, + DenseIntElementsAttr sourceTargetPairs) { + CollectivePermuteOp::build(odsBuilder, odsState, resultType, operand, + sourceTargetPairs, /*channel_handle=*/nullptr); +} + LogicalResult CollectivePermuteOp::verify() { return hlo::verifyCollectivePermuteOp(getLoc(), getSourceTargetPairs()); } diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 776a599a926..399ecc635d3 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2020,6 +2020,42 @@ def StableHLO_ConcatenateOp : StableHLO_ShapedInterfaceOp<"concatenate", }]; } + +def StableHLO_CollectiveBroadcastOp: StableHLO_Op<"collective_broadcast", + [HLO_CompatibleOperandsAndResultType /*collective_broadcast_c3*/]> { + let summary = "CollectiveBroadcast operation"; + let description = [{ + Within each process group in the process grid, send the value of the + `operand` tensor from the source process to the target processes and produce a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_broadcast + + Example: + ```mlir + %result = "stablehlo.collective_broadcast"(%operand) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + ``` + }]; + + let arguments = (ins + HLO_Tensor:$operand, /*collective_broadcast_i1*/ + I64ElementsAttr:$replica_groups, /*collective_broadcast_i2*/ + OptionalAttr:$channel_handle /*collective_broadcast_i3*/ + ); + let results = (outs HLO_Tensor); + let hasVerifier = 1; + // channel_handle is only used for the SPMD partitioner, so we add a + // simplified builder method for convenience. + let builders = [ + OpBuilder<(ins + "::mlir::Type":$result_type, "::mlir::Value":$operand, + "::mlir::DenseIntElementsAttr":$replica_groups)>]; +} + def StableHLO_CollectivePermuteOp: StableHLO_Op<"collective_permute", [HLO_CompatibleOperandsAndResultType /*collective_permute_c5*/]> { let summary = "CollectivePermute operation"; diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 77fae7b64aa..7dbfbd260da 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -3289,6 +3289,35 @@ LogicalResult verifyBroadcastInDimOp(std::optional location, return success(); } +LogicalResult verifyCollectiveBroadcastOp(std::optional location, + DenseIntElementsAttr replicaGroups) { + // collective_permute_i2 + auto replicaGroupType = replicaGroups.getType().cast(); + if (replicaGroupType.getRank() != 2) + return emitOptionalError( + location, "replica groups should be a rank 2 tensor,", + "but instead it is of rank ", replicaGroupType.getRank()); + + auto replicaIds = replicaGroups.getValues(); + llvm::SmallSet replicaIdsSeen; + for (int64_t replicaId : replicaIds) { + // collective_broadcast_c2 + // We only check that is is not negative, as it is impossible + // to statically know `num_replicas` or `num_partitions` + if (replicaId < 0) + return emitOptionalError( + location, "replica_groups values must be positive, but was given ", + replicaId); + + // collective_broadcast_c1 + if (!replicaIdsSeen.insert(replicaId).second) + return emitOptionalError(location, "replica id #", replicaId, + " seen more than once"); + } + + return success(); +} + LogicalResult verifyCollectivePermuteOp( std::optional location, DenseIntElementsAttr sourceTargetPairs) { auto type = sourceTargetPairs.getType().dyn_cast(); diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index b83573fb759..3c9dd83e326 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -388,6 +388,9 @@ LogicalResult verifyBroadcastInDimOp(std::optional location, DenseIntElementsAttr broadcastDimensions, Value result); +LogicalResult verifyCollectiveBroadcastOp(std::optional location, + DenseIntElementsAttr replicaGroups); + LogicalResult verifyCollectivePermuteOp(std::optional location, DenseIntElementsAttr sourceTargetPairs); diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h index 07e85dcb965..0fdbae45979 100644 --- a/stablehlo/dialect/Version.h +++ b/stablehlo/dialect/Version.h @@ -38,7 +38,7 @@ class Version { static FailureOr fromString(llvm::StringRef versionRef); /// Return a Version representing the current VHLO dialect version. - static Version getCurrentVersion() { return Version(0, 15, 5); } + static Version getCurrentVersion() { return Version(0, 16, 0); } /// Return a Version representing the minimum supported VHLO dialect version. static Version getMinimumVersion() { return Version(0, 9, 0); } diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td index 7503164829d..c38a57f02f8 100644 --- a/stablehlo/dialect/VhloDialect.td +++ b/stablehlo/dialect/VhloDialect.td @@ -33,6 +33,7 @@ def VHLO_Dialect : Dialect { 0.12.0: MLIR bytecode version 1 => 3. 0.14.0: MLIR bytecode version 3 => 5 (revised to 4 in #1827). 0.15.0: MLIR bytecode version 5 => 6, use properties in VHLO. + 0.16.0: Introduce `collective_broadcast` operation. }]; let useDefaultAttributePrinterParser = 0; diff --git a/stablehlo/dialect/VhloOps.td b/stablehlo/dialect/VhloOps.td index f99ef309957..0456360e199 100644 --- a/stablehlo/dialect/VhloOps.td +++ b/stablehlo/dialect/VhloOps.td @@ -240,6 +240,15 @@ def VHLO_ClzOpV1 : VHLO_Op<"count_leading_zeros_v1", "0.9.0", "current"> { let results = (outs VHLO_AnyType:$result); } +def VHLO_CollectiveBroadcastOpV1 : VHLO_Op<"collective_broadcast_v1", "0.16.0", "current"> { + let arguments = (ins + VHLO_AnyType:$operand, + VHLO_AnyAttr:$replica_groups, + VHLO_AnyAttr:$channel_id + ); + let results = (outs VHLO_AnyType:$result); +} + def VHLO_CollectivePermuteOpV1 : VHLO_Op<"collective_permute_v1", "0.9.0", "current"> { let arguments = (ins VHLO_AnyType:$operand, diff --git a/stablehlo/tests/stablehlo_legalize_to_vhlo.0_16_0.mlir b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_16_0.mlir new file mode 100644 index 00000000000..9d4b9ee0dda --- /dev/null +++ b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_16_0.mlir @@ -0,0 +1,2287 @@ +// RUN: stablehlo-opt --mlir-print-op-generic %s.bc | FileCheck %s +// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-translate --serialize --target=0.16.0 | stablehlo-opt --mlir-print-op-generic | FileCheck %s +// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-opt > %t.0 +// RUN: stablehlo-opt --strip-debuginfo %s > %t.1 +// RUN: diff %t.0 %t.1 +// RUN: stablehlo-translate --serialize --target=0.16.0 --strip-debuginfo %s > %t.2 +// RUN: diff %s.bc %t.2 + + +// ============ ATTRIBUTES ============ + +// CHECK-LABEL: "attr_comparison_direction_eq" +func.func @attr_comparison_direction_eq(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_ne" +func.func @attr_comparison_direction_ne(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_ge" +func.func @attr_comparison_direction_ge(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_gt" +func.func @attr_comparison_direction_gt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_le" +func.func @attr_comparison_direction_le(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_lt" +func.func @attr_comparison_direction_lt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_notype" +func.func @attr_comparison_type_notype(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + // CHECK: compare_type = #vhlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_float" +func.func @attr_comparison_type_float(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_totalorder" +func.func @attr_comparison_type_totalorder(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_signed" +func.func @attr_comparison_type_signed(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_unsigned" +func.func @attr_comparison_type_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ConvDimensionNumbers aka #stablehlo.conv is covered below. + +// CHECK-LABEL: "attr_custom_call_api_version_unspecified" +func.func @attr_custom_call_api_version_unspecified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 0 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_original" +func.func @attr_custom_call_api_version_original(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 1 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_status_returning" +func.func @attr_custom_call_api_version_status_returning(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 2 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_status_returning_unified" +func.func @attr_custom_call_api_version_status_returning_unified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 3 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_dict" +// CHECK: #vhlo.dict_v1<{#vhlo.string_v1<"attr1"> = #vhlo.integer_v1<1 : i32>, #vhlo.string_v1<"attr2"> = #vhlo.integer_v1<2 : i32>} +func.func @attr_dict() attributes {stablehlo.attr = {attr1 = 1 : i32, attr2 = 2 : i32}} { + return +} + +// DotDimensionNumbers aka #stablehlo.dot is covered below. + +// CHECK-LABEL: "attr_fft_type_fft" +func.func @attr_fft_type_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_ifft" +func.func @attr_fft_type_ifft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_rfft" +func.func @attr_fft_type_rfft(%arg0: tensor<16xf32>) -> tensor<9xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<9xcomplex> + func.return %0 : tensor<9xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_irfft" +func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<9xcomplex>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// GatherDimensionNumbers aka #stablehlo.gather is covered below. + +// CHECK-LABEL: "attr_precision_config_default" +func.func @attr_precision_config_default(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_precision_config_high" +func.func @attr_precision_config_high(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_precision_config_highest" +func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_rng_algorithm_default" +func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_algorithm_three_fry" +func.func @attr_rng_algorithm_three_fry(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_algorithm_philox" +func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_distribution_uniform" +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_rng_distribution_normal" +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ScatterDimensionNumbers aka #stablehlo.scatter is covered below. + +// CHECK-LABEL: "attr_transpose_no_transpose" +func.func @attr_transpose_no_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "attr_transpose_transpose" +func.func @attr_transpose_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "attr_transpose_adjoint" +func.func @attr_transpose_adjoint(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// TypeExtensionsAttr aka #stablehlo.type_extensions is covered below. + +// CHECK-LABEL: "attr_type_extensions_bounds" +func.func @attr_type_extensions_bounds( + %arg0: tensor>) + -> tensor> { + // CHECK: "vhlo.return_v1"(%arg0) : (!vhlo.tensor_v1>) -> () + func.return %arg0 : tensor> +} + + +// ============ DEFAULTS ============ + +// CHECK-LABEL: "default_all_gather" +func.func @default_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v1"(%arg0) <{ + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "default_all_reduce" +func.func @default_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v1"(%arg0) + // CHECK-SAME: <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_all_to_all" +func.func @default_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// CHECK-LABEL: "default_cholesky" +func.func @default_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%arg0) <{ + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} + +// CHECK-LABEL: "default_collective_permute" +func.func @default_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "default_collective_broadcast" +func.func @default_collective_broadcast(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_broadcast_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "default_compare" +func.func @default_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%arg0, %arg1) <{ + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_convolution" +func.func @default_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> { + // CHECK: "vhlo.convolution_v1"(%arg0, %arg1) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x6x6x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> + func.return %0 : tensor<1x6x6x16xf32> +} + +// CHECK-LABEL: "default_custom_call" +func.func @default_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%arg0) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_dot_general" +func.func @default_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v1"(%arg0, %arg1) <{ + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "default_dot" +func.func @default_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.dot_v1"(%arg0, %arg1) <{ + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "default_dynamic_broadcast_in_dim" +func.func @default_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%arg0, %arg1) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_dynamic_conv" +func.func @default_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} + +// CHECK-LABEL: "default_dynamic_gather" +func.func @default_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + > + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +func.func @default_func(%arg0: tensor) -> tensor { + // CHECK: "vhlo.func_v1"() <{ + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"default_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<""> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%arg0: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%arg0) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : () -> () + func.return %arg0 : tensor +} + +// CHECK-LABEL: "dynamic_gather" +func.func @dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v1"(%arg0, %arg1) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = dense<1> : tensor<3xi64> + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "default_infeed" +func.func @default_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%arg0) <{ + // CHECK-SAME: infeed_config = #vhlo.string_v1<"">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "default_outfeed" +func.func @default_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%arg0, %arg1) <{ + // CHECK-SAME: outfeed_config = #vhlo.string_v1<""> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "default_recv" +func.func @default_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "default_send" +func.func @default_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%arg0, %arg1) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "default_reduce_scatter" +func.func @default_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "default_reduce_window" +func.func @default_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%arg0, %arg1) <{ + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x16x30x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> + func.return %0 : tensor<2x16x30x7xf32> +} + +// CHECK-LABEL: "default_scatter" +func.func @default_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + > + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} + +// CHECK-LABEL: "default_select_and_scatter" +func.func @default_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<10x23x23x64x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} + +// CHECK-LABEL: "default_sort" +func.func @default_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%arg0) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<-1 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// ============ OPS ============ + +// CHECK-LABEL: "op_abs" +func.func @op_abs(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_add" +func.func @op_add(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_after_all" +func.func @op_after_all(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.after_all_v1"(%arg0) : (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.after_all"(%arg0) : (!stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_all_gather" +func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v1"(%arg0) <{ + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_all_reduce" +func.func @op_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_all_to_all" +func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// CHECK-LABEL: "op_and" +func.func @op_and(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_atan2" +func.func @op_atan2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.atan2_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_batch_norm_grad" +func.func @op_batch_norm_grad(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_grad_v1"(%arg0, %arg1, %arg2, %arg3, %arg4) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} + +// CHECK-LABEL: "op_batch_norm_inference" +func.func @op_batch_norm_inference(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>) -> tensor<16x16x16x16xf32> { + // CHECK: "vhlo.batch_norm_inference_v1"(%arg0, %arg1, %arg2, %arg3, %arg4) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1> + %0 = "stablehlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> + func.return %0 : tensor<16x16x16x16xf32> +} + +// CHECK-LABEL: "op_batch_norm_training" +func.func @op_batch_norm_training(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_training_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_training"(%arg0, %arg1, %arg2) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} + +// CHECK-LABEL: "op_bitcast_convert" +func.func @op_bitcast_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.bitcast_convert_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.bitcast_convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_broadcast_in_dim" +func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_in_dim_v1"(%arg0) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<1> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_broadcast" +func.func @op_broadcast(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_v1"(%arg0) <{ + // CHECK-SAME: broadcast_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast"(%arg0) { + broadcast_sizes = dense<16> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_case" +func.func @op_case(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%arg0) ({ + // CHECK-NEXT: "vhlo.return_v1"(%arg1) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cbrt" +func.func @op_cbrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cbrt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cbrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_ceil" +func.func @op_ceil(%arg0: tensor) -> tensor { + // CHECK: "vhlo.ceil_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.ceil"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cholesky" +func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%arg0) <{ + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) { + lower = true + } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} + +// CHECK-LABEL: "op_clamp" +func.func @op_clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.clamp_v1"(%arg0, %arg1, %arg2) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_count_leading_zeros" +func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { + // CHECK: "vhlo.count_leading_zeros_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.count_leading_zeros"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_collective_permute" +func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "op_compare" +func.func @op_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%arg0, %arg1) <{ + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_complex" +func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> { + // CHECK: "vhlo.complex_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.complex"(%arg0, %arg1) : (tensor, tensor) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "op_compute_reshape_shape" +func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { + // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> + %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> + func.return %0 : tensor<1xindex> +} + +// CHECK-LABEL: "op_concatenate" +func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1<8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.concatenate"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_constant" +func.func @op_constant(%arg0: tensor) -> tensor { + // CHECK: "vhlo.constant_v1"() <{ + // CHECK-SAME: value = #vhlo.tensor_v1 : tensor> + // CHECK-SAME: }> : () -> !vhlo.tensor_v1 + %0 = "stablehlo.constant"() { + value = dense<0.0> : tensor + } : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_convert" +func.func @op_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.convert_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_convolution" +func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> { + // CHECK: "vhlo.convolution_v1"(%arg0, %arg1) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x7x7x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = dense<2> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, + rhs_dilation = dense<2> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> + func.return %0 : tensor<1x7x7x16xf32> +} + +// CHECK-LABEL: "op_cosine" +func.func @op_cosine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cosine_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cosine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_create_token" +func.func @op_create_token() -> !stablehlo.token { + // CHECK: "vhlo.create_token_v1"() : () -> !vhlo.token_v1 + %0 = "stablehlo.create_token"() : () -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_cross_replica_sum" +func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cross-replica-sum_v1"(%arg0) <{ + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cross-replica-sum"(%arg0) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cstr_reshapable" +func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { + // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 + %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness + func.return %0 : !shape.witness +} + +// CHECK-LABEL: "op_custom_call" +func.func @op_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%arg0) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"\08\03\1A\02">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[#vhlo.string_v1<"foo">]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[ + // CHECK-SAME: #vhlo.output_operand_alias_v1< + // CHECK-SAME: outputTupleIndices = [], + // CHECK-SAME: operandIndex = 0, + // CHECK-SAME: operandTupleIndices = []>]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + has_side_effect = true, + backend_config = "\08\03\1A\02", + api_version = 2 : i32, + called_computations = [@foo], + operand_layouts = [dense<> : tensor<0xindex>], + output_operand_aliases = [ + #stablehlo.output_operand_alias], + result_layouts = [dense<> : tensor<0xindex>] + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_divide" +func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.divide_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dot_general" +func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v1"(%arg0, %arg1) <{ + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "op_dot" +func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.dot_v1"(%arg0, %arg1) <{ + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot"(%arg0, %arg1) { + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "op_dynamic_broadcast_in_dim" +func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%arg0, %arg1) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, + known_expanding_dimensions = dense<0> : tensor<1xi64>, + known_nonexpanding_dimensions = dense<1> : tensor<1xi64> + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_conv" +func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + window_strides = dense<2> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, + rhs_dilation = dense<2> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} + +// CHECK-LABEL: "op_dynamic_gather" +func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +// CHECK-LABEL: "op_dynamic_iota" +func.func @op_dynamic_iota(%arg0: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_iota_v1"(%arg0) <{ + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_iota"(%arg0) { + iota_dimension = 0 : i64 + } : (tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_pad" +func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>, %arg4: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_pad_v1"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_reshape" +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_slice" +func.func @op_dynamic_slice(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor<4xf32> { + // CHECK: "vhlo.dynamic_slice_v1"(%arg0, %arg1) <{ + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) { + slice_sizes = dense<4> : tensor<1xi64> + } : (tensor<16xf32>, tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK-LABEL: "op_dynamic_update_slice" +func.func @op_dynamic_update_slice(%arg0: tensor<16xf32>, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.dynamic_update_slice_v1"(%arg0, %arg1, %arg2) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<16xf32>, tensor<4xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_einsum" +func.func @op_einsum(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.einsum_v1"(%arg0, %arg1) <{ + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab,bc->ac"> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.einsum"(%arg0, %arg1) { + einsum_config = "ab,bc->ac" + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "op_exponential_minus_one" +func.func @op_exponential_minus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_minus_one_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential_minus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_exponential" +func.func @op_exponential(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_fft" +func.func @op_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + // CHECK: "vhlo.fft_v1"(%arg0) <{ + // CHECK-SAME: fft_length = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: fft_type = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.complex_v1>) -> !vhlo.tensor_v1<16x!vhlo.complex_v1> + %0 = "stablehlo.fft"(%arg0) { + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "op_floor" +func.func @op_floor(%arg0: tensor) -> tensor { + // CHECK: "vhlo.floor_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.floor"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +func.func private @op_func(%arg0: tensor {stablehlo.arg = "0"}) -> (tensor {stablehlo.result = "0"}) { + // CHECK: "vhlo.func_v1"() <{ + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.arg"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.result"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"op_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<"private"> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%arg0: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%arg0) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : () -> () + + func.return %arg0 : tensor +} + +// CHECK-LABEL: "op_gather" +func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v1"(%arg0, %arg1) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = dense<1> : tensor<3xi64>, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "op_get_dimension_size" +func.func @op_get_dimension_size(%arg0: tensor) -> tensor { + // CHECK: "vhlo.get_dimension_size_v1"(%arg0) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_dimension_size"(%arg0) { + dimension = 0 : i64 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_get_tuple_element" +func.func @op_get_tuple_element(%arg0: tuple, tensor>) -> tensor { + // CHECK: "vhlo.get_tuple_element_v1"(%arg0) <{ + // CHECK-SAME: index = #vhlo.integer_v1<0 : i32> + // CHECK-SAME: }> : (!vhlo.tuple_v1, !vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_tuple_element"(%arg0) { + index = 0 : i32 + } : (tuple, tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_if" +func.func @op_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.if_v1"(%arg0) ({ + // CHECK-NEXT: "vhlo.return_v1"(%arg1) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: "vhlo.return_v1"(%arg2) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.if"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + "stablehlo.return"(%arg2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_imag" +func.func @op_imag(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.imag_v1"(%arg0) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.imag"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_infeed" +func.func @op_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%arg0) <{ + // CHECK-SAME: infeed_config = #vhlo.string_v1<"foo">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[#vhlo.array_v1<[]>]> + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) { + infeed_config = "foo", + layout = [[]] + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "op_iota" +func.func @op_iota() -> tensor<16xf32> { + // CHECK: "vhlo.iota_v1"() <{ + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : () -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.iota"() { + iota_dimension = 0 : i64 + } : () -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_is_finite" +func.func @op_is_finite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.is_finite_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.is_finite"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_log" +func.func @op_log(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_log_plus_one" +func.func @op_log_plus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_plus_one_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log_plus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_logistic" +func.func @op_logistic(%arg0: tensor) -> tensor { + // CHECK: "vhlo.logistic_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.logistic"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_map" +func.func @op_map(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.map_v1"(%arg0) <{ + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.abs_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.map"(%arg0) ({ + ^bb0(%arg1: tensor): + %1 = "stablehlo.abs"(%arg1) : (tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = dense<0> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_maximum" +func.func @op_maximum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.maximum_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_minimum" +func.func @op_minimum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.minimum_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_multiply" +func.func @op_multiply(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.multiply_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_negate" +func.func @op_negate(%arg0: tensor) -> tensor { + // CHECK: "vhlo.negate_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.negate"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_not" +func.func @op_not(%arg0: tensor) -> tensor { + // CHECK: "vhlo.not_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.not"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_optimization_barrier" +func.func @op_optimization_barrier(%arg0: tensor) -> tensor { + // CHECK: "vhlo.optimization_barrier_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.optimization_barrier"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_or" +func.func @op_or(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.or_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.or"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_outfeed" +func.func @op_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%arg0, %arg1) <{ + // CHECK-SAME: outfeed_config = #vhlo.string_v1<"foo"> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) { + outfeed_config = "foo" + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_pad" +func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.pad_v1"(%arg0, %arg1) <{ + // CHECK-SAME: edge_padding_high = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: edge_padding_low = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: interior_padding = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.pad"(%arg0, %arg1) { + edge_padding_high = dense<4> : tensor<1xi64>, + edge_padding_low = dense<4> : tensor<1xi64>, + interior_padding = dense<0> : tensor<1xi64> + } : (tensor<8xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_popcnt" +func.func @op_popcnt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.popcnt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.popcnt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_power" +func.func @op_power(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.power_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.power"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_real_dynamic_slice" +func.func @op_real_dynamic_slice(%arg0: tensor, %arg1: tensor<1xindex>, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.real_dynamic_slice_v1"(%arg0, %arg1, %arg2, %arg3) : (!vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_real" +func.func @op_real(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.real_v1"(%arg0) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_recv" +func.func @op_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "op_reduce" +func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = dense<0> : tensor<1xi64> + } : (tensor<16xf32>, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reduce_precision" +func.func @op_reduce_precision(%arg0: tensor) -> tensor { + // CHECK: "vhlo.reduce_precision_v1"(%arg0) <{ + // CHECK-SAME: exponent_bits = #vhlo.integer_v1<8 : i32> + // CHECK-SAME: mantissa_bits = #vhlo.integer_v1<10 : i32> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.reduce_precision"(%arg0) { + exponent_bits = 8 : i32, + mantissa_bits = 10 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reduce_scatter" +func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_reduce_window" +func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x9x16x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%arg0, %arg1) <{ + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x9x16x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>, + base_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x9x16x7xf32> + func.return %0 : tensor<2x9x16x7xf32> +} + +// CHECK-LABEL: "op_remainder" +func.func @op_remainder(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.remainder_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.remainder"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_replica_id" +func.func @op_replica_id() -> tensor { + // CHECK: "vhlo.replica_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.replica_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_partition_id" +func.func @op_partition_id() -> tensor { + // CHECK: "vhlo.partition_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.partition_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reshape" +func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { + // CHECK: "vhlo.reshape_v1"(%arg0) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f32_v1> + %0 = "stablehlo.reshape"(%arg0) : (tensor<16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// CHECK-LABEL: "op_return" +func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%arg0) ({ + // CHECK-NEXT: "vhlo.return_v1"(%arg1) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reverse" +func.func @op_reverse(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reverse_v1"(%arg0) <{ + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reverse"(%arg0) { + dimensions = dense<0> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_rng_bit_generator" +func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor) { + // CHECK: "vhlo.rng_bit_generator_v1"(%arg0) <{ + // CHECK-SAME: rng_algorithm = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> (!vhlo.tensor_v1, !vhlo.tensor_v1) + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "op_rng" +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: rng_distribution = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_round_nearest_afz" +func.func @op_round_nearest_afz(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_afz_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_afz"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_round_nearest_even" +func.func @op_round_nearest_even(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_even_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_even"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_rsqrt" +func.func @op_rsqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.rsqrt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.rsqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_scatter" +func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} + +// CHECK-LABEL: "op_select_and_scatter" +func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<1> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} + +// CHECK-LABEL: "op_select" +func.func @op_select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.select_v1"(%arg0, %arg1, %arg2) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_send" +func.func @op_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%arg0, %arg1) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_set_dimension_size" +func.func @op_set_dimension_size(%arg0: tensor, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.set_dimension_size_v1"(%arg0, %arg1) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.set_dimension_size"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_shift_left" +func.func @op_shift_left(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_left_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_left"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_shift_right_arithmetic" +func.func @op_shift_right_arithmetic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_arithmetic_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_arithmetic"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_shift_right_logical" +func.func @op_shift_right_logical(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_logical_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_logical"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_sign" +func.func @op_sign(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sign_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sign"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_sine" +func.func @op_sine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sine_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_slice" +func.func @op_slice(%arg0: tensor<16xf32>) -> tensor<4xf32> { + // CHECK: "vhlo.slice_v1"(%arg0) <{ + // CHECK-SAME: limit_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: strides = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.slice"(%arg0) { + start_indices = dense<0> : tensor<1xi64>, + limit_indices = dense<4> : tensor<1xi64>, + strides = dense<1> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK-LABEL: "op_sort" +func.func @op_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%arg0) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimension = 0 : i64, + is_stable = true + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_sqrt" +func.func @op_sqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sqrt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_subtract" +func.func @op_subtract(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.subtract_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.subtract"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_tanh" +func.func @op_tanh(%arg0: tensor) -> tensor { + // CHECK: "vhlo.tanh_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.tanh"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_torch_index_select" +func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { + // CHECK: "vhlo.torch_index_select_v1"(%arg0, %arg1) <{ + // CHECK-SAME: batch_dims = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: dim = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<5x1x5x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<2x1x5x!vhlo.f32_v1> + %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> + func.return %0 : tensor<2x1x5xf32> +} + +// CHECK-LABEL: "op_trace" +func.func @op_trace(%arg0: tensor) { + // CHECK: "vhlo.trace_v1"(%arg0) <{ + // CHECK-SAME: tag = #vhlo.string_v1<"foo"> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> () + "stablehlo.trace"(%arg0) { + tag = "foo" + } : (tensor) -> () + func.return +} + +// CHECK-LABEL: "op_transpose" +func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> { + // CHECK: "vhlo.transpose_v1"(%arg0) <{ + // CHECK-SAME: permutation = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x16x!vhlo.f32_v1> + %0 = "stablehlo.transpose"(%arg0) { + permutation = dense<[1, 0]> : tensor<2xi64> + } : (tensor<16x8xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: "op_triangular_solve" +func.func @op_triangular_solve(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.triangular_solve_v1"(%arg0, %arg1) <{ + // CHECK-SAME: left_side = #vhlo.bool_v1, + // CHECK-SAME: lower = #vhlo.bool_v1, + // CHECK-SAME: transpose_a = #vhlo, + // CHECK-SAME: unit_diagonal = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_tuple" +func.func @op_tuple(%arg0: tensor) -> tuple> { + // CHECK: "vhlo.tuple_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tuple_v1> + %0 = "stablehlo.tuple"(%arg0) : (tensor) -> tuple> + func.return %0 : tuple> +} + +// CHECK-LABEL: "op_unary_einsum" +func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { + // CHECK: "vhlo.unary_einsum_v1"(%arg0) <{ + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab->a"> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x!vhlo.f32_v1> + %0 = "stablehlo.unary_einsum"(%arg0) { + einsum_config = "ab->a" + } : (tensor<8x16xf32>) -> tensor<8xf32> + func.return %0 : tensor<8xf32> +} + +// CHECK-LABEL: "op_uniform_dequantize" +func.func @op_uniform_dequantize(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.uniform_dequantize_v1"(%arg0) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_uniform_quantize" +func.func @op_uniform_quantize(%arg0: tensor) -> tensor> { + // CHECK: "vhlo.uniform_quantize_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.uniform_quantize"(%arg0) : (tensor) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "op_while" +func.func @op_while(%arg0: tensor) -> tensor { + // CHECK: "vhlo.while_v1"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1) + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.while"(%arg0) ({ + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0: tensor +} + +// CHECK-LABEL: "op_xor" +func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.xor_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.xor"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ============ TYPES ============ + +// CHECK-LABEL: "type_i1" +func.func @type_i1(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i4" +func.func @type_i4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i8" +func.func @type_i8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i16" +func.func @type_i16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i32" +func.func @type_i32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i64" +func.func @type_i64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui4" +func.func @type_ui4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui8" +func.func @type_ui8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui16" +func.func @type_ui16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui32" +func.func @type_ui32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui64" +func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3FN" +func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E5M2" +func.func @type_f8E5M2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3FNUZ" +func.func @type_f8E4M3FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3B11FNUZ" +func.func @type_f8E4M3B11FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E5M2FNUZ" +func.func @type_f8E5M2FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_bf16" +func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f16" +func.func @type_f16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f32" +func.func @type_f32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f64" +func.func @type_f64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_complex_f32" +func.func @type_complex_f32(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_complex_f64" +func.func @type_complex_f64(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_dynamism_ranked" +func.func @type_dynamism_ranked(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_dynamism_unranked" +func.func @type_dynamism_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "vhlo.abs_v1"(%arg0) : (!vhlo.unranked_tensor_v1) -> !vhlo.unranked_tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// CHECK-LABEL: "type_quantization" +func.func @type_quantization(%arg0: tensor>, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_callee" +func.func @type_token_callee(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.return_v1"(%arg0) : (!vhlo.token_v1) -> () + return %arg0 : !stablehlo.token +} + +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_caller" +func.func @type_token_caller(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.call_v1"(%arg0) <{callee = #vhlo.string_v1<"type_token_callee">} + // CHECK-SAME: (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = func.call @type_token_callee(%arg0) : (!stablehlo.token) -> !stablehlo.token + return %0 : !stablehlo.token +} + +// CHECK-LABEL: "type_tuple" +func.func @type_tuple(%arg0: tuple>) -> tuple { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + // CHECK: (!vhlo.tuple_v1>) -> !vhlo.tuple_v1 + } : (tuple>) -> tuple + return %0 : tuple +} diff --git a/stablehlo/tests/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc new file mode 100644 index 0000000000000000000000000000000000000000..a618e80959bf61c840fbfc35a3dabc1da07949e3 GIT binary patch literal 16714 zcmc(Gacor8*6%rcpVR#vCd0hLWSEqBo5^renxU`smVT745Fmxe@ChGqfk%lDXr}G7 z474+KrmcbzpBOP>#E222Mto|-s1c(^j2JOj-^HjApBgn+BSy{Ds8MqvU=sd!|i^Y?vwA0+u*0E^G@|COCZrHqadvE{Ho^ti-Yp=iY=38&S^X~f|c=*vL zo_hBA7hit$wKv~>@57Hj{o<={fB5m2-+pgk)7eZmo6TkOStnb@a%?^8VLRC%E3pZ- zpB-d}*->_k9cL%lNp_l@VQ1L~>>T@&eb0Vo7g+%nfa`+3jA}@pwEPgP-x^dRZ0u3LcBcL`)=;@eYv|us$9aTb+2) zaXM0|R5R4TssE@-8gsBcnFJn>_r_DHWGWdK>9ms;@pRHj#ZyAWlJQhBmL#Ec3Pee# z(@;U^f|F`%ZU!$+jX)vU47p@7os5I&fFqKSO{P#yCf6oYBIW)UbwWk=Un&)W7DfQ+ zc#_S_ra3p?GJoW{{3)K$o##v*PnhX^M)Q7jkD9vkY&;Q_zc^o~`)wbMCoIVNGR$KO z-O?fCv(vYA`$b>Er#?SDmuLJoeV+5O^BG$4MN#}uo=7@h=jLnOaPzLaxcTU#&Yzq= zJD)h8I^Q|F8MF*f%QDlq)r6fFoL8Nik{)>MUSRCv#RXr-#YkGCXBAn(2AG)?`diKUJ$F&Ggg!_lU{) z=Vp2V&rcqlKS+OcKIQ%pJ9+nHoyjusgq^;%HWldHS4isUTuU$JdXme(%ikqqvVqBI zOipKV29qQB6(%Q`+{c{Pn#^wPckf5;XYv4(2bny?Zse#zuFOn%Sgc_x2m@&c0=nPlAgK{jybZ*m%Uew5R>oWbQxE@yE$o69*|&gF6* zcYc%exm>_yCzp%4^SfNe9fm=2rvc;Sat)X3xibyJqg>w0d>@Zu>%UGu zj`Bg2521X5%V)Se$>k~T%!32CGao+S@(h=6aCw$Hi{S?DEMxKmE1D=Gl&Mm zIg5ZW9K^|a!#RhrFq|(D8HV#cg2QkSCCqd#Vj-Avn&~ufIo))oAw*1PI-me%+^+;|KOgf4}mn-%GqP z6K~9NF5->4TwuyhY%Zzve6Dl;jhWDOJ*`{RO+3NV%cgFp8IdtlypcER+W!^gvl|t&I|HuSl;yP*Eo9c}#*av4%?!zpWc5T= z-^QRl63?17!CFCNt!zHe7+OMWW(L13{NC398jbF2KqK|$o3cXnEG)Z8pq#K8vtmWV z9NuR^%tn8QSXMuZ&tqBrA&?XJoW$oeK5yXjE;By|I3-~aM@QJKGgJ-#w zz?kr98p&GLs_Ud^<9dhSqUHBXT5sC0(CkLvB10n<8S)Cs_{}Z&#%{!bYQwD+Aj+e# z-b|3VnK02x7Q->KG#1k)Gwg&8-Nr`9T2}X7Nxs^#gM4L~Qg(}lZKiGkYU&1L+OWZD zU$}4vaFH+*)IU0~P)bJHcqD6D)Tnj4ZR(h48_}WLx^8oaJ2p0)g3Sw@+h!`pTLD7I z7P{U6Au}4~QJa@}+2%1SUB#<5U&t3yHD;S_?feeg9KV|Hwaw;hd9y9r+OOl++hS4I z4g5yiT-3#HvU$YRBj`H9Bh(u@8q)clwpeoWEqn<~zf0#?Wf?3Y8PhajWF!gEq`+^t zh>A@$#lZm$VGfH-iv|O8iGXdXmZWVq+1$3xmKN#{X5MWh^HY*dA14 zi(6!cMb?48)HWY}gg=V5&~S{WsL`U$Q@q9I%emgp9~0iraNxgpU*s25ugh)3KkrW1 z=E{|PrERWS#aA`jW;)H&wz+yWUk!4_75s{fZLVF**E%+q5=Uj-I=&90+pvLefU`Dk zX{#i)?et7QSVPZEoGlw|3d)l~?jBF|6&|`F3HOJ9h9LHe!v#mA$>Zmp?&X ziQ9Y+-(#E0l~aoRNrVy$8WpUB(X+$U|ETN2Sg?cZYZ}Zy>RcEt zTp%WN&C2O;r3ptG@GCG8!mJP$31QI?77JmRpg^1P5S9pG$q?2U!crkDY^Mxi>5yI~ zggGIsDTFnLuxtox31O|mfN+HCS|q1qTf#qhjKgE?K?F;-jw}o!i-Jg35Lr@3HtKw- zFrwH2xTd4+<(}LsEt=E2v0K!7T^U4H1(DSrVv^1k*jH*gYdz$3Le^m&)Q}C{$QMd2 zmeV)#*1+s-4q;nD*wzqsWeD5uV(?fZry~Yk$IvYPQSS|6`7qWW#0p_-NEmJ0ieR%T zXvv&j)L)f*#pbe3KqLB?W zW7l~;mfEVEem(NYnvZV~#uhh%R`%#Of{!Jc^_HP%AZ>3I`hj(m*VJcPcTT^VSL;pP z;vq$Cc}~9-r$Mcg+iKk}&FQyeZ>u49)R6CT`kg`KE^mOB+79maupN1ANlw3q-{BjH zq2A}UsFp-1@bCAuzWz?@%IOdAy#cKUJ?v7U_OOSo`$=1r(;wl@0li1P$=#9vRa;1= zTpyhJ$Ar;NfiZFwV_HW}e}dl_grD^0Nw3Y*Q(jjiS844z{b_zvpsQzHj3>0VocEFru!~I60VT#l0?!JUM857MsH~r$>L2FCh=r)uzNCEvdJJaSN$! z+{AP()1$x5v!oReh|iTZawP=X(Z29n@m7Hl7D@c#dB@g1BOB@-{nZ>42B&Ed!49;eljzl4$LKyVL;aJ9Y7 zV8Q)uCa_lb1kkk)NzWUgIsj(*ZQ9!Vz>uQ)Q#RXgQ)<{4=dj@ZHy2nS5X1SxJf-L+ zOYfmg8JM^EK6XV$dj}X2-`cJ^ILTH;x&xl9cXdAucvFThxIlANpWty#%Miw$5aoP%E!9)=N#^c!H#w+45njP zucEDc0^Ri0^^gY*hV*v(dN@e_L^cHUO2Cvm+#{jFLTADRP&EVYp%Bg^_W7#rsj#1g zj)en2)z0so3vf&5U^oc8BX}|#V&S9VFa(q*GTO^(p5k?v9|5+I)?Aa0OQrS_1{Du4 zxuZ%OzA_8ii-3bC*1f>gn9Zfy3pxE5<4Xe`c}StyTFt#nwdbMAgOPmP*M&!(%ju7^ zP-=Aos24uZlB4SuKT}7Z4kIUls0?;$E_tf1^i){sG!T{G)DO;Az*3`fPtG%p2anD- zfGTe`X-_I|2G-zNU`VPXd^gjc$l-ZF(C`lcuBCah5Iv8%y9so?-NdD}MDHUAtMPYf zkA@9C2ZXG1&u)6*6g<4Y1d7|G^~egfApOPgjo+NEJ&fkCiTT^c_jTmKa7X9s$OB>I z=Q?tK7`XssvwH;I*P~y=mOw|1H<7_3&hMUG?o_9pHGthi{RCobS^&Fy5;h%}*X3Pd z^Je&nrQHc6;D(t#qHA}Ab!PcJ;l#V0lmb@F2Bsp#&6VchNDH8IeRPYi-Imkm!5IN` zzTcy*-HK)dQMADC5pi-0857Xz1lG&-ypcsOAS@}~u#E%saO(rzNv`a~}S z;>Ar%PG132f^m}rx}7$2Q`c@(Ge^Brq=N@LM9>cEjE=rFN&w-mYd7Td^>o{+#nmQY zUNVG~uLooQr%A^)-v&47+sXO;Asa54_4Ngeygq0o>>C6|=fZV37cyitU0-&4%xi

b3d7wUy3Xp%60!%wFFpz%Qe`*Qx5&nOaAs>a}t)#f|fSV%U%m3eV302`Ww zodV{~K!I$n$HnP7(jP`%QSD(>m|8xkVJ*!}k_ zY(oHB;bRXzsIZLzEazj7Jfg780c?$rJ@%Nwwgj;CKKA633fmgMHu>1oPb=)o0M_GU z&poHG?E!3?k4--K5HCA@?1dMU-j0A?pO1BQDJ*!C8uYQ3UQ$?ZKyNoNwLqFQ%8>%; zdC5oKSIC}vs8{^wk8~~>B=6C4HZwG(U z>!pD<6z$}HTWm%W(PUFoTZ`l1OD2g-bEZ)y5()K%6TF&;#j~xMcs3PpwHs|4FDu#> z%2qqknvA!$#9ESxWJac(rfkBtTjPmXGTI(XG$xa=#!M__C+uiTD-cJ8n9XK0 zOBOlJ&COkH9bHS7E?lx`>7u3C=H{kEG&-_l-R7lT-HW=GsGla&;_3EOyt%!tIh%sb zZOM*wV=|FxX-;I)%}#rwBMsxD=rWdRj3pZrvG!;zfdMA6DVa$oTQHJVxv-6f+lXPM z;+fWLJZ8hP#@2MYE7O!nB<&{0PPW27$yiGk-|$H%V|F4IvHQl0!(+vgHnD5Ctc{ch z#)k`9r7$*LDfLdYYlGvZexmt-0V4h7klS>VaZ^zUp>U7#Qy_^j=*o4V14|UAT;ZXr<6UK}dOgY_L#J z-4#oIX(&Io%WZsMZz(@g?C+h*=`WX51HMFGrJNst0b{+zQtv=fje{KG4pq5{RLaBS zW5u#^Z-051##4lAMhlgZ@iEu6yUN3b>YlyI4{+~Tx!3bDk%~JI@)2rvtNfX`#&zk{ zyNbiAweeDMuv{7Gy=pvP8Y}J#@dgTEZoaa&w@@rq-8mQ=A9b4;tK>`7(X#7VRNz>2 zkGbpLg`o}k1hzxKURvjFGJ74PQ>E$|P$`cia!Lin zFlMPRQBb{yx%t6;)I$Ym#gC=xuKcJn7|}mgRI^u90aGcAVodp7zc!L5f4DL2&eJHu ztf=f#jjE_rC3lFG0;nHARo_3O<*VIBM$4o9CFR4>yz0=M-16vHaiq8}uPo`yS1LtU zyyPywQUPJ1>>U}$;vOAVp)^t~`KxOrKT2aHS64lcqTD+=JYMZByHi7Th!m^6gV1wZ zE0)LdiaC;Z!9l7G!TWBgAP6uC%3NYjosz-fvO9pnHKXVU@#uOXScRE&E9$}M`)GW0 zfW}sHmCv8Tz!dwl=F50K^sD}G28t8K0p+g#>R1H}j1=kHv^2W@O1WBX_2#+NTN(YB zHW)Eh-s474f7y>{THpmY+EJ`xH7IwWjP0sW@}`CMQCjT2jWtUN7{+>{juER^4Tbzb z5xe8Q0_;%i;rxi2y#7Mb-BSA8?&`Y_lXtwV{(xI?(#*%+UD33`a#{O>Hntaw z%k?E~VugZ6Xs>(N6suU$14XPN73_tp{sCC}N18+b5FGml&EiD7yg|cLHPv{q6fvwR ziVDLr$X`Zn3`=adT)`f4X;s=W3j>$N#K!Ny=CMb*jZ8sz;Y9ie5DpO+rTis7JOq#7j0(3IhQ@YP3WeUm z%HB)qjpE2IUsFd0usK(XeQvhk9pjflX$ku?RX?puWm?wm(d{bMs+R}y1G~$`k~c8r z_rf*UxhQt3y?u0mRyEISYPq`^NQmX8C?S@c!_-*)UW!v^Ov!{U1=Ti_0H`w3K)se= z)p_;Mt?>eRS&&yRds$ztM=z(FB0#(86y_wKP!_HEr_LnIO*v~Vc*$t3PDyI5LRoOF z;^l!}LOYr5CcZW9WGOE<3gppsQKZ~jd9)Wfydnbwiul{Rs?-uszvj(2snrrzg~4EA z(kVfONhio8R_(Bv#P`o8#SWihiZvN| zC{|zrz*0BgAt1O4+^XVUF7`>p!&KM@>Rwa_IFwxlTSM{{ta_6KS?bYT36}SD87Lk*vE!U5-`J&E!?lOYo`erv0i^R>Qf}HGyFG zf_}UAUMhNX%1a9x!-JMMC#K-1Me0=~Ae;bZ4YO!0b!xnk1m^zDkt# zXT_O(`BB{3owZ(bue|lU?PP>*Xvm2b8@ZiJ-OZr3TId;x0(^tLn>X?3Hm>NDySn1i zRo*LjSpw-ydTOHt=urR_#!K~Q1a94Zx=??T5ajw#7l@}GNKgR}DBP64h^GaV`WOGz zqyOgfJH0~30wSZMSM`G5<)GiAe}fzJks8C7aJSIvNH0>eJxIcp4{FTuQjqOA zP+jz0!N5JqbECR@Yx%y`WqJJmFMa(e^lSgZXEE(xPX6bE2Hx^TL{dCunaxd9c-j)p zoefla7OY+>Jcq&#Dm)Ktfx=!u34eJB-ivtImJ}~p=60q2vSqGQg+EzhUV|#VYKgg2 zii$s5@i`4vN+j^tWK@66vabwOU$^Xg0@XLI`0Tp$hvF~vT2FjznG5fr!Y7tkI146x zDn3JLCQ4t3zoIk)r8e=E#g`KO+A^mr#J-%g#%|ZD<MC_w}ZJ-tbq|0z*S;B?1I2bv5Bm)Vd5&$Ln`fJ8-^8WjU-|M-$Lp`n;Rb9 z(6S*as$viVHlC^{vN3^g8kL~>aLWdBT`#FsP& zm2ZgR%O~fd8KDM$hUoX0wsCQRpo>f}UZ1u_n6^8YEV_FUi8gQuOygoY`8XiSuSHXx03gKWTwuUvHQL>?;}smA*3kg;x%@AYoAK?@{N5Q9VZbv`&a&e5>#S8G`Xy#osbG>++6mJkGh^!RPz|PcJ zGbZrD`V`6H74$1^(`Q`rhD**8X&3LpzBG;f1FB2$kxS0GLZ1`4QGDsvz9F(&d{1PR zI1kNKOyIrk1@yUgEn2x&TqH~iW}s)hft_ZcdV`owWTlv4Aj}&P0eCGsi&#-H+pw%N z<{KDDyO@Kr-HH`1@EUU-T4*_qznXakm=yC#{3fx$C7skf-Y+hr>Rz#eNV~`pSuWNP z*(}x*iHc1w>2b+6B2}>y1TUoGv20A>b=)92K83Exo4ZM3l_-&^8JL1MXcNS`LhN(N zewQ2|f)(VFLuAllLmYwHT8ff(an!(GqgL6yBqPN!SM(tgJ#L7{3CE&3VTfmlbH(>SC}G4bd4$2 z6WwHr9-`Y!v6E<@DF%t|HbsfLlcGY^2~+GNy5AHBh#oY>A)<#(afIkmQ`}4Rm?<7I Zan5WNH$!==IF7e(H^;=|rZ{0*{~ODi(t7{^ literal 0 HcmV?d00001 diff --git a/stablehlo/tests/stablehlo_legalize_to_vhlo.mlir b/stablehlo/tests/stablehlo_legalize_to_vhlo.mlir index 8ed753778a6..baae56d33b3 100644 --- a/stablehlo/tests/stablehlo_legalize_to_vhlo.mlir +++ b/stablehlo/tests/stablehlo_legalize_to_vhlo.mlir @@ -406,6 +406,18 @@ func.func @default_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf3 func.return %0 : tensor<16x8xf32> } +// CHECK-LABEL: "default_collective_broadcast" +func.func @default_collective_broadcast(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_broadcast_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + // CHECK-LABEL: "default_compare" func.func @default_compare(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "vhlo.compare_v1"(%arg0, %arg1) <{ diff --git a/stablehlo/tests/vhlo_to_version_downgrade_invalid.0_15_0.mlir b/stablehlo/tests/vhlo_to_version_downgrade_invalid.0_15_0.mlir new file mode 100644 index 00000000000..d3417f65cb0 --- /dev/null +++ b/stablehlo/tests/vhlo_to_version_downgrade_invalid.0_15_0.mlir @@ -0,0 +1,9 @@ +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=0.15.0' --verify-diagnostics --split-input-file %s + +func.func @default_collective_broadcast(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // expected-error @+1 {{failed to legalize operation 'vhlo.collective_broadcast_v1' that was explicitly marked illegal}} + %0 = "stablehlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} diff --git a/stablehlo/transforms/MapStablehloToVhlo.h b/stablehlo/transforms/MapStablehloToVhlo.h index d305fadff5c..93795399009 100644 --- a/stablehlo/transforms/MapStablehloToVhlo.h +++ b/stablehlo/transforms/MapStablehloToVhlo.h @@ -68,6 +68,7 @@ MAP_STABLEHLO_TO_VHLO(CeilOp, V1) MAP_STABLEHLO_TO_VHLO(CholeskyOp, V1) MAP_STABLEHLO_TO_VHLO(ClampOp, V1) MAP_STABLEHLO_TO_VHLO(ClzOp, V1) +MAP_STABLEHLO_TO_VHLO(CollectiveBroadcastOp, V1) MAP_STABLEHLO_TO_VHLO(CollectivePermuteOp, V1) MAP_STABLEHLO_TO_VHLO(CompareOp, V1) MAP_STABLEHLO_TO_VHLO(ComplexOp, V1) diff --git a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp index 8f2edd91c87..22aee60ab00 100644 --- a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +++ b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp @@ -499,7 +499,9 @@ SpecialResult convertSpecial(const OpConversionPattern& pattern, std::is_same::value || std::is_same::value) { + stablehlo::ReduceScatterOp>::value || + std::is_same::value) { if (stablehloName == "channel_handle") return convertChannelId(pattern, stablehloAttr, vhloAttrs); if (stablehloName == "use_global_device_ids") @@ -581,7 +583,9 @@ LogicalResult addDefaults(const OpConversionPattern& pattern, } if constexpr (std::is_same::value || std::is_same::value) { + stablehlo::CollectivePermuteOp>::value || + std::is_same::value) { if (!stablehloOp.getChannelHandleAttr()) addDefaultAttr("channel_id", builder.getI64IntegerAttr(0)); } diff --git a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp index c5a999ae30b..c15dbafb781 100644 --- a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp @@ -436,7 +436,8 @@ SpecialResult convertSpecial(const OpConversionPattern& pattern, std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value) { + std::is_same::value || + std::is_same::value) { if (vhloName == "channel_id") { stablehloName = StringAttr::get(pattern.getContext(), "channel_handle"); stablehloAttr = convertChannelId(vhloAttr, pattern.getTypeConverter()); @@ -550,7 +551,8 @@ LogicalResult removeDefaults(const OpConversionPattern& pattern, eraseAttrs(vhloAttrs, "use_global_device_ids"); } if constexpr (std::is_same::value || - std::is_same::value) { + std::is_same::value || + std::is_same::value) { if (isInteger(vhloOp.getChannelIdAttr(), 0)) eraseAttrs(vhloAttrs, "channel_id"); }