Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Mar 15, 2023
1 parent 09a3683 commit 9e21f9a
Show file tree
Hide file tree
Showing 10 changed files with 554 additions and 551 deletions.
55 changes: 27 additions & 28 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -1943,42 +1943,41 @@ If `batch_group_count > 1`:
#### Constraints

<!-- markdownlint-disable line-length -->
* (C1) rank(`lhs`) $\ge$ 2.
* (C2) $N =$ rank(`lhs`) $=$ rank(`rhs`).
* (C3) element_type(`lhs`) $=$ element_type(`rhs`).
* (C4) size(`window_strides`) $= N - 2$.
* (C5) `window_strides[i]` $\gt 0$ for all i $\in$ [0, size(`window_strides`)).
* (C6) dim(`padding`, 0) $= N - 2$ and dim(`padding`, 1) = 2.
* (C7) size(`lhs_dilation`) $= N - 2$.
* (C8) `lhs_dilation[i]` $\gt 0$ for all i $\in$ [0, size(`lhs_dilation`)).
* (C9) size(`rhs_dilation`) $= N - 2$.
* (C10) `rhs_dilation[i]` $\gt 0$ for all i $\in$ [0, size(`rhs_dilation`)).
* (C11) size(`window_reversal`) $= N - 2$.
* (C12) `dim(lhs, input_batch_dimension) % batch_group_count = 0`.
* (C13) `dim(lhs, input_feature_dimension) % feature_group_count = 0`.
* (C14) size(`input_spatial_dimensions`) $= N - 2$.
* (C15) Given `input_dimensions = [input_batch_dimension] +
* (C1) $N =$ rank(`lhs`) $=$ rank(`rhs`).
* (C2) element_type(`lhs`) $=$ element_type(`rhs`).
* (C3) size(`window_strides`) $= N - 2$.
* (C4) `window_strides[i]` $\gt 0$ for all i $\in$ [0, size(`window_strides`)).
* (C5) dim(`padding`, 0) $= N - 2$ and dim(`padding`, 1) = 2.
* (C6) size(`lhs_dilation`) $= N - 2$.
* (C7) `lhs_dilation[i]` $\gt 0$ for all i $\in$ [0, size(`lhs_dilation`)).
* (C8) size(`rhs_dilation`) $= N - 2$.
* (C9) `rhs_dilation[i]` $\gt 0$ for all i $\in$ [0, size(`rhs_dilation`)).
* (C10) size(`window_reversal`) $= N - 2$.
* (C11) `dim(lhs, input_batch_dimension) % batch_group_count = 0`.
* (C12) `dim(lhs, input_feature_dimension) % feature_group_count = 0`.
* (C13) size(`input_spatial_dimensions`) $= N - 2$.
* (C14) Given `input_dimensions = [input_batch_dimension] +
input_spatial_dimensions + [input_feature_dimension]`:
* All dimensions in `input_dimensions` are unique.
* For any i $\in$ `input_dimensions`, 0 $\le$ i $\lt$ N.
* (C16) `dim(rhs, kernel_input_feature_dimension = dim(lhs, input_feature_dimension) / feature_group_count`.
* (C17) `dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0`.
* (C18) `dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0`.
* (C19) size(`kernel_spatial_dimensions`) $= N - 2$.
* (C20) Given `kernel_dimensions = kernel_spatial_dimensions +
* (C15) `dim(rhs, kernel_input_feature_dimension = dim(lhs, input_feature_dimension) / feature_group_count`.
* (C16) `dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0`.
* (C17) `dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0`.
* (C18) size(`kernel_spatial_dimensions`) $= N - 2$.
* (C19) Given `kernel_dimensions = kernel_spatial_dimensions +
[kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
* All dimensions in `kernel_dimensions` are unique.
* For any i $\in$ `kernel_dimensions`, 0 $\le$ i $\lt$ N.
* (C21) size(`output_spatial_dimensions`) $= N - 2$.
* (C22) Given `output_dimensions = [output_batch_dimension] +
* (C20) size(`output_spatial_dimensions`) $= N - 2$.
* (C21) Given `output_dimensions = [output_batch_dimension] +
output_spatial_dimensions + [output_feature_dimension]`:
* All dimensions in `output_dimensions` are unique.
* For any i $\in$ `output_dimensions`, 0 $\le$ i $\lt$ N.
* (C23) `feature_group_count > 0`.
* (C24) `batch_group_count > 0`.
* (C25) Either `feature_group_count` $= 1$ or `batch_group_count` $= 1$.
* (C26) size(`precision_config`) $=$ 2.
* (C27) For result_dim $\in$ [0, N), `dim(result, result_dim)` is given by:
* (C22) `feature_group_count > 0`.
* (C23) `batch_group_count > 0`.
* (C24) Either `feature_group_count` $= 1$ or `batch_group_count` $= 1$.
* (C25) size(`precision_config`) $=$ 2.
* (C26) For result_dim $\in$ [0, N), `dim(result, result_dim)` is given by:
* `dim(lhs, input_batch_dimension) / batch_group_count`, if `result_dim = output_batch_dimension`.
* `dim(rhs, kernel_output_feature_dimension)`, if `result_dim = output_feature_dimension`.
* `num_windows` otherwise, where:
Expand Down Expand Up @@ -2026,8 +2025,8 @@ If `batch_group_count > 1`:
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
Expand Down
12 changes: 0 additions & 12 deletions output.mlir

This file was deleted.

4 changes: 2 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2044,10 +2044,10 @@ def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", [Pure]> {
pad = [[0, 0], [0, 0]],
lhs_dilate = [2, 2],
rhs_dilate = [1, 1],
reverse = [false, false]
reverse = [0, 0]
} {
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} :
(tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
Expand Down
61 changes: 29 additions & 32 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,18 +313,18 @@ verifyWindowAttributesAndInferWindowDimensions(
" to have same dimension-size as size of window dimensions (",
windowDimensions.size(), "), but got: ", attrSize, ".");
};
// convolution_c4
// convolution_c3
if (failed(verifySize(windowStrides.size(), "window-strides")))
return failure();
// convolution_c7
// convolution_c6
if (failed(verifySize(lhsDilation.size(), "base-dilation factors")))
return failure();
// convolution_c9
// convolution_c8
if (failed(verifySize(rhsDilation.size(), "window-dilation factors")))
return failure();
// convolution_c6
// convolution_c5
if (failed(verifySize(padding.size(), "padding-entries"))) return failure();
// convolution_c11
// convolution_c10
if (failed(verifySize(windowReversal.size(), "window-reversal")))
return failure();

