Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for Scatter (openxla#2195)
Browse files Browse the repository at this point in the history
I am mostly interested in shape mismatches so I left a TODO
to check if the indices are unique or sorted. In the meantime,
we will do the conservative thing, which is to assume that the
op is not speculatable (the indices may not be unique/sorted).

NOTE: This PR builds on top of
openxla#2193. See the 2nd commit for
changes specific to this PR.
  • Loading branch information
mlevesquedion authored Apr 12, 2024
1 parent 837536c commit c304904
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 40 deletions.
16 changes: 16 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2334,6 +2334,22 @@ LogicalResult ScatterOp::verify() {
getScatterDimensionNumbers().getIndexVectorDim(), getUpdateComputation());
}

mlir::Speculation::Speculatability ScatterOp::getSpeculatability() {
// When unique_indices is true, if the scatter_indices are not unique, the
// behavior is undefined.
// A possible improvement would be to check if the scatter_indices are
// constant and if so, check if they are unique/sorted, and if so do not
// return NotSpeculatable. However, such a check could be somewhat costly and
// has unclear ROI.
if (getUniqueIndices() || getIndicesAreSorted())
return mlir::Speculation::NotSpeculatable;
return llvm::all_of(
this->getOperation()->getOperandTypes(),
[](Type t) { return cast<RankedTensorType>(t).hasStaticShape(); })
? mlir::Speculation::RecursivelySpeculatable
: mlir::Speculation::NotSpeculatable;
}

//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 7 additions & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2641,7 +2641,8 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}

def StableHLO_ScatterOp: StableHLO_Op<"scatter", [RecursiveMemoryEffects,
def StableHLO_ScatterOp: StableHLO_Op<"scatter",
[ConditionallySpeculatable, RecursiveMemoryEffects,
SameVariadicOperandSize /*scatter_c5*/,
DeclareOpInterfaceMethods<InferTypeOpInterface> /*scatter_c16,
scater_c17*/]> {
Expand Down Expand Up @@ -2685,6 +2686,11 @@ def StableHLO_ScatterOp: StableHLO_Op<"scatter", [RecursiveMemoryEffects,
let results = (outs Variadic<HLO_Tensor>);

let hasVerifier = 1;

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_SelectOp: StableHLO_Op<"select",
Expand Down
169 changes: 130 additions & 39 deletions stablehlo/tests/ops_speculatability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,51 @@ func.func @select(%static_pred: tensor<2xi1>, %dynamic_pred: tensor<?xi1>, %stat

// -----

// CHECK-LABEL: func @concatenate
// CHECK-NEXT: return
func.func @concatenate(%static_arg: tensor<2x2xi64>, %first_dim_dynamic: tensor<?x2xi64>, %second_dim_dynamic: tensor<2x?xi64>, %dynamic_arg: tensor<?x?xi64>) {
%speculatable_0 = stablehlo.concatenate %static_arg, %static_arg, dim = 0 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<?x?xi64>
%speculatable_1 = stablehlo.concatenate %static_arg, %static_arg, dim = 1 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<?x?xi64>
%speculatable_2 = stablehlo.concatenate %static_arg, %first_dim_dynamic, dim = 0 : (tensor<2x2xi64>, tensor<?x2xi64>) -> tensor<?x?xi64>
%speculatable_3 = stablehlo.concatenate %second_dim_dynamic, %static_arg, dim = 1 : (tensor<2x?xi64>, tensor<2x2xi64>) -> tensor<?x?xi64>

%speculatable_4 = stablehlo.concatenate %static_arg, %static_arg, dim = 0 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<4x2xi64>
%speculatable_5 = stablehlo.concatenate %static_arg, %static_arg, dim = 1 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x4xi64>
%not_speculatable_0 = stablehlo.concatenate %static_arg, %first_dim_dynamic, dim = 0 : (tensor<2x2xi64>, tensor<?x2xi64>) -> tensor<4x2xi64>
%not_speculatable_1 = stablehlo.concatenate %second_dim_dynamic, %static_arg, dim = 1 : (tensor<2x?xi64>, tensor<2x2xi64>) -> tensor<2x4xi64>

%speculatable_6 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 0 : (tensor<?x2xi64>, tensor<?x2xi64>) -> tensor<?x?xi64>
%not_speculatable_2 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 0 : (tensor<?x2xi64>, tensor<?x2xi64>) -> tensor<4x?xi64>
%not_speculatable_3 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 1 : (tensor<?x2xi64>, tensor<?x2xi64>) -> tensor<?x?xi64>

%not_speculatable_4 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 0 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor<?x?xi64>
%not_speculatable_5 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 1 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor<?x4xi64>
%speculatable_7 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 1 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor<?x?xi64>

%not_speculatable_6 = stablehlo.concatenate %dynamic_arg, %dynamic_arg, dim = 0 : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
%not_speculatable_7 = stablehlo.concatenate %dynamic_arg, %dynamic_arg, dim = 1 : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>

"hlo_test_speculatability.is_speculatable"(%speculatable_0) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_1) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_2) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_3) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_4) : (tensor<4x2xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_5) : (tensor<2x4xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_6) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_7) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_0) : (tensor<4x2xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_1) : (tensor<2x4xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_2) : (tensor<4x?xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_3) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_4) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_5) : (tensor<?x4xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_6) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_7) : (tensor<?x?xi64>) -> ()
return
}

// -----

