Skip to content

Commit

Permalink
Address remaining feedbacks from openxla#2312 (openxla#2327)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist authored May 13, 2024
1 parent 2d35f55 commit 966e4fb
Show file tree
Hide file tree
Showing 13 changed files with 355 additions and 648 deletions.
168 changes: 58 additions & 110 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -2296,11 +2296,11 @@ For hybrid quantized types, performs `hybrid_dequantize_then_op(
// ]
// ]]
//
// %rhs : [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
Expand Down Expand Up @@ -2706,87 +2706,17 @@ If not specified, all dimensions are assumed to be possibly expanding.

#### Semantics

Computes dot products between windows of `lhs` and slices of `rhs` and produces
`result`. The following diagram shows how elements in `result` are computed from
`lhs` and `rhs` using a concrete example.

![convolution](images/spec/convolution.svg)

More formally, consider the following reframing of the inputs in terms of `lhs`
in order to be able to express windows of `lhs`. Additionally, padding is
specified dynamically via `d_padding`:

<!-- markdownlint-disable line-length -->
* `lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))`.
* `lhs_window_strides = lhs_shape(1, window_strides, 1)`.
* `lhs_padding = lhs_shape([0, 0], padding, [0, 0])`.
* `lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)`.
* `lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)`.

This reframing uses the following helper functions:

* `lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])`.
* `result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])`.
* `permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]` where `j[d] = i[permutation[d]]`.

If `feature_group_count = 1` and `batch_group_count = 1`, then for all
`output_spatial_index` in `index_space(dim(result, output_spatial_dimensions...))`,
`result[result_shape(:, output_spatial_index, :)] = dot_product` where:

* `padding_value = constant(0, element_type(lhs))`.
* `padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)`.
* `lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides`.
* `lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)`.
* `reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])`.
This feature appears to be unused, so in the future we are planning to remove
it ([#1181](https://github.com/openxla/stablehlo/issues/1181)).
* `dot_product = dot_general(reversed_lhs_window, rhs,
lhs_batching_dimensions=[],
lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension],
rhs_batching_dimensions=[],
rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])`.

If `feature_group_count > 1`:

* `lhses = split(lhs, feature_group_count, input_feature_dimension)`.
* `rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)`.
* `results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)`.
* `result = concatenate(results, output_feature_dimension)`.

If `batch_group_count > 1`:

* `lhses = split(lhs, batch_group_count, input_batch_dimension)`.
* `rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)`.
* `results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)`.
* `result = concatenate(results, output_feature_dimension)`.
<!-- markdownlint-enable line-length -->

For quantized types, performs `dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, d_padding, window_strides,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))`.

For hybrid quantized types, performs `hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, d_padding, window_strides,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)`.
This operation is functionally identical to
[convolution](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution)
op, but the padding is specified dynamically via `d_padding`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-----------------------------------|--------------------------------------------------------------|-----------------------------------------------------------|
| (I1) | `lhs` | tensor or per-tensor quantized tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
| (I2) | `rhs` | tensor or quantized tensor | (C1), (C14-C16), (C26-C28), (C30-C33) |
| (I3) | `d_padding` | 2-dimensional tensor of type `si64` | (C4) |
| (I3) | `d_padding` | 2-dimensional tensor of integer type | (C4) |
| (I4) | `window_strides` | 1-dimensional tensor constant of type `si64` | (C2-C3) |
| (I5) | `lhs_dilation` | 1-dimensional tensor constant of type `si64` | (C5-C6) |
| (I6) | `rhs_dilation` | 1-dimensional tensor constant of type `si64` | (C7-C8) |
Expand Down Expand Up @@ -2846,22 +2776,34 @@ For hybrid quantized types, performs `hybrid_dequantize_then_op(
* (C22) `0 < batch_group_count`.
* (C23) `feature_group_count = 1 or batch_group_count = 1`.
* (C24) `size(precision_config) = 2`.
* (C25) `rank(result) = N`.
* (C25) `dim(result, result_dim)` is defined as:
* `dim(lhs, input_batch_dimension) / batch_group_count` if `result_dim = output_batch_dimension`.
* `dim(rhs, kernel_output_feature_dimension)` if `result_dim = output_feature_dimension`.
* `num_windows` otherwise, where:
* `output_spatial_dimensions[spatial_dim] = result_dim`.
* `lhs_dim = input_spatial_dimensions[spatial_dim]`.
* `rhs_dim = kernel_spatial_dimensions[spatial_dim]`.
* `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`.
* `padded_input_shape[lhs_dim] = d_padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + d_padding[spatial_dim, 1]`.
* `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`.
* `is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]`.
* `num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`.
* (C26) `rank(result) = N`.
* If the operation uses non-quantized tensors:
* (C26) `element_type(lhs) = element_type(rhs) = element_type(result)`.
* (C27) `element_type(lhs) = element_type(rhs) = element_type(result)`.
* If the operation uses quantized tensors:
* (C27) `is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)`.
* (C28) If `is_per_axis_quantized(rhs)`,
* (C28) `is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)`.
* (C29) If `is_per_axis_quantized(rhs)`,
then `quantization_dimension(rhs) = kernel_output_feature_dimension`.
* (C29) If `is_per_axis_quantized(result)`, then
* (C30) If `is_per_axis_quantized(result)`, then
`quantization_dimension(result) = output_feature_dimension`.
* If `is_quantized(lhs)`:
* (C30) `storage_type(lhs) = storage_type(rhs)`.
* (C31) `expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)`.
* (C32) If `is_per_tensor_quantized(rhs)`, then
* (C31) `storage_type(lhs) = storage_type(rhs)`.
* (C32) `expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)`.
* (C33) If `is_per_tensor_quantized(rhs)`, then
`is_per_tensor_quantized(result)`.
* If `!is_quantized(lhs)`:
* (C33) `element_type(lhs) = expressed_type(rhs) = element_type(result)`.
* (C34) `element_type(lhs) = expressed_type(rhs) = element_type(result)`.
<!-- markdownlint-enable line-length -->

#### Examples
Expand All @@ -2874,30 +2816,36 @@ For hybrid quantized types, performs `hybrid_dequantize_then_op(
// [[12], [13], [16], [17]]
// ]]
//
// %rhs : [
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
// %d_padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %d_padding) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// [[1], [5]],
// [[10], [14]]
// ]]
```

Expand Down Expand Up @@ -3002,10 +2950,10 @@ op, but the result shape is specified dynamically via `output_shape`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|------------------|----------------------------------------------|-------------|
| (I1) | `output_shape` | 1-dimensional tensor of integer type | (C1), (C2) |
| (I2) | `iota_dimension` | `si64` | (C1) |
| Label | Name | Type | Constraints |
|-------|------------------|--------------------------------------|-------------|
| (I1) | `output_shape` | 1-dimensional tensor of integer type | (C1), (C2) |
| (I2) | `iota_dimension` | `si64` | (C1) |

#### Outputs

Expand Down Expand Up @@ -3105,10 +3053,10 @@ op, but the result shape is specified dynamically via `output_shape`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|----------------|----------------------------------------------|-------------|
| (I1) | `operand` | tensor or quantized tensor | (C1-C3) |
| (I2) | `output_shape` | 1-dimensional tensor of integer type | (C4) |
| Label | Name | Type | Constraints |
|-------|----------------|--------------------------------------|-------------|
| (I1) | `operand` | tensor or quantized tensor | (C1-C3) |
| (I2) | `output_shape` | 1-dimensional tensor of integer type | (C4) |

#### Outputs

Expand Down Expand Up @@ -3142,7 +3090,7 @@ op, but the result shape is specified dynamically via `output_shape`.

```mlir
// %operand: [[1, 2, 3], [4, 5, 6]]
%output_shape = stablehlo.constant dense<[3, 2]> : tensor<2xi64>
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
```
Expand Down
4 changes: 2 additions & 2 deletions docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ one of the following tracking labels.
| dot | no | revisit | infeasible | yes | revisit |
| dot_general | yes | revisit | infeasible | no | yes |
| dynamic_broadcast_in_dim | yes | yes | infeasible | yes | revisit |
| dynamic_conv | yes | yes | infeasible | yes | revisit |
| dynamic_gather | yes | yes | infeasible | yes | revisit |
| dynamic_conv | yes | yes | infeasible | revisit | revisit |
| dynamic_gather | yes | yes | infeasible | no | revisit |
| dynamic_iota | yes | yes | infeasible | yes | revisit |
| dynamic_pad | yes | yes | infeasible | yes | revisit |
| dynamic_reshape | yes | yes | infeasible | yes | revisit |
Expand Down
3 changes: 3 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def HLO_Token : Type<CPred<"isa<TokenType>($_self)">, "token">;
// Any integer tensor types
def HLO_IntTensor : RankedTensorOf<[HLO_Int]>;

// Any integer tensor type with rank 2.
def HLO_2DIntTensor : TensorRankOf<[HLO_Int], [2]>;

// Any integer tensor type with rank 0 (i.e. representing a single integer).
def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>;

Expand Down
19 changes: 6 additions & 13 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3492,19 +3492,21 @@ def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv",
[HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> {
let summary = "DynamicConv operation";
let description = [{
Computes dot products between windows of `lhs` and slices of `rhs` and
produces `result`.
This operation is functionally identical to
[convolution](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution)
op, but the padding is specified dynamically via `d_padding`.

Example:
```mlir
%d_padding = stablehlo.constant dense<2> : tensor<2x2xi64>
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %d_padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
```
Expand All @@ -3513,7 +3515,7 @@ def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv",
let arguments = (ins
HLO_Tensor:$lhs, /*dynamic_conv_i1*/
HLO_Tensor:$rhs, /*dynamic_conv_i2*/
HLO_Tensor:$d_padding, /*dynamic_conv_i3*/
HLO_2DIntTensor:$d_padding, /*dynamic_conv_i3*/
OptionalAttr<GenericDenseI64ArrayAttr>:$window_strides, /*dynamic_conv_i4*/
OptionalAttr<GenericDenseI64ArrayAttr>:$lhs_dilation, /*dynamic_conv_i5*/
OptionalAttr<GenericDenseI64ArrayAttr>:$rhs_dilation, /*dynamic_conv_i6*/
Expand All @@ -3532,15 +3534,6 @@ def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv",
return reversal.has_value() && llvm::any_of(reversal.value(), [](bool v) { return v; });
}
}];

let assemblyFormat = [{
`(`operands`)`
`dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers) `,`
`window` `=` `{` custom<WindowAttributes>($window_strides,
$lhs_dilation, $rhs_dilation,
$window_reversal) `}`
attr-dict `:` functional-type(operands, results)
}];
}

#endif // STABLEHLO_DIALECT_STABLEHLO_OPS
2 changes: 2 additions & 0 deletions stablehlo/dialect/VhloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def VHLO_DynamicConvOpV1 : VHLO_Op<"dynamic_conv_v1", "0.9.0", "0.19.0"> {
let results = (outs VHLO_AnyType:$result);
}

// `d_padding` should be used instead of `padding` for dynamic convolution, so
// `padding` is removed for clarity.
def VHLO_DynamicConvOpV2 : VHLO_Op<"dynamic_conv_v2", "0.20.0", "current"> {
let arguments = (ins
VHLO_AnyType:$lhs,
Expand Down
12 changes: 9 additions & 3 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ Index evalIndex(Tensor tensor) {
return result;
}

template <typename T>
SmallVector<T> extractAttributeOrDefault(std::optional<ArrayRef<T>> attr,
int64_t size, T value) {
if (attr.has_value()) return llvm::to_vector(attr.value());
return SmallVector<T>(size, value);
}

Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs,
const Axes &lhsContractingDimensions,
const Axes &rhsContractingDimensions) {
Expand Down Expand Up @@ -519,9 +526,8 @@ SmallVector<InterpreterValue> eval(Region &region,
auto rhs = scope.findTensor(op.getRhs());
auto rank = lhs.getRank();

SmallVector<int64_t> windowStrides(rank - 2, 1);
if (auto windowStridesAttr = op.getWindowStrides())
windowStrides = SmallVector<int64_t>(windowStridesAttr.value());
SmallVector<int64_t> windowStrides = extractAttributeOrDefault<int64_t>(
op.getWindowStrides(), rank - 2, 1);

SmallVector<std::pair<int64_t, int64_t>> padding(rank - 2, {0, 0});
if (auto paddingAttr = op.getPaddingAttr()) {
Expand Down
Loading

0 comments on commit 966e4fb

Please sign in to comment.