diff --git a/stablehlo/reference/Index.h b/stablehlo/reference/Index.h index 17d8e96c7d0..fddf6aa1d69 100644 --- a/stablehlo/reference/Index.h +++ b/stablehlo/reference/Index.h @@ -34,10 +34,7 @@ class IndexSpaceIterator { /// \name Constructor IndexSpaceIterator(Sizes shape, std::optional index) : shape_(shape), index_(index) { - if (index && !index->inBounds(shape)) - llvm::report_fatal_error( - "Incompatible index and shape found while creating " - "an IndexSpaceIterator"); + if (index && !index->inBounds(shape)) index_ = std::nullopt; } /// Get the current index. diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 476f8603802..ef91e8f300d 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -31,6 +31,14 @@ namespace mlir { namespace stablehlo { namespace { +static Type getI1Type(MLIRContext *context) { + return mlir::Builder(context).getI1Type(); +} + +static Type getI64Type(MLIRContext *context) { + return mlir::Builder(context).getI64Type(); +} + Index evalIndices(ArrayRef runtimeIndices) { Index index(runtimeIndices.size()); for (size_t i = 0; i < runtimeIndices.size(); ++i) @@ -68,23 +76,21 @@ TensorType inferConvolutionOpType( Axis outputFeatureDimension, Axes outputSpatialDimensions, int64_t featureGroupCount, int64_t batchGroupCount, std::optional precisionConfig, TensorType resultType) { - Builder builder = mlir::Builder(lhsType.getContext()); - Type i64Type = builder.getI64Type(); - Type i1Type = builder.getI1Type(); - auto flattenPadding = [&](ArrayRef> padding) { - SmallVector paddingVector; - for (auto pair : padding) { - paddingVector.push_back(pair.first); - paddingVector.push_back(pair.second); - } - return paddingVector; - }; + auto i64Type = getI64Type(lhsType.getContext()); + auto i1Type = getI1Type(lhsType.getContext()); + + SmallVector paddingVector; + for (auto pair : padding) { + paddingVector.push_back(pair.first); + paddingVector.push_back(pair.second); + } + SmallVector paddingShape{static_cast(padding.size()), 2}; SmallVector inferredConvolutionType; auto convolutionStatus = hlo::inferConvolutionOp( /*location=*/{}, lhsType, rhsType, getDenseIntElementsAttr(i64Type, windowStrides, {}), - getDenseIntElementsAttr(i64Type, flattenPadding(padding), paddingShape), + getDenseIntElementsAttr(i64Type, paddingVector, paddingShape), getDenseIntElementsAttr(i64Type, lhsDilation, {}), getDenseIntElementsAttr(i64Type, rhsDilation, {}), getDenseIntElementsAttr(i1Type, windowReversal, {}), @@ -122,25 +128,27 @@ TensorType inferDotGeneralOpType(TensorType lhsType, TensorType rhsType, lhsType.getElementType()); } -TensorType inferPadOpType(ArrayRef> lhsPadding, - Type operandType, Type resultElementType, +TensorType inferPadOpType(ArrayRef> padding, + Type operandType, Type paddingValueElementType, ArrayRef interiorPadding) { - Builder builder = mlir::Builder(operandType.getContext()); - Type i64Type = builder.getI64Type(); SmallVector lhsPaddingLow; SmallVector lhsPaddingHigh; - for (auto paddingPair : lhsPadding) { + for (auto paddingPair : padding) { lhsPaddingLow.push_back(paddingPair.first); lhsPaddingHigh.push_back(paddingPair.second); } + SmallVector inferredPadType; + auto i64Type = getI64Type(operandType.getContext()); auto padStatus = hlo::inferPadOp( - {}, operandType, RankedTensorType::get({}, resultElementType), + {}, operandType, RankedTensorType::get({}, paddingValueElementType), getDenseIntElementsAttr(i64Type, lhsPaddingLow, {}), getDenseIntElementsAttr(i64Type, lhsPaddingHigh, {}), getDenseIntElementsAttr(i64Type, interiorPadding, {}), inferredPadType); + if (failed(padStatus)) report_fatal_error(invalidArgument("Could not infer PadOp's return type")); + return inferredPadType[0].cast(); } @@ -148,59 +156,62 @@ TensorType inferSliceOpType(Type operandType, SmallVector lhsWindowStart, SmallVector limitIndices, SmallVector lhsWindowDilations) { - Builder builder = mlir::Builder(operandType.getContext()); - Type i64Type = builder.getI64Type(); SmallVector inferredSliceType; + auto i64Type = getI64Type(operandType.getContext()); auto sliceStatus = hlo::inferSliceOp( {}, operandType, getDenseIntElementsAttr(i64Type, lhsWindowStart, {}), getDenseIntElementsAttr(i64Type, limitIndices, {}), getDenseIntElementsAttr(i64Type, lhsWindowDilations, {}), inferredSliceType); + if (failed(sliceStatus)) report_fatal_error( invalidArgument("Could not infer SliceOp's return type")); + return inferredSliceType[0].cast(); } +// Returns `result` with the effect of applying `permutation` +// (= [dimA] + dimsB + [dimC]) to `input` (= [n] + hw + [c]) such that +// result[permutation[i]] = input[i]. template -SmallVector concatAndPermute(T n, SmallVector hw, T c, int64_t dimA, - SmallVector dimB, int64_t dimC) { - SmallVector permInput; - permInput.push_back(n); - permInput.append(hw.begin(), hw.end()); - permInput.push_back(c); - SmallVector permDims; - permDims.push_back(dimA); - permDims.append(dimB.begin(), dimB.end()); - permDims.push_back(dimC); - SmallVector permOutput(permDims.size()); - for (auto [idx, dim] : llvm::enumerate(permDims)) - permOutput[dim] = permInput[idx]; - return permOutput; +SmallVector concatAndPermute(T n, SmallVector hw, T c, + const Axes &permutation) { + SmallVector input; + input.push_back(n); + input.append(hw.begin(), hw.end()); + input.push_back(c); + + if (input.size() != permutation.size()) + llvm::report_fatal_error( + "Expect same size for permutation and the array to be permuted"); + + SmallVector result(permutation.size()); + for (auto [idx, dim] : llvm::enumerate(permutation)) result[dim] = input[idx]; + return result; } -SmallVector splitIntoNGroupsAlongDim(const Tensor &input, int64_t N, +SmallVector splitIntoNGroupsAlongDim(const Tensor &input, + int64_t nGroups, Axis splitDimension) { - Builder builder = mlir::Builder(input.getType().getContext()); - auto i64Type = builder.getI64Type(); - auto getScalarTensor = [&](auto value) { - return makeTensor( - DenseElementsAttr::get(RankedTensorType::get({}, i64Type), value)); + auto i64Type = getI64Type(input.getType().getContext()); + return Tensor(RankedTensorType::get({}, i64Type), Element(i64Type, value)); }; - auto splitInputShape = SmallVector(input.getShape()); - splitInputShape[splitDimension] /= N; + + Sizes splitInputShape(input.getShape()); + splitInputShape[splitDimension] /= nGroups; auto splitInputType = RankedTensorType::get(splitInputShape, input.getElementType()); SmallVector splitResults; - for (auto idx = 0; idx < N; ++idx) { + for (auto idx = 0; idx < nGroups; ++idx) { SmallVector inputStartIndices(input.getRank(), getScalarTensor(0L)); inputStartIndices[splitDimension] = getScalarTensor(idx * splitInputShape[splitDimension]); - Sizes strides(input.getRank(), 1); - auto resultTensor = - evalDynamicSliceOp(input, inputStartIndices, strides, splitInputType); + + auto resultTensor = evalDynamicSliceOp(input, inputStartIndices, + splitInputShape, splitInputType); splitResults.push_back(resultTensor); } return splitResults; @@ -208,20 +219,21 @@ SmallVector splitIntoNGroupsAlongDim(const Tensor &input, int64_t N, Tensor getZeroScalarTensor(Type elementType) { if (isSupportedIntegerType(elementType)) - return makeTensor( - DenseElementsAttr::get(RankedTensorType::get({}, elementType), - Element(elementType, 0L).getIntegerValue())); + return Tensor(RankedTensorType::get({}, elementType), + Element(elementType, 0L)); + if (isSupportedBooleanType(elementType)) - return makeTensor( - DenseElementsAttr::get(RankedTensorType::get({}, elementType), false)); + return Tensor(RankedTensorType::get({}, elementType), + Element(elementType, false)); + if (isSupportedFloatType(elementType)) - return makeTensor( - DenseElementsAttr::get(RankedTensorType::get({}, elementType), - Element(elementType, 0.0).getFloatValue())); + return Tensor(RankedTensorType::get({}, elementType), + Element(elementType, 0.0)); + if (isSupportedComplexType(elementType)) - return makeTensor(DenseElementsAttr::get( - RankedTensorType::get({}, elementType), - Element(elementType, std::complex(0.0)).getComplexValue())); + return Tensor(RankedTensorType::get({}, elementType), + Element(elementType, std::complex(0.0))); + report_fatal_error(invalidArgument("Unsupported element type: %s", debugString(elementType).c_str())); } @@ -378,11 +390,6 @@ Tensor evalConvolutionOp( int64_t featureGroupCount, int64_t batchGroupCount, TensorType resultType) { Tensor result(resultType); - // This check is necessary here as we are not looping over result tensor but - // instead creating iterators using output spatial dimension size, which may - // fail if the dimension size is zero. - if (resultType.getNumElements() == 0) return result; - if (featureGroupCount > 1) { auto lhses = splitIntoNGroupsAlongDim(lhs, featureGroupCount, inputFeatureDimension); @@ -410,6 +417,7 @@ Tensor evalConvolutionOp( } return evalConcatenateOp(results, outputFeatureDimension, result.getType()); } + if (batchGroupCount > 1) { auto lhses = splitIntoNGroupsAlongDim(lhs, batchGroupCount, inputBatchDimension); @@ -437,35 +445,42 @@ Tensor evalConvolutionOp( } return evalConcatenateOp(results, outputFeatureDimension, result.getType()); } - auto lhsWindowDimensions = concatAndPermute( - lhs.getShape()[inputBatchDimension], - extractElements(rhs.getShape(), kernelSpatialDimensions), - lhs.getShape()[inputFeatureDimension], inputBatchDimension, - inputSpatialDimensions, inputFeatureDimension); - auto lhsWindowStrides = concatAndPermute( - 1L, llvm::to_vector(windowStrides), 1L, inputBatchDimension, - inputSpatialDimensions, inputFeatureDimension); - auto lhsPadding = concatAndPermute( - {0, 0}, llvm::to_vector(padding), {0, 0}, inputBatchDimension, - inputSpatialDimensions, inputFeatureDimension); + + Axes lhsPermutation; + lhsPermutation.push_back(inputBatchDimension); + lhsPermutation.append(inputSpatialDimensions.begin(), + inputSpatialDimensions.end()); + lhsPermutation.push_back(inputFeatureDimension); + + auto lhsWindowDimensions = + concatAndPermute(lhs.getShape()[inputBatchDimension], + extractElements(rhs.getShape(), kernelSpatialDimensions), + lhs.getShape()[inputFeatureDimension], lhsPermutation); + + auto lhsWindowStrides = + concatAndPermute(1L, llvm::to_vector(windowStrides), 1L, lhsPermutation); + + auto lhsPadding = concatAndPermute({0, 0}, llvm::to_vector(padding), {0, 0}, + lhsPermutation); + auto lhsBaseDilations = - concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, inputBatchDimension, - inputSpatialDimensions, inputFeatureDimension); - auto lhsWindowDilations = concatAndPermute( - 1L, llvm::to_vector(rhsDilation), 1L, inputBatchDimension, - inputSpatialDimensions, inputFeatureDimension); + concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, lhsPermutation); + + auto lhsWindowDilations = + concatAndPermute(1L, llvm::to_vector(rhsDilation), 1L, lhsPermutation); - auto outputSpacialIndexIt = IndexSpaceIterator( + auto outputSpatialIndexIt = IndexSpaceIterator( Sizes(extractElements(result.getShape(), outputSpatialDimensions)), Index(outputSpatialDimensions.size())); - auto outputSpacialIndexItEnd = IndexSpaceIterator( + auto outputSpatialIndexItEnd = IndexSpaceIterator( Sizes(extractElements(result.getShape(), outputSpatialDimensions)), std::nullopt); - for (; outputSpacialIndexIt != outputSpacialIndexItEnd; - ++outputSpacialIndexIt) { + for (; outputSpatialIndexIt != outputSpatialIndexItEnd; + ++outputSpatialIndexIt) { SmallVector lhsPaddingLow; for (auto paddingPair : lhsPadding) lhsPaddingLow.push_back(paddingPair.first); + auto paddedLhs = evalPadOp(lhs, getZeroScalarTensor(result.getElementType()), Sizes(lhsPaddingLow), Sizes(lhsBaseDilations), @@ -473,16 +488,16 @@ Tensor evalConvolutionOp( result.getElementType(), lhsBaseDilations)); SmallVector lhsWindowStart; - for (auto [i, offset] : llvm::enumerate( - concatAndPermute(0L, llvm::to_vector(*outputSpacialIndexIt), 0L, - inputBatchDimension, inputSpatialDimensions, - inputFeatureDimension))) + for (auto [i, offset] : llvm::enumerate(concatAndPermute( + 0L, llvm::to_vector(*outputSpatialIndexIt), 0L, lhsPermutation))) lhsWindowStart.push_back(lhsWindowStrides[i] * offset); + SmallVector limitIndices; for (size_t i = 0; i < lhsWindowStart.size(); ++i) limitIndices.push_back(std::min( lhsWindowStart[i] + lhsWindowDimensions[i] * lhsWindowDilations[i], paddedLhs.getShape()[i])); + auto lhsWindow = evalSliceOp(paddedLhs, Sizes(lhsWindowStart), Sizes(lhsWindowDilations), inferSliceOpType(paddedLhs.getType(), lhsWindowStart, @@ -491,15 +506,18 @@ Tensor evalConvolutionOp( SmallVector reverseDims; for (auto [i, isReverse] : llvm::enumerate(windowReversal)) if (isReverse) reverseDims.push_back(inputSpatialDimensions[i]); + auto reversedLhsWindow = evalReverseOp(lhsWindow, Axes(reverseDims), lhsWindow.getType()); auto lhsContractingDimensions = llvm::to_vector(inputSpatialDimensions); lhsContractingDimensions.push_back( static_cast(inputFeatureDimension)); + auto rhsContractingDimensions = llvm::to_vector(kernelSpatialDimensions); rhsContractingDimensions.push_back( static_cast(kernelInputFeatureDimension)); + auto dotProduct = evalDotGeneralOp( reversedLhsWindow, rhs, /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions=*/{}, Axes(lhsContractingDimensions), @@ -508,34 +526,26 @@ Tensor evalConvolutionOp( lhsContractingDimensions, rhsContractingDimensions)); - auto resultShape = [&](auto resultOffsetIt) { - SmallVector resultIndex; - for (auto outputSpacialDimensionsIdx = 0, resultOffsetItIdx = 0, - outputSpacialIndexItIdx = 0; - outputSpacialDimensionsIdx < result.getRank(); - ++outputSpacialDimensionsIdx) { - if (llvm::find(outputSpatialDimensions, outputSpacialDimensionsIdx) == - outputSpatialDimensions.end()) - resultIndex.push_back((*resultOffsetIt)[resultOffsetItIdx++]); - else - resultIndex.push_back( - (*outputSpacialIndexIt)[outputSpacialIndexItIdx++]); - } - return resultIndex; - }; - - SmallVector resultOffset; + Sizes resultNonSpatialDims; for (auto i = 0; i < result.getRank(); ++i) if (llvm::find(outputSpatialDimensions, i) == outputSpatialDimensions.end()) - resultOffset.push_back(result.getShape()[i]); + resultNonSpatialDims.push_back(result.getShape()[i]); + + Axes resultPermutation; + resultPermutation.push_back(outputBatchDimension); + resultPermutation.append(outputSpatialDimensions.begin(), + outputSpatialDimensions.end()); + resultPermutation.push_back(outputFeatureDimension); - auto resultOffsetIt = - IndexSpaceIterator(Sizes(resultOffset), Index(resultOffset.size())); + auto resultNonSpatialIt = IndexSpaceIterator( + resultNonSpatialDims, Index(resultNonSpatialDims.size())); for (auto dotProductIt = dotProduct.index_begin(); dotProductIt != dotProduct.index_end(); - ++dotProductIt, ++resultOffsetIt) { - auto resultIndex = resultShape(resultOffsetIt); + ++dotProductIt, ++resultNonSpatialIt) { + auto resultIndex = + concatAndPermute((*resultNonSpatialIt)[0], *outputSpatialIndexIt, + (*resultNonSpatialIt)[1], resultPermutation); result.set(Index(resultIndex), dotProduct.get(*dotProductIt)); } } diff --git a/stablehlo/reference/Tensor.cpp b/stablehlo/reference/Tensor.cpp index f4510236ddf..3a4503449c5 100644 --- a/stablehlo/reference/Tensor.cpp +++ b/stablehlo/reference/Tensor.cpp @@ -93,6 +93,13 @@ Tensor::Tensor(TensorType type) Tensor::Tensor(TensorType type, AsmResourceBlob blob) : impl_(llvm::makeIntrusiveRefCnt(type, std::move(blob))) {} +Tensor::Tensor(TensorType type, const Element &element) + : impl_(llvm::makeIntrusiveRefCnt(type)) { + for (auto indexIt = this->index_begin(); indexIt != this->index_end(); + ++indexIt) + this->set(*indexIt, element); +} + Element Tensor::get(const Index &index) const { Type elementType = getType().getElementType(); const char *elementPtr = diff --git a/stablehlo/reference/Tensor.h b/stablehlo/reference/Tensor.h index 60937cbd41b..e9f6cf21bde 100644 --- a/stablehlo/reference/Tensor.h +++ b/stablehlo/reference/Tensor.h @@ -72,6 +72,9 @@ class Tensor { Tensor(); explicit Tensor(TensorType type); explicit Tensor(TensorType type, AsmResourceBlob blob); + /// This constructor initializes the tensor populated with provided initial + /// values. This constructor is O(n) with respect to the tensor size. + explicit Tensor(TensorType type, const Element &initValue); Tensor(const Tensor &other) = default; /// @}