// CHECK-LABEL: func @gather
// CHECK-NEXT: return
func.func @gather(
Expand Down Expand Up @@ -1303,46 +1348,92 @@ func.func @gather(
return
}

// CHECK-LABEL: func @concatenate
// CHECK-NEXT: return
func.func @concatenate(%static_arg: tensor<2x2xi64>, %first_dim_dynamic: tensor<?x2xi64>, %second_dim_dynamic: tensor<2x?xi64>, %dynamic_arg: tensor<?x?xi64>) {
%speculatable_0 = stablehlo.concatenate %static_arg, %static_arg, dim = 0 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<?x?xi64>
%speculatable_1 = stablehlo.concatenate %static_arg, %static_arg, dim = 1 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<?x?xi64>
%speculatable_2 = stablehlo.concatenate %static_arg, %first_dim_dynamic, dim = 0 : (tensor<2x2xi64>, tensor<?x2xi64>) -> tensor<?x?xi64>
%speculatable_3 = stablehlo.concatenate %second_dim_dynamic, %static_arg, dim = 1 : (tensor<2x?xi64>, tensor<2x2xi64>) -> tensor<?x?xi64>

%speculatable_4 = stablehlo.concatenate %static_arg, %static_arg, dim = 0 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<4x2xi64>
%speculatable_5 = stablehlo.concatenate %static_arg, %static_arg, dim = 1 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x4xi64>
%not_speculatable_0 = stablehlo.concatenate %static_arg, %first_dim_dynamic, dim = 0 : (tensor<2x2xi64>, tensor<?x2xi64>) -> tensor<4x2xi64>
%not_speculatable_1 = stablehlo.concatenate %second_dim_dynamic, %static_arg, dim = 1 : (tensor<2x?xi64>, tensor<2x2xi64>) -> tensor<2x4xi64>

%speculatable_6 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 0 : (tensor<?x2xi64>, tensor<?x2xi64>) -> tensor<?x?xi64>
%not_speculatable_2 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 0 : (tensor<?x2xi64>, tensor<?x2xi64>) -> tensor<4x?xi64>
%not_speculatable_3 = stablehlo.concatenate %first_dim_dynamic, %first_dim_dynamic, dim = 1 : (tensor<?x2xi64>, tensor<?x2xi64>) -> tensor<?x?xi64>

%not_speculatable_4 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 0 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor<?x?xi64>
%not_speculatable_5 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 1 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor<?x4xi64>
%speculatable_7 = stablehlo.concatenate %second_dim_dynamic, %second_dim_dynamic, dim = 1 : (tensor<2x?xi64>, tensor<2x?xi64>) -> tensor<?x?xi64>

%not_speculatable_6 = stablehlo.concatenate %dynamic_arg, %dynamic_arg, dim = 0 : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
%not_speculatable_7 = stablehlo.concatenate %dynamic_arg, %dynamic_arg, dim = 1 : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
// -----

"hlo_test_speculatability.is_speculatable"(%speculatable_0) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_1) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_2) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_3) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_4) : (tensor<4x2xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_5) : (tensor<2x4xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_6) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_7) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_0) : (tensor<4x2xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_1) : (tensor<2x4xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_2) : (tensor<4x?xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_3) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_4) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_5) : (tensor<?x4xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_6) : (tensor<?x?xi64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_7) : (tensor<?x?xi64>) -> ()
// CHECK-LABEL: func @scatter
// CHECK-NEXT: return
func.func @scatter(
%static_inputs: tensor<3x4x2xf64>, %static_indices: tensor<2x3x2xi64>, %static_updates: tensor<2x3x2x2xf64>,
%dynamic_inputs: tensor<?x?x?xf64>, %dynamic_indices: tensor<?x?x?xi64>, %dynamic_updates: tensor<?x?x?x?xf64>
) {
%recursively_speculatable_0 = "stablehlo.scatter"(%static_inputs, %static_indices, %static_updates) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
stablehlo.return %arg0 : tensor<f64>
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<3x4x2xf64>, tensor<2x3x2xi64>, tensor<2x3x2x2xf64>) -> tensor<?x?x?xf64>
%not_speculatable_0 = "stablehlo.scatter"(%static_inputs, %static_indices, %static_updates) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
stablehlo.return %arg0 : tensor<f64>
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = true
} : (tensor<3x4x2xf64>, tensor<2x3x2xi64>, tensor<2x3x2x2xf64>) -> tensor<?x?x?xf64>
%not_speculatable_1 = "stablehlo.scatter"(%static_inputs, %static_indices, %static_updates) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
stablehlo.return %arg0 : tensor<f64>
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = true,
unique_indices = false
} : (tensor<3x4x2xf64>, tensor<2x3x2xi64>, tensor<2x3x2x2xf64>) -> tensor<?x?x?xf64>
%not_speculatable_2 = "stablehlo.scatter"(%dynamic_inputs, %static_indices, %static_updates) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
stablehlo.return %arg0 : tensor<f64>
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<?x?x?xf64>, tensor<2x3x2xi64>, tensor<2x3x2x2xf64>) -> tensor<?x?x?xf64>
%not_speculatable_3 = "stablehlo.scatter"(%static_inputs, %dynamic_indices, %static_updates) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
stablehlo.return %arg0 : tensor<f64>
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<3x4x2xf64>, tensor<?x?x?xi64>, tensor<2x3x2x2xf64>) -> tensor<?x?x?xf64>
%not_speculatable_4 = "stablehlo.scatter"(%static_inputs, %static_indices, %dynamic_updates) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
stablehlo.return %arg0 : tensor<f64>
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<3x4x2xf64>, tensor<2x3x2xi64>, tensor<?x?x?x?xf64>) -> tensor<?x?x?xf64>
"hlo_test_speculatability.is_recursively_speculatable"(%recursively_speculatable_0) : (tensor<?x?x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_0) : (tensor<?x?x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_1) : (tensor<?x?x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_2) : (tensor<?x?x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_3) : (tensor<?x?x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_4) : (tensor<?x?x?xf64>) -> ()
return
}

Expand Down

0 comments on commit c304904

Please sign in to comment.