From 58328b8ce9d9af8d53aa110834835cc37ce99edf Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 15 Mar 2023 23:58:20 +0000 Subject: [PATCH] Address feedback --- stablehlo/reference/Ops.cpp | 442 ++++++++++++++++++------------------ stablehlo/reference/Ops.h | 16 +- 2 files changed, 232 insertions(+), 226 deletions(-) diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 7b7defc1ec0..476f8603802 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -20,7 +20,6 @@ limitations under the License. #include "llvm/Support/Errc.h" #include "llvm/Support/Error.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/DebugStringHelper.h" #include "stablehlo/dialect/TypeInference.h" @@ -47,15 +46,142 @@ SmallVector extractElements(ArrayRef arr, } template -DenseIntElementsAttr getDenseIntElementsAttr(SmallVector shape, - Type elementType, T values) { +DenseIntElementsAttr getDenseIntElementsAttr( + Type elementType, T values, + std::optional> valuesShape) { + SmallVector shape = + valuesShape.has_value() + ? *valuesShape + : SmallVector({static_cast(values.size())}); return DenseIntElementsAttr::get(RankedTensorType::get(shape, elementType), values); } -SmallVector split(const Tensor &input, int64_t groupSize, - Axis splitDimension, MLIRContext *context) { - Builder builder = mlir::Builder(context); +TensorType inferConvolutionOpType( + TensorType lhsType, TensorType rhsType, ArrayRef windowStrides, + ArrayRef> padding, + ArrayRef lhsDilation, ArrayRef rhsDilation, + ArrayRef windowReversal, Axis inputBatchDimension, + Axis inputFeatureDimension, Axes inputSpatialDimensions, + Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, + Axes kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, Axes outputSpatialDimensions, + int64_t featureGroupCount, int64_t batchGroupCount, + std::optional precisionConfig, TensorType resultType) { + Builder builder = mlir::Builder(lhsType.getContext()); + Type i64Type = builder.getI64Type(); + Type i1Type = builder.getI1Type(); + auto flattenPadding = [&](ArrayRef> padding) { + SmallVector paddingVector; + for (auto pair : padding) { + paddingVector.push_back(pair.first); + paddingVector.push_back(pair.second); + } + return paddingVector; + }; + SmallVector paddingShape{static_cast(padding.size()), 2}; + SmallVector inferredConvolutionType; + auto convolutionStatus = hlo::inferConvolutionOp( + /*location=*/{}, lhsType, rhsType, + getDenseIntElementsAttr(i64Type, windowStrides, {}), + getDenseIntElementsAttr(i64Type, flattenPadding(padding), paddingShape), + getDenseIntElementsAttr(i64Type, lhsDilation, {}), + getDenseIntElementsAttr(i64Type, rhsDilation, {}), + getDenseIntElementsAttr(i1Type, windowReversal, {}), + static_cast(inputBatchDimension), + static_cast(inputFeatureDimension), + ArrayRef(inputSpatialDimensions), + static_cast(kernelInputFeatureDimension), + static_cast(kernelOutputFeatureDimension), + ArrayRef(kernelSpatialDimensions), + static_cast(outputBatchDimension), + static_cast(outputFeatureDimension), + ArrayRef(outputSpatialDimensions), + /*featureGroupCount=*/1, batchGroupCount, + /*precisionConfig=*/{}, inferredConvolutionType); + if (failed(convolutionStatus)) + report_fatal_error( + invalidArgument("Could not infer ConvolutionOp's return type")); + return RankedTensorType::get(inferredConvolutionType[0].getDims(), + resultType.getElementType()); +} + +TensorType inferDotGeneralOpType(TensorType lhsType, TensorType rhsType, + ArrayRef lhsContractingDimensions, + ArrayRef rhsContractingDimensions) { + SmallVector inferredDotGeneralType; + auto dotGeneralStatus = hlo::inferDotGeneralOp( + /*location=*/{}, lhsType, rhsType, + /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions*/ {}, + lhsContractingDimensions, rhsContractingDimensions, + /*precisionConfig=*/{}, inferredDotGeneralType); + if (failed(dotGeneralStatus)) + report_fatal_error( + invalidArgument("Could not infer DotGeneralOp's return type")); + return RankedTensorType::get(inferredDotGeneralType[0].getDims(), + lhsType.getElementType()); +} + +TensorType inferPadOpType(ArrayRef> lhsPadding, + Type operandType, Type resultElementType, + ArrayRef interiorPadding) { + Builder builder = mlir::Builder(operandType.getContext()); + Type i64Type = builder.getI64Type(); + SmallVector lhsPaddingLow; + SmallVector lhsPaddingHigh; + for (auto paddingPair : lhsPadding) { + lhsPaddingLow.push_back(paddingPair.first); + lhsPaddingHigh.push_back(paddingPair.second); + } + SmallVector inferredPadType; + auto padStatus = hlo::inferPadOp( + {}, operandType, RankedTensorType::get({}, resultElementType), + getDenseIntElementsAttr(i64Type, lhsPaddingLow, {}), + getDenseIntElementsAttr(i64Type, lhsPaddingHigh, {}), + getDenseIntElementsAttr(i64Type, interiorPadding, {}), inferredPadType); + if (failed(padStatus)) + report_fatal_error(invalidArgument("Could not infer PadOp's return type")); + return inferredPadType[0].cast(); +} + +TensorType inferSliceOpType(Type operandType, + SmallVector lhsWindowStart, + SmallVector limitIndices, + SmallVector lhsWindowDilations) { + Builder builder = mlir::Builder(operandType.getContext()); + Type i64Type = builder.getI64Type(); + SmallVector inferredSliceType; + auto sliceStatus = hlo::inferSliceOp( + {}, operandType, getDenseIntElementsAttr(i64Type, lhsWindowStart, {}), + getDenseIntElementsAttr(i64Type, limitIndices, {}), + getDenseIntElementsAttr(i64Type, lhsWindowDilations, {}), + inferredSliceType); + if (failed(sliceStatus)) + report_fatal_error( + invalidArgument("Could not infer SliceOp's return type")); + return inferredSliceType[0].cast(); +} + +template +SmallVector concatAndPermute(T n, SmallVector hw, T c, int64_t dimA, + SmallVector dimB, int64_t dimC) { + SmallVector permInput; + permInput.push_back(n); + permInput.append(hw.begin(), hw.end()); + permInput.push_back(c); + SmallVector permDims; + permDims.push_back(dimA); + permDims.append(dimB.begin(), dimB.end()); + permDims.push_back(dimC); + SmallVector permOutput(permDims.size()); + for (auto [idx, dim] : llvm::enumerate(permDims)) + permOutput[dim] = permInput[idx]; + return permOutput; +} + +SmallVector splitIntoNGroupsAlongDim(const Tensor &input, int64_t N, + Axis splitDimension) { + Builder builder = mlir::Builder(input.getType().getContext()); auto i64Type = builder.getI64Type(); auto getScalarTensor = [&](auto value) { @@ -63,12 +189,12 @@ SmallVector split(const Tensor &input, int64_t groupSize, DenseElementsAttr::get(RankedTensorType::get({}, i64Type), value)); }; auto splitInputShape = SmallVector(input.getShape()); - splitInputShape[splitDimension] /= groupSize; + splitInputShape[splitDimension] /= N; auto splitInputType = RankedTensorType::get(splitInputShape, input.getElementType()); SmallVector splitResults; - for (auto idx = 0; idx < groupSize; ++idx) { + for (auto idx = 0; idx < N; ++idx) { SmallVector inputStartIndices(input.getRank(), getScalarTensor(0L)); inputStartIndices[splitDimension] = getScalarTensor(idx * splitInputShape[splitDimension]); @@ -81,27 +207,23 @@ SmallVector split(const Tensor &input, int64_t groupSize, } Tensor getZeroScalarTensor(Type elementType) { - Tensor tensor; - if (isSupportedIntegerType(elementType)) { - tensor = makeTensor( + if (isSupportedIntegerType(elementType)) + return makeTensor( DenseElementsAttr::get(RankedTensorType::get({}, elementType), Element(elementType, 0L).getIntegerValue())); - } else if (isSupportedBooleanType(elementType)) { - tensor = makeTensor( + if (isSupportedBooleanType(elementType)) + return makeTensor( DenseElementsAttr::get(RankedTensorType::get({}, elementType), false)); - } else if (isSupportedFloatType(elementType)) { - tensor = makeTensor( + if (isSupportedFloatType(elementType)) + return makeTensor( DenseElementsAttr::get(RankedTensorType::get({}, elementType), Element(elementType, 0.0).getFloatValue())); - } else if (isSupportedComplexType(elementType)) { - tensor = makeTensor(DenseElementsAttr::get( + if (isSupportedComplexType(elementType)) + return makeTensor(DenseElementsAttr::get( RankedTensorType::get({}, elementType), Element(elementType, std::complex(0.0)).getComplexValue())); - } else { - report_fatal_error(invalidArgument("Unsupported element type: %s", - debugString(elementType).c_str())); - } - return tensor; + report_fatal_error(invalidArgument("Unsupported element type: %s", + debugString(elementType).c_str())); } } // namespace @@ -246,14 +368,14 @@ Tensor evalConvertOp(const Tensor &operand, TensorType resultType) { Tensor evalConvolutionOp( const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, - SmallVector> padding, ArrayRef lhsDilation, - ArrayRef rhsDilation, ArrayRef windowReversal, - Axis inputBatchDimension, Axis inputFeatureDimension, - Axes inputSpatialDimensions, Axis kernelInputFeatureDimension, - Axis kernelOutputFeatureDimension, Axes kernelSpatialDimensions, - Axis outputBatchDimension, Axis outputFeatureDimension, - Axes outputSpatialDimensions, int64_t featureGroupCount, - int64_t batchGroupCount, TensorType resultType) { + ArrayRef> padding, + ArrayRef lhsDilation, ArrayRef rhsDilation, + ArrayRef windowReversal, Axis inputBatchDimension, + Axis inputFeatureDimension, Axes inputSpatialDimensions, + Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, + Axes kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, Axes outputSpatialDimensions, + int64_t featureGroupCount, int64_t batchGroupCount, TensorType resultType) { Tensor result(resultType); // This check is necessary here as we are not looping over result tensor but @@ -261,51 +383,13 @@ Tensor evalConvolutionOp( // fail if the dimension size is zero. if (resultType.getNumElements() == 0) return result; - Builder builder = mlir::Builder(resultType.getContext()); - Type i64Type = builder.getI64Type(); - Type i1Type = builder.getI1Type(); - - auto flattenPadding = [&](auto padding) { - SmallVector paddingVector; - for (auto pair : padding) paddingVector.append(pair.begin(), pair.end()); - return paddingVector; - }; - if (featureGroupCount > 1) { - auto flatPadding = flattenPadding(padding); - auto lhses = split(lhs, featureGroupCount, inputFeatureDimension, - resultType.getContext()); - auto rhses = split(rhs, featureGroupCount, kernelOutputFeatureDimension, - resultType.getContext()); - SmallVector paddingShape{static_cast(padding.size()), 2}; + auto lhses = + splitIntoNGroupsAlongDim(lhs, featureGroupCount, inputFeatureDimension); + auto rhses = splitIntoNGroupsAlongDim(rhs, featureGroupCount, + kernelOutputFeatureDimension); SmallVector results; for (auto [left, right] : llvm::zip(lhses, rhses)) { - SmallVector inferredConvolutionType; - auto convolutionStatus = hlo::inferConvolutionOp( - /*location=*/{}, left.getType(), right.getType(), - getDenseIntElementsAttr({static_cast(windowStrides.size())}, - i64Type, windowStrides), - getDenseIntElementsAttr(paddingShape, i64Type, flatPadding), - getDenseIntElementsAttr({static_cast(lhsDilation.size())}, - i64Type, lhsDilation), - getDenseIntElementsAttr({static_cast(rhsDilation.size())}, - i64Type, rhsDilation), - getDenseIntElementsAttr({static_cast(windowReversal.size())}, - i1Type, windowReversal), - static_cast(inputBatchDimension), - static_cast(inputFeatureDimension), - ArrayRef(inputSpatialDimensions), - static_cast(kernelInputFeatureDimension), - static_cast(kernelOutputFeatureDimension), - ArrayRef(kernelSpatialDimensions), - static_cast(outputBatchDimension), - static_cast(outputFeatureDimension), - ArrayRef(outputSpatialDimensions), - /*featureGroupCount=*/1, batchGroupCount, - /*precisionConfig=*/{}, inferredConvolutionType); - if (failed(convolutionStatus)) - report_fatal_error( - invalidArgument("Could not infer ConvolutionOp's return type")); auto convolutionResult = evalConvolutionOp( left, right, windowStrides, padding, lhsDilation, rhsDilation, windowReversal, inputBatchDimension, inputFeatureDimension, @@ -313,48 +397,26 @@ Tensor evalConvolutionOp( kernelOutputFeatureDimension, kernelSpatialDimensions, outputBatchDimension, outputFeatureDimension, outputSpatialDimensions, /*featureGroupCount=*/1, batchGroupCount, - RankedTensorType::get(inferredConvolutionType[0].getDims(), - result.getElementType()) - .cast()); + inferConvolutionOpType( + left.getType(), right.getType(), windowStrides, padding, + lhsDilation, rhsDilation, windowReversal, inputBatchDimension, + inputFeatureDimension, inputSpatialDimensions, + kernelInputFeatureDimension, kernelOutputFeatureDimension, + kernelSpatialDimensions, outputBatchDimension, + outputFeatureDimension, outputSpatialDimensions, + /*featureGroupCount=*/1, batchGroupCount, + /*precisionConfig=*/{}, resultType)); results.push_back(convolutionResult); } return evalConcatenateOp(results, outputFeatureDimension, result.getType()); } if (batchGroupCount > 1) { - auto flatPadding = flattenPadding(padding); - auto lhses = split(lhs, batchGroupCount, inputBatchDimension, - resultType.getContext()); - auto rhses = split(rhs, batchGroupCount, kernelOutputFeatureDimension, - resultType.getContext()); - SmallVector paddingShape{static_cast(padding.size()), 2}; + auto lhses = + splitIntoNGroupsAlongDim(lhs, batchGroupCount, inputBatchDimension); + auto rhses = splitIntoNGroupsAlongDim(rhs, batchGroupCount, + kernelOutputFeatureDimension); SmallVector results; for (auto [left, right] : llvm::zip(lhses, rhses)) { - SmallVector inferredConvolutionType; - auto convolutionStatus = hlo::inferConvolutionOp( - /*location=*/{}, left.getType(), right.getType(), - getDenseIntElementsAttr({static_cast(windowStrides.size())}, - i64Type, windowStrides), - getDenseIntElementsAttr(paddingShape, i64Type, flatPadding), - getDenseIntElementsAttr({static_cast(lhsDilation.size())}, - i64Type, lhsDilation), - getDenseIntElementsAttr({static_cast(rhsDilation.size())}, - i64Type, rhsDilation), - getDenseIntElementsAttr({static_cast(windowReversal.size())}, - i1Type, windowReversal), - static_cast(inputBatchDimension), - static_cast(inputFeatureDimension), - ArrayRef(inputSpatialDimensions), - static_cast(kernelInputFeatureDimension), - static_cast(kernelOutputFeatureDimension), - ArrayRef(kernelSpatialDimensions), - static_cast(outputBatchDimension), - static_cast(outputFeatureDimension), - ArrayRef(outputSpatialDimensions), featureGroupCount, - /*batchGroupCount=*/1, - /*precisionConfig=*/{}, inferredConvolutionType); - if (failed(convolutionStatus)) - report_fatal_error( - invalidArgument("Could not infer ConvolutionOp's return type")); auto convolutionResult = evalConvolutionOp( left, right, windowStrides, padding, lhsDilation, rhsDilation, windowReversal, inputBatchDimension, inputFeatureDimension, @@ -362,113 +424,70 @@ Tensor evalConvolutionOp( kernelOutputFeatureDimension, kernelSpatialDimensions, outputBatchDimension, outputFeatureDimension, outputSpatialDimensions, featureGroupCount, /*batchGroupCount=*/1, - RankedTensorType::get(inferredConvolutionType[0].getDims(), - result.getElementType()) - .cast()); + inferConvolutionOpType( + left.getType(), right.getType(), windowStrides, padding, + lhsDilation, rhsDilation, windowReversal, inputBatchDimension, + inputFeatureDimension, inputSpatialDimensions, + kernelInputFeatureDimension, kernelOutputFeatureDimension, + kernelSpatialDimensions, outputBatchDimension, + outputFeatureDimension, outputSpatialDimensions, + featureGroupCount, /*batchGroupCount=*/1, + /*precisionConfig=*/{}, resultType)); results.push_back(convolutionResult); } return evalConcatenateOp(results, outputFeatureDimension, result.getType()); } - auto lhsShape = [&](auto n, auto hw, auto c) { - SmallVector permInput; - permInput.push_back(n); - permInput.append(hw.begin(), hw.end()); - permInput.push_back(c); - SmallVector permDims; - permDims.push_back(inputBatchDimension); - permDims.append(inputSpatialDimensions.begin(), - inputSpatialDimensions.end()); - permDims.push_back(inputFeatureDimension); - SmallVector permOutput(permDims.size()); - for (auto [idx, dim] : llvm::enumerate(permDims)) - permOutput[dim] = permInput[idx]; - return permOutput; - }; - auto lhsShapePadding = [&](auto n, auto hw, auto c) { - SmallVector> permInput; - permInput.push_back(n); - permInput.append(hw.begin(), hw.end()); - permInput.push_back(c); - SmallVector permDims; - permDims.push_back(inputBatchDimension); - permDims.append(inputSpatialDimensions.begin(), - inputSpatialDimensions.end()); - permDims.push_back(inputFeatureDimension); - SmallVector> permOutput(permDims.size()); - for (auto [idx, dim] : llvm::enumerate(permDims)) - permOutput[dim] = permInput[idx]; - return permOutput; - }; - - auto lhsWindowDimensions = - lhsShape(lhs.getShape()[inputBatchDimension], - extractElements(rhs.getShape(), kernelSpatialDimensions), - lhs.getShape()[inputFeatureDimension]); - auto lhsWindowStrides = lhsShape(1, windowStrides, 1); - auto lhsPadding = lhsShapePadding(SmallVector(2, 0), padding, - SmallVector(2, 0)); - auto lhsBaseDilations = lhsShape(0, Sizes(lhsDilation) - 1, 0); - auto lhsWindowDilations = lhsShape(1, rhsDilation, 1); + auto lhsWindowDimensions = concatAndPermute( + lhs.getShape()[inputBatchDimension], + extractElements(rhs.getShape(), kernelSpatialDimensions), + lhs.getShape()[inputFeatureDimension], inputBatchDimension, + inputSpatialDimensions, inputFeatureDimension); + auto lhsWindowStrides = concatAndPermute( + 1L, llvm::to_vector(windowStrides), 1L, inputBatchDimension, + inputSpatialDimensions, inputFeatureDimension); + auto lhsPadding = concatAndPermute( + {0, 0}, llvm::to_vector(padding), {0, 0}, inputBatchDimension, + inputSpatialDimensions, inputFeatureDimension); + auto lhsBaseDilations = + concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, inputBatchDimension, + inputSpatialDimensions, inputFeatureDimension); + auto lhsWindowDilations = concatAndPermute( + 1L, llvm::to_vector(rhsDilation), 1L, inputBatchDimension, + inputSpatialDimensions, inputFeatureDimension); auto outputSpacialIndexIt = IndexSpaceIterator( Sizes(extractElements(result.getShape(), outputSpatialDimensions)), Index(outputSpatialDimensions.size())); - auto outputSpacialIndexItEnd = IndexSpaceIterator( Sizes(extractElements(result.getShape(), outputSpatialDimensions)), std::nullopt); - auto resultElementType = result.getElementType(); for (; outputSpacialIndexIt != outputSpacialIndexItEnd; ++outputSpacialIndexIt) { SmallVector lhsPaddingLow; - SmallVector lhsPaddingHigh; - for (auto paddingPair : lhsPadding) { - lhsPaddingLow.push_back(paddingPair[0]); - lhsPaddingHigh.push_back(paddingPair[1]); - } - SmallVector inferredPadType; - auto padStatus = hlo::inferPadOp( - {}, lhs.getType(), RankedTensorType::get({}, resultElementType), - getDenseIntElementsAttr({static_cast(lhsPaddingLow.size())}, - i64Type, lhsPaddingLow), - getDenseIntElementsAttr({static_cast(lhsPaddingHigh.size())}, - i64Type, lhsPaddingHigh), - getDenseIntElementsAttr({static_cast(lhsBaseDilations.size())}, - i64Type, llvm::to_vector(lhsBaseDilations)), - inferredPadType); - if (failed(padStatus)) - report_fatal_error( - invalidArgument("Could not infer PadOp's return type")); - auto paddedLhs = evalPadOp(lhs, getZeroScalarTensor(resultElementType), - Sizes(lhsPaddingLow), Sizes(lhsBaseDilations), - inferredPadType[0].cast()); + for (auto paddingPair : lhsPadding) + lhsPaddingLow.push_back(paddingPair.first); + auto paddedLhs = + evalPadOp(lhs, getZeroScalarTensor(result.getElementType()), + Sizes(lhsPaddingLow), Sizes(lhsBaseDilations), + inferPadOpType(lhsPadding, lhs.getType(), + result.getElementType(), lhsBaseDilations)); SmallVector lhsWindowStart; - for (auto [i, offset] : - llvm::enumerate(lhsShape(0, *outputSpacialIndexIt, 0))) + for (auto [i, offset] : llvm::enumerate( + concatAndPermute(0L, llvm::to_vector(*outputSpacialIndexIt), 0L, + inputBatchDimension, inputSpatialDimensions, + inputFeatureDimension))) lhsWindowStart.push_back(lhsWindowStrides[i] * offset); SmallVector limitIndices; for (size_t i = 0; i < lhsWindowStart.size(); ++i) limitIndices.push_back(std::min( lhsWindowStart[i] + lhsWindowDimensions[i] * lhsWindowDilations[i], paddedLhs.getShape()[i])); - SmallVector inferredSliceType; - auto sliceStatus = hlo::inferSliceOp( - {}, paddedLhs.getType(), - getDenseIntElementsAttr({static_cast(lhsWindowStart.size())}, - i64Type, llvm::to_vector(lhsWindowStart)), - getDenseIntElementsAttr({static_cast(lhsWindowStart.size())}, - i64Type, llvm::to_vector(limitIndices)), - getDenseIntElementsAttr({static_cast(lhsWindowStart.size())}, - i64Type, llvm::to_vector(lhsWindowDilations)), - inferredSliceType); - if (failed(sliceStatus)) - report_fatal_error( - invalidArgument("Could not infer SliceOp's return type")); - auto lhsWindow = evalSliceOp(paddedLhs, Sizes(lhsWindowStart), Sizes(lhsWindowDilations), - inferredSliceType[0].cast()); + inferSliceOpType(paddedLhs.getType(), lhsWindowStart, + limitIndices, lhsWindowDilations)); + SmallVector reverseDims; for (auto [i, isReverse] : llvm::enumerate(windowReversal)) if (isReverse) reverseDims.push_back(inputSpatialDimensions[i]); @@ -481,32 +500,23 @@ Tensor evalConvolutionOp( auto rhsContractingDimensions = llvm::to_vector(kernelSpatialDimensions); rhsContractingDimensions.push_back( static_cast(kernelInputFeatureDimension)); - SmallVector inferredDotGeneralType; - auto dotGeneralStatus = hlo::inferDotGeneralOp( - /*location=*/{}, reversedLhsWindow.getType(), rhs.getType(), - /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions*/ {}, - lhsContractingDimensions, rhsContractingDimensions, - /*precisionConfig=*/{}, inferredDotGeneralType); - if (failed(dotGeneralStatus)) - report_fatal_error( - invalidArgument("Could not infer DotGeneralOp's return type")); - auto dotProduct = evalDotGeneralOp( reversedLhsWindow, rhs, /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions=*/{}, Axes(lhsContractingDimensions), Axes(rhsContractingDimensions), - RankedTensorType::get(inferredDotGeneralType[0].getDims(), - reversedLhsWindow.getElementType()) - .cast()); + inferDotGeneralOpType(reversedLhsWindow.getType(), rhs.getType(), + lhsContractingDimensions, + rhsContractingDimensions)); - auto resultShape = [&](auto it) { + auto resultShape = [&](auto resultOffsetIt) { SmallVector resultIndex; - int64_t outputSpacialIndexItIdx = 0; - int64_t indexer = 0; - for (auto i = 0; i < result.getRank(); ++i) { - if (llvm::find(outputSpatialDimensions, i) == + for (auto outputSpacialDimensionsIdx = 0, resultOffsetItIdx = 0, + outputSpacialIndexItIdx = 0; + outputSpacialDimensionsIdx < result.getRank(); + ++outputSpacialDimensionsIdx) { + if (llvm::find(outputSpatialDimensions, outputSpacialDimensionsIdx) == outputSpatialDimensions.end()) - resultIndex.push_back((*it)[indexer++]); + resultIndex.push_back((*resultOffsetIt)[resultOffsetItIdx++]); else resultIndex.push_back( (*outputSpacialIndexIt)[outputSpacialIndexItIdx++]); @@ -984,33 +994,29 @@ SmallVector eval( Tensor runtimeRhs = scope.find(convolutionOp.getRhs()); int64_t rank = runtimeLhs.getRank(); auto windowStrides = SmallVector(rank - 2, 1); - auto padding = SmallVector>(rank - 2, {0, 0}); - auto lhsDilation = SmallVector(rank - 2, 1); - auto rhsDilation = SmallVector(rank - 2, 1); - auto windowReversal = SmallVector(rank - 2, false); if (convolutionOp.getWindowStrides().has_value()) windowStrides = llvm::to_vector( convolutionOp.getWindowStridesAttr().getValues()); + auto padding = SmallVector>(rank - 2, {0, 0}); if (convolutionOp.getPadding().has_value()) { auto paddingOrErr = hlo::convertPaddingAttribute(convolutionOp.getPadding(), {}); if (failed(paddingOrErr)) report_fatal_error(invalidArgument("Invalid padding format found.")); - for (auto i = 0; i < rank - 2; ++i) { - padding[i][0] = (*paddingOrErr)[i].first; - padding[i][1] = (*paddingOrErr)[i].second; - } + padding = *paddingOrErr; } + auto lhsDilation = SmallVector(rank - 2, 1); if (convolutionOp.getLhsDilation().has_value()) lhsDilation = llvm::to_vector( convolutionOp.getLhsDilationAttr().getValues()); + auto rhsDilation = SmallVector(rank - 2, 1); if (convolutionOp.getRhsDilation().has_value()) rhsDilation = llvm::to_vector( convolutionOp.getRhsDilationAttr().getValues()); + auto windowReversal = SmallVector(rank - 2, false); if (convolutionOp.getWindowReversal().has_value()) windowReversal = llvm::to_vector( convolutionOp.getWindowReversalAttr().getValues()); - Tensor runtimeResult = evalConvolutionOp( runtimeLhs, runtimeRhs, windowStrides, padding, lhsDilation, rhsDilation, windowReversal, diff --git a/stablehlo/reference/Ops.h b/stablehlo/reference/Ops.h index 241cebe4ac4..5c01dde4f57 100644 --- a/stablehlo/reference/Ops.h +++ b/stablehlo/reference/Ops.h @@ -46,14 +46,14 @@ Tensor evalConstantOp(ElementsAttr value); Tensor evalConvertOp(const Tensor &operand, TensorType resultType); Tensor evalConvolutionOp( const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, - SmallVector> padding, ArrayRef lhsDilation, - ArrayRef rhsDilation, ArrayRef windowReversal, - Axis inputBatchDimension, Axis inputFeatureDimension, - Axes inputSpatialDimensions, Axis kernelInputFeatureDimension, - Axis kernelOutputFeatureDimension, Axes kernelSpatialDimensions, - Axis outputBatchDimension, Axis outputFeatureDimension, - Axes outputSpatialDimensions, int64_t featureGroupCount, - int64_t batchGroupCount, TensorType resultType); + ArrayRef> padding, + ArrayRef lhsDilation, ArrayRef rhsDilation, + ArrayRef windowReversal, Axis inputBatchDimension, + Axis inputFeatureDimension, Axes inputSpatialDimensions, + Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, + Axes kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, Axes outputSpatialDimensions, + int64_t featureGroupCount, int64_t batchGroupCount, TensorType resultType); Tensor evalCosineOp(const Tensor &operand, TensorType resultType); Tensor evalDivideOp(const Tensor &lhs, const Tensor &rhs, TensorType resultType);