Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Mar 28, 2023
1 parent 58328b8 commit 3b92a93
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 112 deletions.
5 changes: 1 addition & 4 deletions stablehlo/reference/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ class IndexSpaceIterator {
/// \name Constructor
IndexSpaceIterator(Sizes shape, std::optional<Index> index)
: shape_(shape), index_(index) {
if (index && !index->inBounds(shape))
llvm::report_fatal_error(
"Incompatible index and shape found while creating "
"an IndexSpaceIterator");
if (index && !index->inBounds(shape)) index_ = std::nullopt;
}

/// Get the current index.
Expand Down
226 changes: 118 additions & 108 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ namespace mlir {
namespace stablehlo {
namespace {

static Type getI1Type(MLIRContext *context) {
return mlir::Builder(context).getI1Type();
}

static Type getI64Type(MLIRContext *context) {
return mlir::Builder(context).getI64Type();
}

Index evalIndices(ArrayRef<Tensor> runtimeIndices) {
Index index(runtimeIndices.size());
for (size_t i = 0; i < runtimeIndices.size(); ++i)
Expand Down Expand Up @@ -68,23 +76,21 @@ TensorType inferConvolutionOpType(
Axis outputFeatureDimension, Axes outputSpatialDimensions,
int64_t featureGroupCount, int64_t batchGroupCount,
std::optional<ArrayAttr> precisionConfig, TensorType resultType) {
Builder builder = mlir::Builder(lhsType.getContext());
Type i64Type = builder.getI64Type();
Type i1Type = builder.getI1Type();
auto flattenPadding = [&](ArrayRef<std::pair<int64_t, int64_t>> padding) {
SmallVector<int64_t> paddingVector;
for (auto pair : padding) {
paddingVector.push_back(pair.first);
paddingVector.push_back(pair.second);
}
return paddingVector;
};
auto i64Type = getI64Type(lhsType.getContext());
auto i1Type = getI1Type(lhsType.getContext());

SmallVector<int64_t> paddingVector;
for (auto pair : padding) {
paddingVector.push_back(pair.first);
paddingVector.push_back(pair.second);
}

SmallVector<int64_t> paddingShape{static_cast<int64_t>(padding.size()), 2};
SmallVector<ShapedTypeComponents> inferredConvolutionType;
auto convolutionStatus = hlo::inferConvolutionOp(
/*location=*/{}, lhsType, rhsType,
getDenseIntElementsAttr(i64Type, windowStrides, {}),
getDenseIntElementsAttr(i64Type, flattenPadding(padding), paddingShape),
getDenseIntElementsAttr(i64Type, paddingVector, paddingShape),
getDenseIntElementsAttr(i64Type, lhsDilation, {}),
getDenseIntElementsAttr(i64Type, rhsDilation, {}),
getDenseIntElementsAttr(i1Type, windowReversal, {}),
Expand Down Expand Up @@ -122,106 +128,112 @@ TensorType inferDotGeneralOpType(TensorType lhsType, TensorType rhsType,
lhsType.getElementType());
}

TensorType inferPadOpType(ArrayRef<std::pair<int64_t, int64_t>> lhsPadding,
Type operandType, Type resultElementType,
TensorType inferPadOpType(ArrayRef<std::pair<int64_t, int64_t>> padding,
Type operandType, Type paddingValueElementType,
ArrayRef<int64_t> interiorPadding) {
Builder builder = mlir::Builder(operandType.getContext());
Type i64Type = builder.getI64Type();
SmallVector<int64_t> lhsPaddingLow;
SmallVector<int64_t> lhsPaddingHigh;
for (auto paddingPair : lhsPadding) {
for (auto paddingPair : padding) {
lhsPaddingLow.push_back(paddingPair.first);
lhsPaddingHigh.push_back(paddingPair.second);
}

SmallVector<Type> inferredPadType;
auto i64Type = getI64Type(operandType.getContext());
auto padStatus = hlo::inferPadOp(
{}, operandType, RankedTensorType::get({}, resultElementType),
{}, operandType, RankedTensorType::get({}, paddingValueElementType),
getDenseIntElementsAttr(i64Type, lhsPaddingLow, {}),
getDenseIntElementsAttr(i64Type, lhsPaddingHigh, {}),
getDenseIntElementsAttr(i64Type, interiorPadding, {}), inferredPadType);

if (failed(padStatus))
report_fatal_error(invalidArgument("Could not infer PadOp's return type"));

return inferredPadType[0].cast<TensorType>();
}

TensorType inferSliceOpType(Type operandType,
SmallVector<int64_t> lhsWindowStart,
SmallVector<int64_t> limitIndices,
SmallVector<int64_t> lhsWindowDilations) {
Builder builder = mlir::Builder(operandType.getContext());
Type i64Type = builder.getI64Type();
SmallVector<Type> inferredSliceType;
auto i64Type = getI64Type(operandType.getContext());
auto sliceStatus = hlo::inferSliceOp(
{}, operandType, getDenseIntElementsAttr(i64Type, lhsWindowStart, {}),
getDenseIntElementsAttr(i64Type, limitIndices, {}),
getDenseIntElementsAttr(i64Type, lhsWindowDilations, {}),
inferredSliceType);

if (failed(sliceStatus))
report_fatal_error(
invalidArgument("Could not infer SliceOp's return type"));

return inferredSliceType[0].cast<TensorType>();
}

// Returns `result` with the effect of applying `permutation`
// (= [dimA] + dimsB + [dimC]) to `input` (= [n] + hw + [c]) such that
// result[permutation[i]] = input[i].
template <typename T>
SmallVector<T> concatAndPermute(T n, SmallVector<T> hw, T c, int64_t dimA,
SmallVector<int64_t> dimB, int64_t dimC) {
SmallVector<T> permInput;
permInput.push_back(n);
permInput.append(hw.begin(), hw.end());
permInput.push_back(c);
SmallVector<int64_t> permDims;
permDims.push_back(dimA);
permDims.append(dimB.begin(), dimB.end());
permDims.push_back(dimC);
SmallVector<T> permOutput(permDims.size());
for (auto [idx, dim] : llvm::enumerate(permDims))
permOutput[dim] = permInput[idx];
return permOutput;
SmallVector<T> concatAndPermute(T n, SmallVector<T> hw, T c,
const Axes &permutation) {
SmallVector<T> input;
input.push_back(n);
input.append(hw.begin(), hw.end());
input.push_back(c);

if (input.size() != permutation.size())
llvm::report_fatal_error(
"Expect same size for permutation and the array to be permuted");

SmallVector<T> result(permutation.size());
for (auto [idx, dim] : llvm::enumerate(permutation)) result[dim] = input[idx];
return result;
}

SmallVector<Tensor> splitIntoNGroupsAlongDim(const Tensor &input, int64_t N,
SmallVector<Tensor> splitIntoNGroupsAlongDim(const Tensor &input,
int64_t nGroups,
Axis splitDimension) {
Builder builder = mlir::Builder(input.getType().getContext());
auto i64Type = builder.getI64Type();

auto getScalarTensor = [&](auto value) {
return makeTensor(
DenseElementsAttr::get(RankedTensorType::get({}, i64Type), value));
auto i64Type = getI64Type(input.getType().getContext());
return Tensor(RankedTensorType::get({}, i64Type), Element(i64Type, value));
};
auto splitInputShape = SmallVector<int64_t>(input.getShape());
splitInputShape[splitDimension] /= N;

Sizes splitInputShape(input.getShape());
splitInputShape[splitDimension] /= nGroups;

auto splitInputType =
RankedTensorType::get(splitInputShape, input.getElementType());
SmallVector<Tensor> splitResults;
for (auto idx = 0; idx < N; ++idx) {
for (auto idx = 0; idx < nGroups; ++idx) {
SmallVector<Tensor> inputStartIndices(input.getRank(), getScalarTensor(0L));
inputStartIndices[splitDimension] =
getScalarTensor(idx * splitInputShape[splitDimension]);
Sizes strides(input.getRank(), 1);
auto resultTensor =
evalDynamicSliceOp(input, inputStartIndices, strides, splitInputType);

auto resultTensor = evalDynamicSliceOp(input, inputStartIndices,
splitInputShape, splitInputType);
splitResults.push_back(resultTensor);
}
return splitResults;
}

Tensor getZeroScalarTensor(Type elementType) {
if (isSupportedIntegerType(elementType))
return makeTensor(
DenseElementsAttr::get(RankedTensorType::get({}, elementType),
Element(elementType, 0L).getIntegerValue()));
return Tensor(RankedTensorType::get({}, elementType),
Element(elementType, 0L));

if (isSupportedBooleanType(elementType))
return makeTensor(
DenseElementsAttr::get(RankedTensorType::get({}, elementType), false));
return Tensor(RankedTensorType::get({}, elementType),
Element(elementType, false));

if (isSupportedFloatType(elementType))
return makeTensor(
DenseElementsAttr::get(RankedTensorType::get({}, elementType),
Element(elementType, 0.0).getFloatValue()));
return Tensor(RankedTensorType::get({}, elementType),
Element(elementType, 0.0));

if (isSupportedComplexType(elementType))
return makeTensor(DenseElementsAttr::get(
RankedTensorType::get({}, elementType),
Element(elementType, std::complex<double>(0.0)).getComplexValue()));
return Tensor(RankedTensorType::get({}, elementType),
Element(elementType, std::complex<double>(0.0)));

report_fatal_error(invalidArgument("Unsupported element type: %s",
debugString(elementType).c_str()));
}
Expand Down Expand Up @@ -378,11 +390,6 @@ Tensor evalConvolutionOp(
int64_t featureGroupCount, int64_t batchGroupCount, TensorType resultType) {
Tensor result(resultType);

// This check is necessary here as we are not looping over result tensor but
// instead creating iterators using output spatial dimension size, which may
// fail if the dimension size is zero.
if (resultType.getNumElements() == 0) return result;

if (featureGroupCount > 1) {
auto lhses =
splitIntoNGroupsAlongDim(lhs, featureGroupCount, inputFeatureDimension);
Expand Down Expand Up @@ -410,6 +417,7 @@ Tensor evalConvolutionOp(
}
return evalConcatenateOp(results, outputFeatureDimension, result.getType());
}

if (batchGroupCount > 1) {
auto lhses =
splitIntoNGroupsAlongDim(lhs, batchGroupCount, inputBatchDimension);
Expand Down Expand Up @@ -437,52 +445,59 @@ Tensor evalConvolutionOp(
}
return evalConcatenateOp(results, outputFeatureDimension, result.getType());
}
auto lhsWindowDimensions = concatAndPermute(
lhs.getShape()[inputBatchDimension],
extractElements(rhs.getShape(), kernelSpatialDimensions),
lhs.getShape()[inputFeatureDimension], inputBatchDimension,
inputSpatialDimensions, inputFeatureDimension);
auto lhsWindowStrides = concatAndPermute(
1L, llvm::to_vector(windowStrides), 1L, inputBatchDimension,
inputSpatialDimensions, inputFeatureDimension);
auto lhsPadding = concatAndPermute(
{0, 0}, llvm::to_vector(padding), {0, 0}, inputBatchDimension,
inputSpatialDimensions, inputFeatureDimension);

Axes lhsPermutation;
lhsPermutation.push_back(inputBatchDimension);
lhsPermutation.append(inputSpatialDimensions.begin(),
inputSpatialDimensions.end());
lhsPermutation.push_back(inputFeatureDimension);

auto lhsWindowDimensions =
concatAndPermute(lhs.getShape()[inputBatchDimension],
extractElements(rhs.getShape(), kernelSpatialDimensions),
lhs.getShape()[inputFeatureDimension], lhsPermutation);

auto lhsWindowStrides =
concatAndPermute(1L, llvm::to_vector(windowStrides), 1L, lhsPermutation);

auto lhsPadding = concatAndPermute({0, 0}, llvm::to_vector(padding), {0, 0},
lhsPermutation);

auto lhsBaseDilations =
concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, inputBatchDimension,
inputSpatialDimensions, inputFeatureDimension);
auto lhsWindowDilations = concatAndPermute(
1L, llvm::to_vector(rhsDilation), 1L, inputBatchDimension,
inputSpatialDimensions, inputFeatureDimension);
concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, lhsPermutation);

auto lhsWindowDilations =
concatAndPermute(1L, llvm::to_vector(rhsDilation), 1L, lhsPermutation);

auto outputSpacialIndexIt = IndexSpaceIterator(
auto outputSpatialIndexIt = IndexSpaceIterator(
Sizes(extractElements(result.getShape(), outputSpatialDimensions)),
Index(outputSpatialDimensions.size()));
auto outputSpacialIndexItEnd = IndexSpaceIterator(
auto outputSpatialIndexItEnd = IndexSpaceIterator(
Sizes(extractElements(result.getShape(), outputSpatialDimensions)),
std::nullopt);
for (; outputSpacialIndexIt != outputSpacialIndexItEnd;
++outputSpacialIndexIt) {
for (; outputSpatialIndexIt != outputSpatialIndexItEnd;
++outputSpatialIndexIt) {
SmallVector<int64_t> lhsPaddingLow;
for (auto paddingPair : lhsPadding)
lhsPaddingLow.push_back(paddingPair.first);

auto paddedLhs =
evalPadOp(lhs, getZeroScalarTensor(result.getElementType()),
Sizes(lhsPaddingLow), Sizes(lhsBaseDilations),
inferPadOpType(lhsPadding, lhs.getType(),
result.getElementType(), lhsBaseDilations));

SmallVector<int64_t> lhsWindowStart;
for (auto [i, offset] : llvm::enumerate(
concatAndPermute(0L, llvm::to_vector(*outputSpacialIndexIt), 0L,
inputBatchDimension, inputSpatialDimensions,
inputFeatureDimension)))
for (auto [i, offset] : llvm::enumerate(concatAndPermute(
0L, llvm::to_vector(*outputSpatialIndexIt), 0L, lhsPermutation)))
lhsWindowStart.push_back(lhsWindowStrides[i] * offset);

SmallVector<int64_t> limitIndices;
for (size_t i = 0; i < lhsWindowStart.size(); ++i)
limitIndices.push_back(std::min(
lhsWindowStart[i] + lhsWindowDimensions[i] * lhsWindowDilations[i],
paddedLhs.getShape()[i]));

auto lhsWindow =
evalSliceOp(paddedLhs, Sizes(lhsWindowStart), Sizes(lhsWindowDilations),
inferSliceOpType(paddedLhs.getType(), lhsWindowStart,
Expand All @@ -491,15 +506,18 @@ Tensor evalConvolutionOp(
SmallVector<int64_t> reverseDims;
for (auto [i, isReverse] : llvm::enumerate(windowReversal))
if (isReverse) reverseDims.push_back(inputSpatialDimensions[i]);

auto reversedLhsWindow =
evalReverseOp(lhsWindow, Axes(reverseDims), lhsWindow.getType());

auto lhsContractingDimensions = llvm::to_vector(inputSpatialDimensions);
lhsContractingDimensions.push_back(
static_cast<int64_t>(inputFeatureDimension));

auto rhsContractingDimensions = llvm::to_vector(kernelSpatialDimensions);
rhsContractingDimensions.push_back(
static_cast<int64_t>(kernelInputFeatureDimension));

auto dotProduct = evalDotGeneralOp(
reversedLhsWindow, rhs, /*lhsBatchingDimensions=*/{},
/*rhsBatchingDimensions=*/{}, Axes(lhsContractingDimensions),
Expand All @@ -508,34 +526,26 @@ Tensor evalConvolutionOp(
lhsContractingDimensions,
rhsContractingDimensions));

auto resultShape = [&](auto resultOffsetIt) {
SmallVector<int64_t> resultIndex;
for (auto outputSpacialDimensionsIdx = 0, resultOffsetItIdx = 0,
outputSpacialIndexItIdx = 0;
outputSpacialDimensionsIdx < result.getRank();
++outputSpacialDimensionsIdx) {
if (llvm::find(outputSpatialDimensions, outputSpacialDimensionsIdx) ==
outputSpatialDimensions.end())
resultIndex.push_back((*resultOffsetIt)[resultOffsetItIdx++]);
else
resultIndex.push_back(
(*outputSpacialIndexIt)[outputSpacialIndexItIdx++]);
}
return resultIndex;
};

SmallVector<int64_t> resultOffset;
Sizes resultNonSpatialDims;
for (auto i = 0; i < result.getRank(); ++i)
if (llvm::find(outputSpatialDimensions, i) ==
outputSpatialDimensions.end())
resultOffset.push_back(result.getShape()[i]);
resultNonSpatialDims.push_back(result.getShape()[i]);

Axes resultPermutation;
resultPermutation.push_back(outputBatchDimension);
resultPermutation.append(outputSpatialDimensions.begin(),
outputSpatialDimensions.end());
resultPermutation.push_back(outputFeatureDimension);

auto resultOffsetIt =
IndexSpaceIterator(Sizes(resultOffset), Index(resultOffset.size()));
auto resultNonSpatialIt = IndexSpaceIterator(
resultNonSpatialDims, Index(resultNonSpatialDims.size()));
for (auto dotProductIt = dotProduct.index_begin();
dotProductIt != dotProduct.index_end();
++dotProductIt, ++resultOffsetIt) {
auto resultIndex = resultShape(resultOffsetIt);
++dotProductIt, ++resultNonSpatialIt) {
auto resultIndex =
concatAndPermute((*resultNonSpatialIt)[0], *outputSpatialIndexIt,
(*resultNonSpatialIt)[1], resultPermutation);
result.set(Index(resultIndex), dotProduct.get(*dotProductIt));
}
}
Expand Down
Loading

0 comments on commit 3b92a93

Please sign in to comment.