diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 309f3328907..062b7a2ae5a 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -2334,6 +2334,22 @@ LogicalResult ScatterOp::verify() { getScatterDimensionNumbers().getIndexVectorDim(), getUpdateComputation()); } +mlir::Speculation::Speculatability ScatterOp::getSpeculatability() { + // When unique_indices is true, if the scatter_indices are not unique, the + // behavior is undefined. + // A possible improvement would be to check if the scatter_indices are + // constant and if so, check if they are unique/sorted, and if so do not + // return NotSpeculatable. However, such a check could be somewhat costly and + // has unclear ROI. + if (getUniqueIndices() || getIndicesAreSorted()) + return mlir::Speculation::NotSpeculatable; + return llvm::all_of( + this->getOperation()->getOperandTypes(), + [](Type t) { return cast(t).hasStaticShape(); }) + ? mlir::Speculation::RecursivelySpeculatable + : mlir::Speculation::NotSpeculatable; +} + //===----------------------------------------------------------------------===// // WhileOp //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 74fba1ec3ad..dba4b6ee1a2 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2641,7 +2641,8 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } -def StableHLO_ScatterOp: StableHLO_Op<"scatter", [RecursiveMemoryEffects, +def StableHLO_ScatterOp: StableHLO_Op<"scatter", + [ConditionallySpeculatable, RecursiveMemoryEffects, SameVariadicOperandSize /*scatter_c5*/, DeclareOpInterfaceMethods /*scatter_c16, scater_c17*/]> { @@ -2685,6 +2686,11 @@ def StableHLO_ScatterOp: StableHLO_Op<"scatter", [RecursiveMemoryEffects, let results = (outs Variadic); let hasVerifier = 1; + + let extraClassDeclaration = commonClassDeclaration # [{ + /// Interface method for ConditionallySpeculatable. + mlir::Speculation::Speculatability getSpeculatability(); + }]; } def StableHLO_SelectOp: StableHLO_Op<"select", diff --git a/stablehlo/tests/ops_speculatability.mlir b/stablehlo/tests/ops_speculatability.mlir index 4cb7b844efe..6a06306fecd 100644 --- a/stablehlo/tests/ops_speculatability.mlir +++ b/stablehlo/tests/ops_speculatability.mlir @@ -1244,6 +1244,51 @@ func.func @select(%static_pred: tensor<2xi1>, %dynamic_pred: tensor, %stat // ----- +// CHECK-LABEL: func @concatenate +// CHECK-NEXT: return +func.func @concatenate(%static_arg: tensor<2x2xi64>, %first_dim_dynamic: tensor, %second_dim_dynamic: tensor<2x?xi64>, %dynamic_arg: tensor) { + %speculatable_0 = stablehlo.concatenate %static_arg, %static_arg, dim = 0 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor + %speculatable_1 = stablehlo.concatenate %static_arg, %static_arg, dim = 1 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor + %speculatable_2 = stablehlo.concatenate %static_arg, %first_dim_dynamic, dim = 0 : (tensor<2x2xi64>, tensor) -> tensor + %speculatable_3 = stablehlo.concatenate %second_dim_dynamic, %static_arg, dim = 1 : (tensor<2x?xi64>, tensor<2x2xi64>) -> tensor + + %speculatable_4 = stablehlo.concatenate %static_arg, %static_arg, dim = 0 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<4x2xi64> + %speculatable_5 = stablehlo.concatenate %static_arg, %static_arg, dim = 1 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x4xi64> + %not_speculatable_0 = stablehlo.concatenate %static_arg, %first_dim_dynamic, dim = 0 : (tensor<2x2xi64>, tensor) -> tensor<4x2xi64> + %not_speculatable_1 = stablehlo.concatenate %second_dim_dynamic, %static_arg, dim = 1 : (tensor<2x?xi64>, tensor<2x2xi64>) -> tensor<2x4xi64> + + %speculatable_6 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 0 : (tensor, tensor) -> tensor + %not_speculatable_2 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 0 : (tensor, tensor) -> tensor<4x?xi64> + %not_speculatable_3 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 1 : (tensor, tensor) -> tensor + + %not_speculatable_4 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 0 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor + %not_speculatable_5 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 1 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor + %speculatable_7 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 1 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor + + %not_speculatable_6 = stablehlo.concatenate %dynamic_arg, %dynamic_arg, dim = 0 : (tensor, tensor) -> tensor + %not_speculatable_7 = stablehlo.concatenate %dynamic_arg, %dynamic_arg, dim = 1 : (tensor, tensor) -> tensor + + "hlo_test_speculatability.is_speculatable"(%speculatable_0) : (tensor) -> () + "hlo_test_speculatability.is_speculatable"(%speculatable_1) : (tensor) -> () + "hlo_test_speculatability.is_speculatable"(%speculatable_2) : (tensor) -> () + "hlo_test_speculatability.is_speculatable"(%speculatable_3) : (tensor) -> () + "hlo_test_speculatability.is_speculatable"(%speculatable_4) : (tensor<4x2xi64>) -> () + "hlo_test_speculatability.is_speculatable"(%speculatable_5) : (tensor<2x4xi64>) -> () + "hlo_test_speculatability.is_speculatable"(%speculatable_6) : (tensor) -> () + "hlo_test_speculatability.is_speculatable"(%speculatable_7) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_0) : (tensor<4x2xi64>) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_1) : (tensor<2x4xi64>) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_2) : (tensor<4x?xi64>) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_3) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_4) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_5) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_6) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_7) : (tensor) -> () + return +} + +// ----- + // CHECK-LABEL: func @gather // CHECK-NEXT: return func.func @gather( @@ -1303,46 +1348,92 @@ func.func @gather( return } -// CHECK-LABEL: func @concatenate -// CHECK-NEXT: return -func.func @concatenate(%static_arg: tensor<2x2xi64>, %first_dim_dynamic: tensor, %second_dim_dynamic: tensor<2x?xi64>, %dynamic_arg: tensor) { - %speculatable_0 = stablehlo.concatenate %static_arg, %static_arg, dim = 0 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor - %speculatable_1 = stablehlo.concatenate %static_arg, %static_arg, dim = 1 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor - %speculatable_2 = stablehlo.concatenate %static_arg, %first_dim_dynamic, dim = 0 : (tensor<2x2xi64>, tensor) -> tensor - %speculatable_3 = stablehlo.concatenate %second_dim_dynamic, %static_arg, dim = 1 : (tensor<2x?xi64>, tensor<2x2xi64>) -> tensor - - %speculatable_4 = stablehlo.concatenate %static_arg, %static_arg, dim = 0 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<4x2xi64> - %speculatable_5 = stablehlo.concatenate %static_arg, %static_arg, dim = 1 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x4xi64> - %not_speculatable_0 = stablehlo.concatenate %static_arg, %first_dim_dynamic, dim = 0 : (tensor<2x2xi64>, tensor) -> tensor<4x2xi64> - %not_speculatable_1 = stablehlo.concatenate %second_dim_dynamic, %static_arg, dim = 1 : (tensor<2x?xi64>, tensor<2x2xi64>) -> tensor<2x4xi64> - - %speculatable_6 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 0 : (tensor, tensor) -> tensor - %not_speculatable_2 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 0 : (tensor, tensor) -> tensor<4x?xi64> - %not_speculatable_3 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 1 : (tensor, tensor) -> tensor - - %not_speculatable_4 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 0 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor - %not_speculatable_5 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 1 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor - %speculatable_7 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 1 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor - - %not_speculatable_6 = stablehlo.concatenate %dynamic_arg, %dynamic_arg, dim = 0 : (tensor, tensor) -> tensor - %not_speculatable_7 = stablehlo.concatenate %dynamic_arg, %dynamic_arg, dim = 1 : (tensor, tensor) -> tensor +// ----- - "hlo_test_speculatability.is_speculatable"(%speculatable_0) : (tensor) -> () - "hlo_test_speculatability.is_speculatable"(%speculatable_1) : (tensor) -> () - "hlo_test_speculatability.is_speculatable"(%speculatable_2) : (tensor) -> () - "hlo_test_speculatability.is_speculatable"(%speculatable_3) : (tensor) -> () - "hlo_test_speculatability.is_speculatable"(%speculatable_4) : (tensor<4x2xi64>) -> () - "hlo_test_speculatability.is_speculatable"(%speculatable_5) : (tensor<2x4xi64>) -> () - "hlo_test_speculatability.is_speculatable"(%speculatable_6) : (tensor) -> () - "hlo_test_speculatability.is_speculatable"(%speculatable_7) : (tensor) -> () - "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_0) : (tensor<4x2xi64>) -> () - "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_1) : (tensor<2x4xi64>) -> () - "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_2) : (tensor<4x?xi64>) -> () - "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_3) : (tensor) -> () - "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_4) : (tensor) -> () - "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_5) : (tensor) -> () - "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_6) : (tensor) -> () - "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_7) : (tensor) -> () +// CHECK-LABEL: func @scatter +// CHECK-NEXT: return +func.func @scatter( + %static_inputs: tensor<3x4x2xf64>, %static_indices: tensor<2x3x2xi64>, %static_updates: tensor<2x3x2x2xf64>, + %dynamic_inputs: tensor, %dynamic_indices: tensor, %dynamic_updates: tensor +) { + %recursively_speculatable_0 = "stablehlo.scatter"(%static_inputs, %static_indices, %static_updates) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + stablehlo.return %arg0 : tensor + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [2, 3], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [1, 0], + index_vector_dim = 2>, + indices_are_sorted = false, + unique_indices = false + } : (tensor<3x4x2xf64>, tensor<2x3x2xi64>, tensor<2x3x2x2xf64>) -> tensor + %not_speculatable_0 = "stablehlo.scatter"(%static_inputs, %static_indices, %static_updates) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + stablehlo.return %arg0 : tensor + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [2, 3], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [1, 0], + index_vector_dim = 2>, + indices_are_sorted = false, + unique_indices = true + } : (tensor<3x4x2xf64>, tensor<2x3x2xi64>, tensor<2x3x2x2xf64>) -> tensor + %not_speculatable_1 = "stablehlo.scatter"(%static_inputs, %static_indices, %static_updates) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + stablehlo.return %arg0 : tensor + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [2, 3], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [1, 0], + index_vector_dim = 2>, + indices_are_sorted = true, + unique_indices = false + } : (tensor<3x4x2xf64>, tensor<2x3x2xi64>, tensor<2x3x2x2xf64>) -> tensor + %not_speculatable_2 = "stablehlo.scatter"(%dynamic_inputs, %static_indices, %static_updates) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + stablehlo.return %arg0 : tensor + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [2, 3], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [1, 0], + index_vector_dim = 2>, + indices_are_sorted = false, + unique_indices = false + } : (tensor, tensor<2x3x2xi64>, tensor<2x3x2x2xf64>) -> tensor + %not_speculatable_3 = "stablehlo.scatter"(%static_inputs, %dynamic_indices, %static_updates) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + stablehlo.return %arg0 : tensor + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [2, 3], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [1, 0], + index_vector_dim = 2>, + indices_are_sorted = false, + unique_indices = false + } : (tensor<3x4x2xf64>, tensor, tensor<2x3x2x2xf64>) -> tensor + %not_speculatable_4 = "stablehlo.scatter"(%static_inputs, %static_indices, %dynamic_updates) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + stablehlo.return %arg0 : tensor + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [2, 3], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [1, 0], + index_vector_dim = 2>, + indices_are_sorted = false, + unique_indices = false + } : (tensor<3x4x2xf64>, tensor<2x3x2xi64>, tensor) -> tensor + "hlo_test_speculatability.is_recursively_speculatable"(%recursively_speculatable_0) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_0) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_1) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_2) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_3) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_4) : (tensor) -> () return }