Skip to content

Commit

Permalink
Organize code based on recent PRs
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Sep 13, 2023
1 parent dd052b7 commit 6375baf
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 184 deletions.
80 changes: 51 additions & 29 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,48 +318,55 @@ verifyWindowAttributesAndInferWindowDimensions(
" to have same dimension-size as size of window dimensions (",
windowDimensions.size(), "), but got: ", attrSize, ".");
};
// convolution_c3, reduce_window_c6, select_and_scatter_c6
// convolution_c2, reduce_window_c6, select_and_scatter_c6
if (failed(verifySize(windowStrides.size(), "window-strides")))
return failure();
// convolution_c6, reduce_window_c8

// convolution_c5, reduce_window_c8
if (failed(verifySize(lhsDilation.size(), "base-dilation factors")))
return failure();
// convolution_c8, reduce_window_c10

// convolution_c7, reduce_window_c10
if (failed(verifySize(rhsDilation.size(), "window-dilation factors")))
return failure();
// convolution_c5, reduce_window_c12

// convolution_c4, reduce_window_c12
if (failed(verifySize(padding.size(), "padding-entries"))) return failure();
// convolution_c10

// convolution_c9
if (failed(verifySize(windowReversal.size(), "window-reversal")))
return failure();

SmallVector<WindowDimension> window(windowDimensions.size());
for (size_t i = 0; i < windowDimensions.size(); i++) {
WindowDimension& dim = window[i];

dim.size = windowDimensions[i];

// reduce_window_c5, select_and_scatter_c5
if (!isDynamicDimSize(dim.size) && dim.size <= 0)
return emitOptionalError(loc,
"expects window to have positive value for ", i,
"-th window dimension, but got ", dim.size, ".");

if (!windowStrides.empty()) dim.stride = windowStrides[i];
// convolution_c4, reduce_window_c7, select_and_scatter_c7

// convolution_c3, reduce_window_c7, select_and_scatter_c7
if (dim.stride <= 0)
return emitOptionalError(
loc, "expects window to have positive stride for ", i,
"-th window dimension, but got ", dim.stride, ".");

if (!lhsDilation.empty()) dim.baseDilation = lhsDilation[i];
// convolution_c7, reduce_window_c9

// convolution_c6, reduce_window_c9
if (dim.baseDilation <= 0)
return emitOptionalError(
loc, "expects window to have positive base dilation factor for ", i,
"-th window dimension, but got ", dim.baseDilation, ".");

if (!rhsDilation.empty()) dim.windowDilation = rhsDilation[i];
// convolution_c9, reduce_window_c11

// convolution_c8, reduce_window_c11
if (dim.windowDilation <= 0)
return emitOptionalError(
loc, "expects window to have positive window dilation factor for ", i,
Expand Down Expand Up @@ -757,7 +764,7 @@ LogicalResult isSpatialDimensionsValid(
int64_t outputFeatureDimension, ArrayRef<int64_t> outputSpatialDimensions,
std::optional<Location> location) {
uint64_t spatialDimNum = inputSpatialDimensions.size();
// convolution_c18, convolution_c20
// convolution_c17, convolution_c19
if ((spatialDimNum != kernelSpatialDimensions.size()) ||
(spatialDimNum != outputSpatialDimensions.size()))
return emitOptionalError(location,
Expand Down Expand Up @@ -787,25 +794,28 @@ LogicalResult isSpatialDimensionsValid(

auto numDims = lhsType.cast<RankedTensorType>().getRank();
const auto inRange = [numDims](int64_t i) { return 0 <= i && i < numDims; };
// convolution_c14, convolution_c19, convolution_c21
// convolution_c13, convolution_c18, convolution_c20
if (!llvm::all_of(inputDimNums, inRange) ||
!llvm::all_of(windowDimNums, inRange) ||
!llvm::all_of(outputDimNums, inRange))
return emitOptionalError(location,
"expects input, kernel, and output "
"dimension-numbers to be in-range [0, ",
numDims, ").");
// convolution_c14

// convolution_c13
if (hasDuplicates(inputDimNums))
return emitOptionalError(
location, "expects input dimension-numbers to be unique, got {",
inputDimNums, "}.");
// convolution_c19

// convolution_c18
if (hasDuplicates(windowDimNums))
return emitOptionalError(
location, "expects kernel dimension-numbers to be unique, got {",
windowDimNums, "}.");
// convolution_c21

// convolution_c20
if (hasDuplicates(outputDimNums))
return emitOptionalError(
location, "expects output dimension-numbers to be unique, got {",
Expand Down Expand Up @@ -844,17 +854,19 @@ LogicalResult verifyConvolutionAttributes(
location)))
return failure();

// convolution_c22
// convolution_c21
if (featureGroupCount <= 0)
return emitOptionalError(
location, "expects feature_group_count to be a positive number, got ",
featureGroupCount, ".");
// convolution_c23

// convolution_c22
if (batchGroupCount <= 0)
return emitOptionalError(
location, "expects batch_group_count to be a positive number, got ",
batchGroupCount, ".");
// convolution_c24

// convolution_c23
if (batchGroupCount > 1 && featureGroupCount > 1)
return emitOptionalError(
location,
Expand All @@ -872,22 +884,24 @@ LogicalResult verifyConvolutionAttributes(
const int64_t kernelOutputFeatures =
rankedRhsType.getShape()[kernelOutputFeatureDimension];

// convolution_c11
// convolution_c10
if (!isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0)
return emitOptionalError(location, "expects input batch dimension (",
inputBatch,
") to be divisible by "
"batch_group_count. Got batch_group_count = ",
batchGroupCount, ".");

if (!isDynamicDimSize(inputFeatures)) {
// convolution_c12
// convolution_c11
if (inputFeatures % featureGroupCount != 0)
return emitOptionalError(location, "expects input feature dimension (",
inputFeatures,
") to be a multiple of feature_group_count. Got "
"feature_group_count = ",
featureGroupCount, ".");
// convolution_c15

// convolution_c14
if (!isDynamicDimSize(kernelInputFeatures) &&
inputFeatures / featureGroupCount != kernelInputFeatures)
return emitOptionalError(
Expand All @@ -897,15 +911,17 @@ LogicalResult verifyConvolutionAttributes(
kernelInputFeatures,
"). Got feature_group_count = ", featureGroupCount, ".");
}

if (!isDynamicDimSize(kernelOutputFeatures)) {
// convolution_c16
// convolution_c15
if (kernelOutputFeatures % batchGroupCount != 0)
return emitOptionalError(
location, "expects output feature dimension size (",
kernelOutputFeatures,
") to be a multiple of batch_group_count. Got batch_group_count = ",
batchGroupCount, ".");
// convolution_c17

// convolution_c16
if (kernelOutputFeatures % featureGroupCount != 0)
return emitOptionalError(location,
"expects kernel output feature dimension (",
Expand All @@ -915,7 +931,7 @@ LogicalResult verifyConvolutionAttributes(
featureGroupCount, ".");
}

// convolution_c25
// convolution_c24
if (failed(verifyPrecisionConfig(location, precisionConfig)))
return failure();

Expand Down Expand Up @@ -1714,20 +1730,22 @@ LogicalResult inferConvolutionOp(
return success();
}

// convolution_c14
// convolution_c13
int numDims = rankedLhsType.getRank();
if (numDims < 2)
return emitOptionalError(
location,
"expects convolution arguments to have >= 2 dimensions. Got: ",
rankedLhsType, " and ", rankedRhsType, ".");

// convolution_c1
if (numDims != rankedRhsType.getRank())
return emitOptionalError(location,
"expects convolution arguments to have same "
"number of dimensions. Got: ",
rankedLhsType, " and ", rankedRhsType, ".");
// convolution_c2

// convolution_c27
if (!isCompatibleForHloTypeInference(rankedLhsType.getElementType(),
rankedRhsType.getElementType()))
return emitOptionalError(
Expand All @@ -1743,7 +1761,8 @@ LogicalResult inferConvolutionOp(
outputSpatialDimensions, featureGroupCount, batchGroupCount,
precisionConfig)))
return failure();
// convolution_c13

// convolution_c12
if ((size_t)numDims != inputSpatialDimensions.size() + 2)
return emitOptionalError(location, "expects convolution arguments to have ",
inputSpatialDimensions.size() + 2,
Expand All @@ -1753,22 +1772,25 @@ LogicalResult inferConvolutionOp(
for (size_t i = 0; i < windowDimensions.size(); i++)
windowDimensions[i] = rankedRhsType.getShape()[kernelSpatialDimensions[i]];

// convolution_c5, convolution_i4
// convolution_c4, convolution_i4
auto paddingOrErr = convertPaddingAttribute(padding, location);
if (failed(paddingOrErr)) return failure();

// convolution_i3
auto windowStridesOrErr =
convert1DAttribute(windowStrides, location, "window_strides");
if (failed(windowStridesOrErr)) return failure();

// convolution_i5
auto lhsDilationOrErr =
convert1DAttribute(lhsDilation, location, "lhs_dilation");
if (failed(lhsDilationOrErr)) return failure();

// convolution_i6
auto rhsDilationOrErr =
convert1DAttribute(rhsDilation, location, "rhs_dilation");
if (failed(rhsDilationOrErr)) return failure();

// convolution_i7
auto windowReversalOrErr = convertWindowReversalAttribute(
windowReversal, location, "window_reversal");
Expand All @@ -1779,7 +1801,7 @@ LogicalResult inferConvolutionOp(
*rhsDilationOrErr, *windowReversalOrErr, location);
if (failed(windowOrErr)) return failure();

// convolution_c26
// convolution_c25, convolution_c26
SmallVector<int64_t> outputDimensions(rankedLhsType.getShape().size(),
ShapedType::kDynamic);
auto numSpatialDims = inputSpatialDimensions.size();
Expand Down Expand Up @@ -3312,7 +3334,7 @@ LogicalResult verifyConvolutionOp(

auto inferredShape = inferredReturnShapes[0];
auto shapedResultType = resultType.cast<ShapedType>();
// convolution_c26
// convolution_c25
if (inferredShape.hasRank() && shapedResultType.hasRank() &&
failed(verifyCompatibleShape(inferredShape.getDims(),
shapedResultType.getShape())))
Expand Down
Loading

0 comments on commit 6375baf

Please sign in to comment.