From 837536cba5304d4185bea83f707ce4d1a3b5cc79 Mon Sep 17 00:00:00 2001 From: mlevesquedion Date: Fri, 12 Apr 2024 09:44:15 -0700 Subject: [PATCH] Implement ConditionallySpeculatable for Sort (#2200) NOTE: This PR builds on top of https://github.com/openxla/stablehlo/pull/2193. See the 2nd commit for changes specific to this PR. --- stablehlo/dialect/StablehloOps.td | 3 +- stablehlo/tests/ops_speculatability.mlir | 40 ++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index a406e91dd66..74fba1ec3ad 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2795,7 +2795,8 @@ def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size", [Pure, } def StableHLO_SortOp : StableHLO_Op<"sort", - [RecursiveMemoryEffects, SameOperandsAndResultShape /*sort_c1, sort_c3*/, + [HLO_RecursivelySpeculatableIfAllInputsStatic, RecursiveMemoryEffects, + SameOperandsAndResultShape /*sort_c1, sort_c3*/, InferTensorType /*sort_c2*/]> { let summary = "Sort operation"; let description = [{ diff --git a/stablehlo/tests/ops_speculatability.mlir b/stablehlo/tests/ops_speculatability.mlir index c8aefe61636..4cb7b844efe 100644 --- a/stablehlo/tests/ops_speculatability.mlir +++ b/stablehlo/tests/ops_speculatability.mlir @@ -1489,6 +1489,46 @@ func.func @reduce_window(%static_arg: tensor<2x4xf64>, %dynamic_arg: tensor, %dynamic_arg: tensor) { + %recursively_speculatable_0:2 = "stablehlo.sort"(%static_arg, %static_arg) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %predicate = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + stablehlo.return %predicate : tensor + }) {dimension = 0 : i64, is_stable = true} : (tensor<2x4xf64>, tensor<2x4xf64>) -> (tensor, tensor) + %not_recursively_speculatable_0:2 = "stablehlo.sort"(%dynamic_arg, %static_arg) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %predicate = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + stablehlo.return %predicate : tensor + }) {dimension = 0 : i64, is_stable = true} : (tensor, tensor<2x4xf64>) -> (tensor, tensor) + %not_recursively_speculatable_1:2 = "stablehlo.sort"(%static_arg, %dynamic_arg) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %predicate = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + stablehlo.return %predicate : tensor + }) {dimension = 0 : i64, is_stable = true} : (tensor<2x4xf64>, tensor) -> (tensor, tensor) + %not_recursively_speculatable_2:2 = "stablehlo.sort"(%dynamic_arg, %dynamic_arg) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %predicate = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + stablehlo.return %predicate : tensor + }) {dimension = 0 : i64, is_stable = true} : (tensor, tensor) -> (tensor, tensor) + "hlo_test_speculatability.is_recursively_speculatable"(%recursively_speculatable_0) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_recursively_speculatable_0) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_recursively_speculatable_1) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_recursively_speculatable_2) : (tensor) -> () + return +} + +// ----- + // Miscellaneous ops // -----