From dd052b7b83536e0d6e4f067e61debedf30a315b3 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Tue, 28 Mar 2023 17:41:09 +0000 Subject: [PATCH] Optimize runtime --- stablehlo/reference/Ops.cpp | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index e4146349151..9c6f09dfb28 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -286,17 +286,11 @@ ShapedType inferSliceOpType(Type operandType, template SmallVector concatAndPermute(T n, SmallVector hw, T c, const Axes &permutation) { - SmallVector input; - input.push_back(n); - input.append(hw.begin(), hw.end()); - input.push_back(c); - - if (input.size() != permutation.size()) - llvm::report_fatal_error( - "Expect same size for permutation and the array to be permuted"); - SmallVector result(permutation.size()); - for (auto [idx, dim] : llvm::enumerate(permutation)) result[dim] = input[idx]; + result[permutation[0]] = n; + result[permutation[permutation.size() - 1]] = c; + for (uint64_t i = 1; i < permutation.size() - 1; ++i) + result[permutation[i]] = hw[i - 1]; return result; } @@ -1441,18 +1435,20 @@ Tensor evalConvolutionOp( auto outputSpatialIndexItEnd = IndexSpaceIterator( Sizes(extractElements(result.getShape(), outputSpatialDimensions)), std::nullopt); - for (; outputSpatialIndexIt != outputSpatialIndexItEnd; - ++outputSpatialIndexIt) { - SmallVector lhsPaddingLow; - for (auto paddingPair : lhsPadding) - lhsPaddingLow.push_back(paddingPair.first); - auto paddedLhs = - evalPadOp(lhs, makeScalar(convert(result.getElementType(), 0.0)), - Sizes(lhsPaddingLow), Sizes(lhsBaseDilations), - inferPadOpType(lhsPadding, lhs.getType(), - result.getElementType(), lhsBaseDilations)); + SmallVector lhsPaddingLow; + for (auto paddingPair : lhsPadding) + lhsPaddingLow.push_back(paddingPair.first); + + auto inferredPadOpType = inferPadOpType( + lhsPadding, lhs.getType(), result.getElementType(), lhsBaseDilations); + + auto paddedLhs = evalPadOp( + lhs, makeScalar(convert(result.getElementType(), 0.0)), + Sizes(lhsPaddingLow), Sizes(lhsBaseDilations), inferredPadOpType); + for (; outputSpatialIndexIt != outputSpatialIndexItEnd; + ++outputSpatialIndexIt) { SmallVector lhsWindowStart; for (auto [i, offset] : llvm::enumerate(concatAndPermute( 0L, llvm::to_vector(*outputSpatialIndexIt), 0L, lhsPermutation)))