Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for Sort (openxla#2200)
Browse files Browse the repository at this point in the history
NOTE: This PR builds on top of
openxla#2193. See the 2nd commit for
changes specific to this PR.
  • Loading branch information
mlevesquedion authored Apr 12, 2024
1 parent 2fe1109 commit 837536c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
3 changes: 2 additions & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2795,7 +2795,8 @@ def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size", [Pure,
}

def StableHLO_SortOp : StableHLO_Op<"sort",
[RecursiveMemoryEffects, SameOperandsAndResultShape /*sort_c1, sort_c3*/,
[HLO_RecursivelySpeculatableIfAllInputsStatic, RecursiveMemoryEffects,
SameOperandsAndResultShape /*sort_c1, sort_c3*/,
InferTensorType /*sort_c2*/]> {
let summary = "Sort operation";
let description = [{
Expand Down
40 changes: 40 additions & 0 deletions stablehlo/tests/ops_speculatability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,46 @@ func.func @reduce_window(%static_arg: tensor<2x4xf64>, %dynamic_arg: tensor<?x?x

// -----

// CHECK-LABEL: func @sort
// CHECK-NEXT: return
func.func @sort(%static_arg: tensor<2x4xf64>, %dynamic_arg: tensor<?x?xf64>) {
%recursively_speculatable_0:2 = "stablehlo.sort"(%static_arg, %static_arg) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>, %arg2: tensor<f64>, %arg3: tensor<f64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<f64>, tensor<f64>) -> tensor<i1>
stablehlo.return %predicate : tensor<i1>
}) {dimension = 0 : i64, is_stable = true} : (tensor<2x4xf64>, tensor<2x4xf64>) -> (tensor<?x?xf64>, tensor<?x?xf64>)
%not_recursively_speculatable_0:2 = "stablehlo.sort"(%dynamic_arg, %static_arg) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>, %arg2: tensor<f64>, %arg3: tensor<f64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<f64>, tensor<f64>) -> tensor<i1>
stablehlo.return %predicate : tensor<i1>
}) {dimension = 0 : i64, is_stable = true} : (tensor<?x?xf64>, tensor<2x4xf64>) -> (tensor<?x?xf64>, tensor<?x?xf64>)
%not_recursively_speculatable_1:2 = "stablehlo.sort"(%static_arg, %dynamic_arg) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>, %arg2: tensor<f64>, %arg3: tensor<f64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<f64>, tensor<f64>) -> tensor<i1>
stablehlo.return %predicate : tensor<i1>
}) {dimension = 0 : i64, is_stable = true} : (tensor<2x4xf64>, tensor<?x?xf64>) -> (tensor<?x?xf64>, tensor<?x?xf64>)
%not_recursively_speculatable_2:2 = "stablehlo.sort"(%dynamic_arg, %dynamic_arg) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>, %arg2: tensor<f64>, %arg3: tensor<f64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<f64>, tensor<f64>) -> tensor<i1>
stablehlo.return %predicate : tensor<i1>
}) {dimension = 0 : i64, is_stable = true} : (tensor<?x?xf64>, tensor<?x?xf64>) -> (tensor<?x?xf64>, tensor<?x?xf64>)
"hlo_test_speculatability.is_recursively_speculatable"(%recursively_speculatable_0) : (tensor<?x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_recursively_speculatable_0) : (tensor<?x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_recursively_speculatable_1) : (tensor<?x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_recursively_speculatable_2) : (tensor<?x?xf64>) -> ()
return
}

// -----

// Miscellaneous ops

// -----
Expand Down

0 comments on commit 837536c

Please sign in to comment.