Expand All @@ -339,21 +339,21 @@ verifyWindowAttributesAndInferWindowDimensions(
"-th window dimension, but got ", dim.size, ".");

if (!windowStrides.empty()) dim.stride = windowStrides[i];
// convolution_c5
// convolution_c4
if (dim.stride <= 0)
return emitOptionalError(
loc, "expects window to have positive stride for ", i,
"-th window dimension, but got ", dim.stride, ".");

if (!lhsDilation.empty()) dim.baseDilation = lhsDilation[i];
// convolution_c8
// convolution_c7
if (dim.baseDilation <= 0)
return emitOptionalError(
loc, "expects window to have positive base dilation factor for ", i,
"-th window dimension, but got ", dim.baseDilation, ".");

if (!rhsDilation.empty()) dim.windowDilation = rhsDilation[i];
// convolution_c10
// convolution_c9
if (dim.windowDilation <= 0)
return emitOptionalError(
loc, "expects window to have positive window dilation factor for ", i,
Expand Down Expand Up @@ -763,7 +763,7 @@ LogicalResult isSpatialDimensionsValid(
int64_t outputFeatureDimension, ArrayRef<int64_t> outputSpatialDimensions,
std::optional<Location> location) {
uint64_t spatialDimNum = inputSpatialDimensions.size();
// convolution_c19, convolution_c21
// convolution_c18, convolution_c20
if ((spatialDimNum != kernelSpatialDimensions.size()) ||
(spatialDimNum != outputSpatialDimensions.size()))
return emitOptionalError(location,
Expand Down Expand Up @@ -793,25 +793,25 @@ LogicalResult isSpatialDimensionsValid(

auto numDims = lhsType.cast<RankedTensorType>().getRank();
const auto inRange = [numDims](int64_t i) { return 0 <= i && i < numDims; };
// convolution_c15, convolution_c20, convolution_c22
// convolution_c14, convolution_c19, convolution_c21
if (!llvm::all_of(inputDimNums, inRange) ||
!llvm::all_of(windowDimNums, inRange) ||
!llvm::all_of(outputDimNums, inRange))
return emitOptionalError(location,
"expects input, kernel, and output "
"dimension-numbers to be in-range [0, ",
numDims, ").");
// convolution_c15
// convolution_c14
if (hasDuplicates(inputDimNums))
return emitOptionalError(
location, "expects input dimension-numbers to be unique, got {",
inputDimNums, "}.");
// convolution_c20
// convolution_c19
if (hasDuplicates(windowDimNums))
return emitOptionalError(
location, "expects kernel dimension-numbers to be unique, got {",
windowDimNums, "}.");
// convolution_c22
// convolution_c21
if (hasDuplicates(outputDimNums))
return emitOptionalError(
location, "expects output dimension-numbers to be unique, got {",
Expand Down Expand Up @@ -850,17 +850,17 @@ LogicalResult verifyConvolutionAttributes(
location)))
return failure();

// convolution_c23
// convolution_c22
if (featureGroupCount <= 0)
return emitOptionalError(
location, "expects feature_group_count to be a positive number, got ",
featureGroupCount, ".");
// convolution_c24
// convolution_c23
if (batchGroupCount <= 0)
return emitOptionalError(
location, "expects batch_group_count to be a positive number, got ",
batchGroupCount, ".");
// convolution_c25
// convolution_c24
if (batchGroupCount > 1 && featureGroupCount > 1)
return emitOptionalError(
location,
Expand All @@ -878,22 +878,22 @@ LogicalResult verifyConvolutionAttributes(
const int64_t kernelOutputFeatures =
rankedRhsType.getShape()[kernelOutputFeatureDimension];

// convolution_c12
// convolution_c11
if (!isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0)
return emitOptionalError(location, "expects input batch dimension (",
inputBatch,
") to be divisible by "
"batch_group_count. Got batch_group_count = ",
batchGroupCount, ".");
if (!isDynamicDimSize(inputFeatures)) {
// convolution_c13
// convolution_c12
if (inputFeatures % featureGroupCount != 0)
return emitOptionalError(location, "expects input feature dimension (",
inputFeatures,
") to be a multiple of feature_group_count. Got "
"feature_group_count = ",
featureGroupCount, ".");
// convolution_c16
// convolution_c15
if (!isDynamicDimSize(kernelInputFeatures) &&
inputFeatures / featureGroupCount != kernelInputFeatures)
return emitOptionalError(
Expand All @@ -904,14 +904,14 @@ LogicalResult verifyConvolutionAttributes(
"). Got feature_group_count = ", featureGroupCount, ".");
}
if (!isDynamicDimSize(kernelOutputFeatures)) {
// convolution_c17
// convolution_c16
if (kernelOutputFeatures % batchGroupCount != 0)
return emitOptionalError(
location, "expects output feature dimension size (",
kernelOutputFeatures,
") to be a multiple of batch_group_count. Got batch_group_count = ",
batchGroupCount, ".");
// convolution_c18
// convolution_c17
if (kernelOutputFeatures % featureGroupCount != 0)
return emitOptionalError(location,
"expects kernel output feature dimension (",
Expand All @@ -921,7 +921,7 @@ LogicalResult verifyConvolutionAttributes(
featureGroupCount, ".");
}

// convolution_c26
// convolution_c25
if (failed(verifyPrecisionConfig(location, precisionConfig)))
return failure();

Expand Down Expand Up @@ -1700,7 +1700,6 @@ LogicalResult inferConvertOp(
return success();
}

// TODO(b/232574102): Verify the element-type of return-value.
LogicalResult inferConvolutionOp(
std::optional<Location> location, Type lhsType, Type rhsType,
std::optional<DenseIntElementsAttr> windowStrides,
Expand All @@ -1723,20 +1722,20 @@ LogicalResult inferConvolutionOp(
return success();
}

// convolution_c1
// convolution_c14
int numDims = rankedLhsType.getRank();
if (numDims < 2)
return emitOptionalError(
location,
"expects convolution arguments to have >= 2 dimensions. Got: ",
rankedLhsType, " and ", rankedRhsType, ".");
// convolution_c2
// convolution_c1
if (numDims != rankedRhsType.getRank())
return emitOptionalError(location,
"expects convolution arguments to have same "
"number of dimensions. Got: ",
rankedLhsType, " and ", rankedRhsType, ".");
// convolution_c3
// convolution_c2
if (!isCompatibleForHloTypeInference(rankedLhsType.getElementType(),
rankedRhsType.getElementType()))
return emitOptionalError(
Expand All @@ -1752,7 +1751,7 @@ LogicalResult inferConvolutionOp(
outputSpatialDimensions, featureGroupCount, batchGroupCount,
precisionConfig)))
return failure();
// convolution_c14
// convolution_c13
if ((size_t)numDims != inputSpatialDimensions.size() + 2)
return emitOptionalError(location, "expects convolution arguments to have ",
inputSpatialDimensions.size() + 2,
Expand All @@ -1762,7 +1761,7 @@ LogicalResult inferConvolutionOp(
for (size_t i = 0; i < windowDimensions.size(); i++)
windowDimensions[i] = rankedRhsType.getShape()[kernelSpatialDimensions[i]];

// convolution_c6, convolution_i4
// convolution_c5, convolution_i4
auto paddingOrErr = convertPaddingAttribute(padding, location);
if (failed(paddingOrErr)) return failure();

Expand All @@ -1788,10 +1787,9 @@ LogicalResult inferConvolutionOp(
*rhsDilationOrErr, *windowReversalOrErr, location);
if (failed(windowOrErr)) return failure();

// convolution_c26
SmallVector<int64_t> outputDimensions(rankedLhsType.getShape().size(),
ShapedType::kDynamic);

// Infer the output spatial dimensions.
auto numSpatialDims = inputSpatialDimensions.size();
SmallVector<int64_t> inputSpatialDimVals(numSpatialDims);
for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
Expand All @@ -1803,7 +1801,6 @@ LogicalResult inferConvolutionOp(
for (int64_t i = 0; i < static_cast<int64_t>(windowOrErr->size()); ++i)
outputDimensions[outputSpatialDimensions[i]] = windowOutputShape[i];

// Infer the output-batch-dimension and output-feature-dimension.
const int64_t inputBatch = rankedLhsType.getShape()[inputBatchDimension];
const int64_t kernelOutputFeatures =
rankedRhsType.getShape()[kernelOutputFeatureDimension];
Expand Down Expand Up @@ -3256,7 +3253,7 @@ LogicalResult verifyConvolutionOp(

auto inferredShape = inferredReturnShapes[0];
auto shapedResultType = resultType.dyn_cast<ShapedType>();
// convolution_c27
// convolution_c26
if (inferredShape.hasRank() && shapedResultType.hasRank() &&
failed(verifyCompatibleShape(inferredShape.getDims(),
shapedResultType.getShape())))
Expand Down
5 changes: 5 additions & 0 deletions stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ void reifyGatherDimSizes(int64_t resultRank,
ArrayRef<int64_t> startIndexMap,
int64_t indexVectorDim, SmallVectorImpl<Value>& shape);

// Convert a Nx2 dense int64 padding attribute to a list of tuples.
FailureOr<SmallVector<std::pair<int64_t, int64_t>>> convertPaddingAttribute(
std::optional<DenseIntElementsAttr> optionalAttr,
std::optional<Location> loc);

// Convert a 1D dense bool attribute to a list of values.
FailureOr<SmallVector<bool>> convertWindowReversalAttribute(
std::optional<DenseElementsAttr> optionalAttr, std::optional<Location> loc,
Expand Down
Loading

0 comments on commit 9e21f9a

Please sign in to comment.