From 18860191a4957a07f5bf8fa8d179c288e8b3c600 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Tue, 28 Mar 2023 17:41:09 +0000 Subject: [PATCH] Optimize runtime --- stablehlo/reference/Ops.cpp | 43 ++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index ef91e8f300d..06e2e6c6eef 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -177,17 +177,11 @@ TensorType inferSliceOpType(Type operandType, template SmallVector concatAndPermute(T n, SmallVector hw, T c, const Axes &permutation) { - SmallVector input; - input.push_back(n); - input.append(hw.begin(), hw.end()); - input.push_back(c); - - if (input.size() != permutation.size()) - llvm::report_fatal_error( - "Expect same size for permutation and the array to be permuted"); - SmallVector result(permutation.size()); - for (auto [idx, dim] : llvm::enumerate(permutation)) result[dim] = input[idx]; + result[permutation[0]] = n; + result[permutation[permutation.size() - 1]] = c; + for (uint64_t i = 1; i < permutation.size() - 1; ++i) + result[permutation[i]] = hw[i - 1]; return result; } @@ -475,18 +469,21 @@ Tensor evalConvolutionOp( auto outputSpatialIndexItEnd = IndexSpaceIterator( Sizes(extractElements(result.getShape(), outputSpatialDimensions)), std::nullopt); - for (; outputSpatialIndexIt != outputSpatialIndexItEnd; - ++outputSpatialIndexIt) { - SmallVector lhsPaddingLow; - for (auto paddingPair : lhsPadding) - lhsPaddingLow.push_back(paddingPair.first); - auto paddedLhs = - evalPadOp(lhs, getZeroScalarTensor(result.getElementType()), - Sizes(lhsPaddingLow), Sizes(lhsBaseDilations), - inferPadOpType(lhsPadding, lhs.getType(), - result.getElementType(), lhsBaseDilations)); + SmallVector lhsPaddingLow; + for (auto paddingPair : lhsPadding) + lhsPaddingLow.push_back(paddingPair.first); + auto inferredPadOpType = inferPadOpType( + lhsPadding, lhs.getType(), result.getElementType(), lhsBaseDilations); + + auto zeroTensor = getZeroScalarTensor(result.getElementType()); + + auto paddedLhs = evalPadOp(lhs, zeroTensor, Sizes(lhsPaddingLow), + Sizes(lhsBaseDilations), inferredPadOpType); + + for (; outputSpatialIndexIt != outputSpatialIndexItEnd; + ++outputSpatialIndexIt) { SmallVector lhsWindowStart; for (auto [i, offset] : llvm::enumerate(concatAndPermute( 0L, llvm::to_vector(*outputSpatialIndexIt), 0L, lhsPermutation))) @@ -507,8 +504,10 @@ Tensor evalConvolutionOp( for (auto [i, isReverse] : llvm::enumerate(windowReversal)) if (isReverse) reverseDims.push_back(inputSpatialDimensions[i]); - auto reversedLhsWindow = - evalReverseOp(lhsWindow, Axes(reverseDims), lhsWindow.getType()); + auto reversedLhsWindow = lhsWindow; + if (reverseDims.size() != 0) + reversedLhsWindow = + evalReverseOp(lhsWindow, Axes(reverseDims), lhsWindow.getType()); auto lhsContractingDimensions = llvm::to_vector(inputSpatialDimensions); lhsContractingDimensions.push_back(