Skip to content

Commit

Permalink
Optimize runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Mar 28, 2023
1 parent 3b92a93 commit 1886019
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,11 @@ TensorType inferSliceOpType(Type operandType,
template <typename T>
SmallVector<T> concatAndPermute(T n, SmallVector<T> hw, T c,
const Axes &permutation) {
SmallVector<T> input;
input.push_back(n);
input.append(hw.begin(), hw.end());
input.push_back(c);

if (input.size() != permutation.size())
llvm::report_fatal_error(
"Expect same size for permutation and the array to be permuted");

SmallVector<T> result(permutation.size());
for (auto [idx, dim] : llvm::enumerate(permutation)) result[dim] = input[idx];
result[permutation[0]] = n;
result[permutation[permutation.size() - 1]] = c;
for (uint64_t i = 1; i < permutation.size() - 1; ++i)
result[permutation[i]] = hw[i - 1];
return result;
}

Expand Down Expand Up @@ -475,18 +469,21 @@ Tensor evalConvolutionOp(
auto outputSpatialIndexItEnd = IndexSpaceIterator(
Sizes(extractElements(result.getShape(), outputSpatialDimensions)),
std::nullopt);
for (; outputSpatialIndexIt != outputSpatialIndexItEnd;
++outputSpatialIndexIt) {
SmallVector<int64_t> lhsPaddingLow;
for (auto paddingPair : lhsPadding)
lhsPaddingLow.push_back(paddingPair.first);

auto paddedLhs =
evalPadOp(lhs, getZeroScalarTensor(result.getElementType()),
Sizes(lhsPaddingLow), Sizes(lhsBaseDilations),
inferPadOpType(lhsPadding, lhs.getType(),
result.getElementType(), lhsBaseDilations));
SmallVector<int64_t> lhsPaddingLow;
for (auto paddingPair : lhsPadding)
lhsPaddingLow.push_back(paddingPair.first);

auto inferredPadOpType = inferPadOpType(
lhsPadding, lhs.getType(), result.getElementType(), lhsBaseDilations);

auto zeroTensor = getZeroScalarTensor(result.getElementType());

auto paddedLhs = evalPadOp(lhs, zeroTensor, Sizes(lhsPaddingLow),
Sizes(lhsBaseDilations), inferredPadOpType);

for (; outputSpatialIndexIt != outputSpatialIndexItEnd;
++outputSpatialIndexIt) {
SmallVector<int64_t> lhsWindowStart;
for (auto [i, offset] : llvm::enumerate(concatAndPermute(
0L, llvm::to_vector(*outputSpatialIndexIt), 0L, lhsPermutation)))
Expand All @@ -507,8 +504,10 @@ Tensor evalConvolutionOp(
for (auto [i, isReverse] : llvm::enumerate(windowReversal))
if (isReverse) reverseDims.push_back(inputSpatialDimensions[i]);

auto reversedLhsWindow =
evalReverseOp(lhsWindow, Axes(reverseDims), lhsWindow.getType());
auto reversedLhsWindow = lhsWindow;
if (reverseDims.size() != 0)
reversedLhsWindow =
evalReverseOp(lhsWindow, Axes(reverseDims), lhsWindow.getType());

auto lhsContractingDimensions = llvm::to_vector(inputSpatialDimensions);
lhsContractingDimensions.push_back(
Expand Down

0 comments on commit 1886019

Please sign in to comment.