Skip to content

Commit

Permalink
Allow CustomCallOp to have unranked tensor operands and results (open…
Browse files Browse the repository at this point in the history
…xla#2055)

Custom call op permits just about everything: Tuple, Tensor, Token,
Unranked, all quantization, etc. This change restores its ability to
operate on unranked tensors.

This was squashed in a merge conflict between openxla#2045 and openxla#2007. Adding a
testpoint to avoid this issue going forward.
  • Loading branch information
GleasonK authored Feb 27, 2024
1 parent 4a26dde commit 77a4b4c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>;
// TODO(b/326463552): Remove these when CHLO no longer needs unranked dynamism.
//===----------------------------------------------------------------------===//

def HLO_AnyTensor : TensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>;
def HLO_AnyTensor : TensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt]>;

def HLO_AnyPredTensor : TensorOf<[HLO_Pred]>;

Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2182,7 +2182,7 @@ def StableHLO_CustomCallOp: StableHLO_Op<"custom_call",
}];

let arguments = (ins
Variadic<HLO_TensorOrPerAxisQuantizedTensorOrTokenOrTuple>:$inputs,
Variadic<HLO_CustomCallValue>:$inputs,
StrAttr:$call_target_name,
DefaultValuedOptionalAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedStrAttr<StrAttr, "">:$backend_config,
Expand All @@ -2202,7 +2202,7 @@ def StableHLO_CustomCallOp: StableHLO_Op<"custom_call",
"{}">:$output_operand_aliases
);

let results = (outs Variadic<HLO_TensorOrPerAxisQuantizedTensorOrTokenOrTuple>);
let results = (outs Variadic<HLO_CustomCallValue>);
let hasVerifier = 1;

let assemblyFormat = [{
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4221,6 +4221,15 @@ func.func @custom_call_output_operand_alias(%arg0: tuple<tensor<1x1xf32>, tensor

// -----

// CHECK-LABEL: func @custom_call_unranked_types
func.func @custom_call_unranked_types(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: stablehlo.custom_call {{.*}} : (tensor<*xf32>) -> tensor<*xf32>
%0 = "stablehlo.custom_call"(%arg0) {call_target_name = "foo"} : (tensor<*xf32>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

// Test custom attribute printing/parsing.
// We really just need one op as holder, use module: this is the simplest top-level.

Expand Down

0 comments on commit 77a4b4c

Please sign in to comment.