Skip to content

Commit

Permalink
Optimize runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Sep 13, 2023
1 parent 19d111a commit dd052b7
Showing 1 changed file with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,11 @@ ShapedType inferSliceOpType(Type operandType,
template <typename T>
SmallVector<T> concatAndPermute(T n, SmallVector<T> hw, T c,
const Axes &permutation) {
SmallVector<T> input;
input.push_back(n);
input.append(hw.begin(), hw.end());
input.push_back(c);

if (input.size() != permutation.size())
llvm::report_fatal_error(
"Expect same size for permutation and the array to be permuted");

SmallVector<T> result(permutation.size());
for (auto [idx, dim] : llvm::enumerate(permutation)) result[dim] = input[idx];
result[permutation[0]] = n;
result[permutation[permutation.size() - 1]] = c;
for (uint64_t i = 1; i < permutation.size() - 1; ++i)
result[permutation[i]] = hw[i - 1];
return result;
}

Expand Down Expand Up @@ -1441,18 +1435,20 @@ Tensor evalConvolutionOp(
auto outputSpatialIndexItEnd = IndexSpaceIterator(
Sizes(extractElements(result.getShape(), outputSpatialDimensions)),
std::nullopt);
for (; outputSpatialIndexIt != outputSpatialIndexItEnd;
++outputSpatialIndexIt) {
SmallVector<int64_t> lhsPaddingLow;
for (auto paddingPair : lhsPadding)
lhsPaddingLow.push_back(paddingPair.first);

auto paddedLhs =
evalPadOp(lhs, makeScalar(convert(result.getElementType(), 0.0)),
Sizes(lhsPaddingLow), Sizes(lhsBaseDilations),
inferPadOpType(lhsPadding, lhs.getType(),
result.getElementType(), lhsBaseDilations));
SmallVector<int64_t> lhsPaddingLow;
for (auto paddingPair : lhsPadding)
lhsPaddingLow.push_back(paddingPair.first);

auto inferredPadOpType = inferPadOpType(
lhsPadding, lhs.getType(), result.getElementType(), lhsBaseDilations);

auto paddedLhs = evalPadOp(
lhs, makeScalar(convert(result.getElementType(), 0.0)),
Sizes(lhsPaddingLow), Sizes(lhsBaseDilations), inferredPadOpType);

for (; outputSpatialIndexIt != outputSpatialIndexItEnd;
++outputSpatialIndexIt) {
SmallVector<int64_t> lhsWindowStart;
for (auto [i, offset] : llvm::enumerate(concatAndPermute(
0L, llvm::to_vector(*outputSpatialIndexIt), 0L, lhsPermutation)))
Expand Down

0 comments on commit dd052b7

Please sign in to comment.