Skip to content

Commit

Permalink
dynamic_broadcast_in_dim followup cleanup (openxla#2325)
Browse files Browse the repository at this point in the history
Separate constraints are not required, C8 is taking care of C9, C10.
Same verifier check is applicable.
  • Loading branch information
abhigunj authored May 11, 2024
1 parent 532a0ef commit 12fd0a9
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
8 changes: 3 additions & 5 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -2683,11 +2683,9 @@ If not specified, all dimensions are assumed to be possibly expanding.
zero_points(operand)[0] for i in
range(dim(result, quantization_dimension(result)))`.
* (C7) `size(output_dimensions) = rank(result)`.
* (C8) `is_unique(known_expanding_dimensions)`.
* (C9) `is_unique(known_non_expanding_dimensions)`.
* (C10) `is_unique(known_expanding_dimensions + known_non_expanding_dimensions)`.
* (C11) `0 <= known_expanding_dimensions < rank(operand)`.
* (C12) `0 <= known_non_expanding_dimensions < rank(operand)`.
* (C8) `is_unique(known_expanding_dimensions + known_non_expanding_dimensions)`.
* (C9) `0 <= known_expanding_dimensions < rank(operand)`.
* (C10) `0 <= known_non_expanding_dimensions < rank(operand)`.

#### Examples

Expand Down
5 changes: 2 additions & 3 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3991,14 +3991,13 @@ LogicalResult verifyDynamicBroadcastInDimOp(
collectExpansionBehaviorDims(knownExpandingDimensions);
collectExpansionBehaviorDims(knownNonexpandingDimensions);

// dynamic_broadcast_in_dim_c8, dynamic_broadcast_in_dim_c9,
// dynamic_broadcast_in_dim_c10
// dynamic_broadcast_in_dim_c8
if (knownExpansionBehavior.size() != numKnownExpansionBehavior)
return emitOptionalError(
location,
"duplicate expansion hint for at least one operand dimension");

// dynamic_broadcast_in_dim_c11, dynamic_broadcast_in_dim_c12
// dynamic_broadcast_in_dim_c9, dynamic_broadcast_in_dim_c10
for (int64_t i : knownExpansionBehavior)
if (i < 0 || i >= operandType.getRank())
return emitOptionalError(location, "hint for expanding dimension ", i,
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ func.func @dynamic_broadcast_in_dim_c7_output_dimensions_mismatching_size(%arg0:

// -----

func.func @dynamic_broadcast_in_dim_c8_c9_c10(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<?x?x?xi32> {
func.func @dynamic_broadcast_in_dim_c8(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<?x?x?xi32> {
// expected-error@+1 {{duplicate expansion hint for at least one operand dimension}}
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {
broadcast_dimensions = array<i64: 1, 2>,
Expand All @@ -1128,7 +1128,7 @@ func.func @dynamic_broadcast_in_dim_c8_c9_c10(%arg0: tensor<?x?xi32>, %shape: te

// -----

func.func @dynamic_broadcast_in_dim_c11_c12(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<?x?x?xi32> {
func.func @dynamic_broadcast_in_dim_c9_c10(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<?x?x?xi32> {
// expected-error@+1 {{hint for expanding dimension 3 does not refer to a valid operand dimension}}
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {
broadcast_dimensions = array<i64: 1, 2>,
Expand Down

0 comments on commit 12fd0a9

Please sign in to comment.