From ac8ed4cf6f79b698f2cfc4c2e1d2ef6a0c78cff3 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 22 Apr 2024 10:05:51 -0700 Subject: [PATCH] Rename variables to simplify refactoring (#2245) This is part of ongoing efforts to sync reference interpreter closer to the spec #1049 --- stablehlo/reference/Ops.cpp | 1010 +++++++++++++++++------------------ 1 file changed, 494 insertions(+), 516 deletions(-) diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 2d683c5612a..df501966ed4 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -320,24 +320,24 @@ SmallVector eval(Region ®ion, Scope scope(parent); scope.add(block.getArguments(), args); - for (Operation &op : block) { - if (auto absOp = dyn_cast(op)) { - auto operand = scope.findTensor(absOp.getOperand()); - auto result = evalAbsOp(operand, absOp.getType()); - scope.add(absOp.getResult(), result); - } else if (auto addOp = dyn_cast(op)) { - auto lhs = scope.findTensor(addOp.getLhs()); - auto rhs = scope.findTensor(addOp.getRhs()); - auto result = evalAddOp(lhs, rhs, addOp.getType()); - scope.add(addOp.getResult(), result); - } else if (auto afterAllOp = dyn_cast(op)) { - auto inputs = scope.findTokens(afterAllOp.getInputs()); - auto result = evalAfterAllOp(inputs, afterAllOp->getContext()); - scope.add(afterAllOp.getResult(), result); - } else if (auto allGatherOp = dyn_cast(op)) { - auto operand = scope.findTensor(allGatherOp.getOperand()); - - auto replicaGroupsAttr = allGatherOp.getReplicaGroups(); + for (Operation &operation : block) { + if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalAbsOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalAddOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto inputs = scope.findTokens(op.getInputs()); + auto result = evalAfterAllOp(inputs, op->getContext()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + + auto replicaGroupsAttr = op.getReplicaGroups(); auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape(); SmallVector> replicaGroups(replicaGroupsShape[0]); auto replicaGroupsIt = replicaGroupsAttr.getValues().begin(); @@ -346,29 +346,28 @@ SmallVector eval(Region ®ion, replicaGroup.push_back(*replicaGroupsIt); ChannelId channelId = 0; - if (auto channelHandle = allGatherOp.getChannelHandle()) + if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle->getHandle(); auto result = evalAllGatherOp( - operand, allGatherOp.getAllGatherDim(), replicaGroups, channelId, - allGatherOp.getUseGlobalDeviceIds(), process, allGatherOp.getType()); - scope.add(allGatherOp.getResult(), result); - } else if (auto allReduceOp = dyn_cast(op)) { - auto operand = scope.findTensor(allReduceOp.getOperand()); - auto replicaGroups = getReplicaGroups(allReduceOp.getReplicaGroups()); + operand, op.getAllGatherDim(), replicaGroups, channelId, + op.getUseGlobalDeviceIds(), process, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto replicaGroups = getReplicaGroups(op.getReplicaGroups()); ChannelId channelId = 0; - if (auto channelHandle = allReduceOp.getChannelHandle()) + if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle->getHandle(); - auto result = evalAllReduceOp(operand, replicaGroups, channelId, - allReduceOp.getUseGlobalDeviceIds(), - allReduceOp.getComputation(), process, - scope, allReduceOp.getType()); - scope.add(allReduceOp.getResult(), result); - } else if (auto allToAllOp = dyn_cast(op)) { - auto operand = scope.findTensor(allToAllOp.getOperand()); - auto replicaGroupsAttr = allToAllOp.getReplicaGroups(); + auto result = evalAllReduceOp( + operand, replicaGroups, channelId, op.getUseGlobalDeviceIds(), + op.getComputation(), process, scope, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto replicaGroupsAttr = op.getReplicaGroups(); auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape(); SmallVector> replicaGroups(replicaGroupsShape[0]); auto replicaGroupsIt = replicaGroupsAttr.getValues().begin(); @@ -377,78 +376,75 @@ SmallVector eval(Region ®ion, replicaGroup.push_back(*replicaGroupsIt); ChannelId channelId = 0; - if (auto channelHandle = allToAllOp.getChannelHandle()) + if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle->getHandle(); - auto result = evalAllToAllOp(operand, allToAllOp.getSplitDimension(), - allToAllOp.getConcatDimension(), - allToAllOp.getSplitCount(), replicaGroups, - channelId, process, allToAllOp.getType()); - scope.add(allToAllOp.getResult(), result); - } else if (auto andOp = dyn_cast(op)) { - auto lhs = scope.findTensor(andOp.getLhs()); - auto rhs = scope.findTensor(andOp.getRhs()); - auto result = evalAndOp(lhs, rhs, andOp.getType()); - scope.add(andOp.getResult(), result); - } else if (auto atan2Op = dyn_cast(op)) { - auto lhs = scope.findTensor(atan2Op.getLhs()); - auto rhs = scope.findTensor(atan2Op.getRhs()); - auto result = evalAtan2Op(lhs, rhs, atan2Op.getType()); - scope.add(atan2Op.getResult(), result); - } else if (auto batchNormGradOp = dyn_cast(op)) { - failOnDecomposableOp(op); - } else if (auto batchNormInferenceOp = dyn_cast(op)) { - failOnDecomposableOp(op); - } else if (auto batchNormTrainingOp = dyn_cast(op)) { - failOnDecomposableOp(op); - } else if (auto bitcastConvertOp = dyn_cast(op)) { - auto operand = scope.findTensor(bitcastConvertOp.getOperand()); - auto result = evalBitcastConvertOp(operand, bitcastConvertOp.getType()); - scope.add(bitcastConvertOp.getResult(), result); - } else if (auto broadcastInDimOp = dyn_cast(op)) { - auto operand = scope.findTensor(broadcastInDimOp.getOperand()); - auto broadcastDimensions = - Axes(broadcastInDimOp.getBroadcastDimensions()); - auto result = evalBroadcastInDimOp(operand, broadcastDimensions, - broadcastInDimOp.getType()); - scope.add(broadcastInDimOp.getResult(), result); - } else if (isa(op)) { - failOnDecomposableOp(op); - } else if (auto callOp = dyn_cast(op)) { - auto operands = scope.findTensors(callOp.getOperands()); + auto result = evalAllToAllOp( + operand, op.getSplitDimension(), op.getConcatDimension(), + op.getSplitCount(), replicaGroups, channelId, process, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalAndOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalAtan2Op(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalBitcastConvertOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto broadcastDimensions = Axes(op.getBroadcastDimensions()); + auto result = + evalBroadcastInDimOp(operand, broadcastDimensions, op.getType()); + scope.add(op.getResult(), result); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (auto op = dyn_cast(operation)) { + auto operands = scope.findTensors(op.getOperands()); auto results = - evalCallOp(operands, fallback, process, &op, callOp.getCallee()); - scope.add(callOp.getResults(), results); - } else if (auto caseOp = dyn_cast(op)) { - auto index = scope.findTensor(caseOp.getIndex()); - auto branches = caseOp.getBranches(); + evalCallOp(operands, fallback, process, &operation, op.getCallee()); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto index = scope.findTensor(op.getIndex()); + auto branches = op.getBranches(); auto results = evalCaseOp(index, branches, process, scope); - scope.add(caseOp.getResults(), results); - } else if (auto cbrtOp = dyn_cast(op)) { - auto operand = scope.findTensor(cbrtOp.getOperand()); - auto result = evalCbrtOp(operand, cbrtOp.getType()); - scope.add(cbrtOp.getResult(), result); - } else if (auto ceilOp = dyn_cast(op)) { - auto operand = scope.findTensor(ceilOp.getOperand()); - auto result = evalCeilOp(operand, ceilOp.getType()); - scope.add(ceilOp.getResult(), result); - } else if (auto choleskyOp = dyn_cast(op)) { - failOnDecomposableOp(op); - } else if (auto clampOp = dyn_cast(op)) { - auto min = scope.findTensor(clampOp.getMin()); - auto operand = scope.findTensor(clampOp.getOperand()); - auto max = scope.findTensor(clampOp.getMax()); - auto result = evalClampOp(min, operand, max, clampOp.getType()); - scope.add(clampOp.getResult(), result); - } else if (auto clzOp = dyn_cast(op)) { - auto operand = scope.findTensor(clzOp.getOperand()); - auto result = evalClzOp(operand, clzOp.getType()); - scope.add(clzOp.getResult(), result); - } else if (auto collectiveBroadcastOp = - dyn_cast(op)) { - auto operand = scope.findTensor(collectiveBroadcastOp.getOperand()); - - auto replicaGroupsAttr = collectiveBroadcastOp.getReplicaGroups(); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalCbrtOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalCeilOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (auto op = dyn_cast(operation)) { + auto min = scope.findTensor(op.getMin()); + auto operand = scope.findTensor(op.getOperand()); + auto max = scope.findTensor(op.getMax()); + auto result = evalClampOp(min, operand, max, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalClzOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + + auto replicaGroupsAttr = op.getReplicaGroups(); auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape(); SmallVector> replicaGroups(replicaGroupsShape[0]); auto replicaGroupsIt = replicaGroupsAttr.getValues().begin(); @@ -457,16 +453,16 @@ SmallVector eval(Region ®ion, replicaGroup.push_back(*replicaGroupsIt); ChannelId channelId = 0; - if (auto channelHandle = collectiveBroadcastOp.getChannelHandle()) + if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle->getHandle(); auto result = evalCollectiveBroadcastOp(operand, replicaGroups, channelId, process); - scope.add(collectiveBroadcastOp.getResult(), result); - } else if (auto collectivePermuteOp = dyn_cast(op)) { - auto operand = scope.findTensor(collectivePermuteOp.getOperand()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); - auto sourceTargetPairsAttr = collectivePermuteOp.getSourceTargetPairs(); + auto sourceTargetPairsAttr = op.getSourceTargetPairs(); SmallVector> sourceTargetPairs( sourceTargetPairsAttr.getNumElements() / 2); auto sourceTargetPairsIt = @@ -477,52 +473,51 @@ SmallVector eval(Region ®ion, } ChannelId channelId = 0; - if (auto channelHandle = collectivePermuteOp.getChannelHandle()) + if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle->getHandle(); auto result = evalCollectivePermuteOp(operand, sourceTargetPairs, channelId, process); - scope.add(collectivePermuteOp.getResult(), result); - } else if (auto compareOp = dyn_cast(op)) { - auto lhs = scope.findTensor(compareOp.getLhs()); - auto rhs = scope.findTensor(compareOp.getRhs()); - auto comparisonDirection = compareOp.getComparisonDirection(); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto comparisonDirection = op.getComparisonDirection(); + auto result = evalCompareOp(lhs, rhs, comparisonDirection, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalComplexOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operands = scope.findTensors(op.getOperands()); + auto results = evalCallOp(operands, fallback, process, &operation, + op.getDecomposition()); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto operands = scope.findTensors(op.getOperands()); auto result = - evalCompareOp(lhs, rhs, comparisonDirection, compareOp.getType()); - scope.add(compareOp.getResult(), result); - } else if (auto complexOp = dyn_cast(op)) { - auto lhs = scope.findTensor(complexOp.getLhs()); - auto rhs = scope.findTensor(complexOp.getRhs()); - auto result = evalComplexOp(lhs, rhs, complexOp.getType()); - scope.add(complexOp.getResult(), result); - } else if (auto compositeOp = dyn_cast(op)) { - auto operands = scope.findTensors(compositeOp.getOperands()); - auto results = evalCallOp(operands, fallback, process, &op, - compositeOp.getDecomposition()); - scope.add(compositeOp.getResults(), results); - } else if (auto concatenateOp = dyn_cast(op)) { - auto operands = scope.findTensors(concatenateOp.getOperands()); - auto result = evalConcatenateOp(operands, concatenateOp.getDimension(), - concatenateOp.getType()); - scope.add(concatenateOp.getResult(), result); - } else if (auto constantOp = dyn_cast(op)) { - auto result = evalConstantOp(constantOp.getValue()); - scope.add(constantOp.getResult(), result); - } else if (auto convertOp = dyn_cast(op)) { - auto operand = scope.findTensor(convertOp.getOperand()); - auto result = evalConvertOp(operand, convertOp.getType()); - scope.add(convertOp.getResult(), result); - } else if (auto convolutionOp = dyn_cast(op)) { - auto lhs = scope.findTensor(convolutionOp.getLhs()); - auto rhs = scope.findTensor(convolutionOp.getRhs()); + evalConcatenateOp(operands, op.getDimension(), op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto result = evalConstantOp(op.getValue()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalConvertOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); auto rank = lhs.getRank(); SmallVector windowStrides(rank - 2, 1); - if (auto windowStridesAttr = convolutionOp.getWindowStrides()) + if (auto windowStridesAttr = op.getWindowStrides()) windowStrides = SmallVector(windowStridesAttr.value()); SmallVector> padding(rank - 2, {0, 0}); - if (auto paddingAttr = convolutionOp.getPaddingAttr()) { + if (auto paddingAttr = op.getPaddingAttr()) { auto paddingOrErr = hlo::convertPaddingAttribute(paddingAttr, {}); if (failed(paddingOrErr)) report_fatal_error(invalidArgument("Invalid padding format found.")); @@ -530,18 +525,18 @@ SmallVector eval(Region ®ion, } SmallVector lhsDilation(rank - 2, 1); - if (auto lhsDilationAttr = convolutionOp.getLhsDilation()) + if (auto lhsDilationAttr = op.getLhsDilation()) lhsDilation = SmallVector(lhsDilationAttr.value()); SmallVector rhsDilation(rank - 2, 1); - if (auto rhsDilationAttr = convolutionOp.getRhsDilation()) + if (auto rhsDilationAttr = op.getRhsDilation()) rhsDilation = SmallVector(rhsDilationAttr.value()); SmallVector windowReversal(rank - 2, false); - if (auto windowReversalAttr = convolutionOp.getWindowReversal()) + if (auto windowReversalAttr = op.getWindowReversal()) windowReversal = SmallVector(windowReversalAttr.value()); - auto dimensionNumbers = convolutionOp.getDimensionNumbers(); + auto dimensionNumbers = op.getDimensionNumbers(); auto result = evalConvolutionOp( lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, windowReversal, dimensionNumbers.getInputBatchDimension(), @@ -553,252 +548,242 @@ SmallVector eval(Region ®ion, dimensionNumbers.getOutputBatchDimension(), dimensionNumbers.getOutputFeatureDimension(), Axes(dimensionNumbers.getOutputSpatialDimensions()), - convolutionOp.getFeatureGroupCount(), - convolutionOp.getBatchGroupCount(), convolutionOp.getType()); - scope.add(convolutionOp.getResult(), result); - } else if (auto cosineOp = dyn_cast(op)) { - auto operand = scope.findTensor(cosineOp.getOperand()); - auto result = evalCosineOp(operand, cosineOp.getType()); - scope.add(cosineOp.getResult(), result); - } else if (isa(op)) { - failOnDecomposableOp(op); - } else if (isa(op)) { - failOnDecomposableOp(op); - } else if (auto divideOp = dyn_cast(op)) { - auto lhs = scope.findTensor(divideOp.getLhs()); - auto rhs = scope.findTensor(divideOp.getRhs()); - auto result = evalDivideOp(lhs, rhs, divideOp.getType()); - scope.add(divideOp.getResult(), result); - } else if (isa(op)) { - failOnDecomposableOp(op); - } else if (auto dotGeneralOp = dyn_cast(op)) { - auto lhs = scope.findTensor(dotGeneralOp.getLhs()); - auto rhs = scope.findTensor(dotGeneralOp.getRhs()); - auto lhsBatchingDimensions = Axes( - dotGeneralOp.getDotDimensionNumbers().getLhsBatchingDimensions()); - auto rhsBatchingDimensions = Axes( - dotGeneralOp.getDotDimensionNumbers().getRhsBatchingDimensions()); - auto lhsContractingDimensions = Axes( - dotGeneralOp.getDotDimensionNumbers().getLhsContractingDimensions()); - auto rhsContractingDimensions = Axes( - dotGeneralOp.getDotDimensionNumbers().getRhsContractingDimensions()); + op.getFeatureGroupCount(), op.getBatchGroupCount(), op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalCosineOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalDivideOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto lhsBatchingDimensions = + Axes(op.getDotDimensionNumbers().getLhsBatchingDimensions()); + auto rhsBatchingDimensions = + Axes(op.getDotDimensionNumbers().getRhsBatchingDimensions()); + auto lhsContractingDimensions = + Axes(op.getDotDimensionNumbers().getLhsContractingDimensions()); + auto rhsContractingDimensions = + Axes(op.getDotDimensionNumbers().getRhsContractingDimensions()); + auto result = evalDotGeneralOp( + lhs, rhs, lhsBatchingDimensions, rhsBatchingDimensions, + lhsContractingDimensions, rhsContractingDimensions, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto startIndices = scope.findTensors(op.getStartIndices()); + auto sliceSizes = Sizes(op.getSliceSizes()); auto result = - evalDotGeneralOp(lhs, rhs, lhsBatchingDimensions, - rhsBatchingDimensions, lhsContractingDimensions, - rhsContractingDimensions, dotGeneralOp.getType()); - scope.add(dotGeneralOp.getResult(), result); - } else if (auto dynamicSliceOp = dyn_cast(op)) { - auto operand = scope.findTensor(dynamicSliceOp.getOperand()); - auto startIndices = scope.findTensors(dynamicSliceOp.getStartIndices()); - auto sliceSizes = Sizes(dynamicSliceOp.getSliceSizes()); - auto result = evalDynamicSliceOp(operand, startIndices, sliceSizes, - dynamicSliceOp.getType()); - scope.add(dynamicSliceOp.getResult(), result); - } else if (auto dynamicUpdateSliceOp = dyn_cast(op)) { - auto operand = scope.findTensor(dynamicUpdateSliceOp.getOperand()); - auto update = scope.findTensor(dynamicUpdateSliceOp.getUpdate()); - auto startIndices = - scope.findTensors(dynamicUpdateSliceOp.getStartIndices()); - auto result = evalDynamicUpdateSliceOp(operand, update, startIndices, - dynamicUpdateSliceOp.getType()); - scope.add(dynamicUpdateSliceOp.getResult(), result); - } else if (isa(op)) { - failOnDecomposableOp(op); - } else if (auto expOp = dyn_cast(op)) { - auto operand = scope.findTensor(expOp.getOperand()); - auto result = evalExponentialOp(operand, expOp.getType()); - scope.add(expOp.getResult(), result); - } else if (auto expm1Op = dyn_cast(op)) { - auto operand = scope.findTensor(expm1Op.getOperand()); - auto result = evalExpm1Op(operand, expm1Op.getType()); - scope.add(expm1Op.getResult(), result); - } else if (auto floorOp = dyn_cast(op)) { - auto operand = scope.findTensor(floorOp.getOperand()); - auto result = evalFloorOp(operand, floorOp.getType()); - scope.add(floorOp.getResult(), result); - } else if (auto gatherOp = dyn_cast(op)) { - auto operand = scope.findTensor(gatherOp.getOperand()); - auto startIndices = scope.findTensor(gatherOp.getStartIndices()); - auto result = evalGatherOp( - operand, startIndices, - Axes(gatherOp.getDimensionNumbers().getOffsetDims()), - Axes(gatherOp.getDimensionNumbers().getCollapsedSliceDims()), - Axes(gatherOp.getDimensionNumbers().getStartIndexMap()), - Axis(gatherOp.getDimensionNumbers().getIndexVectorDim()), - Sizes(gatherOp.getSliceSizes()), gatherOp.getIndicesAreSorted(), - gatherOp.getType()); - scope.add(gatherOp.getResult(), result); - } else if (auto getDimensionSizeOp = dyn_cast(op)) { - auto operand = scope.findTensor(getDimensionSizeOp.getOperand()); - auto dimension = getDimensionSizeOp.getDimension(); - auto result = evalGetDimensionSizeOp(operand, dimension, - getDimensionSizeOp.getType()); - scope.add(getDimensionSizeOp.getResult(), result); - } else if (auto getTupleElementOp = dyn_cast(op)) { - auto operand = scope.findTuple(getTupleElementOp.getOperand()); + evalDynamicSliceOp(operand, startIndices, sliceSizes, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto update = scope.findTensor(op.getUpdate()); + auto startIndices = scope.findTensors(op.getStartIndices()); auto result = - evalGetTupleElementOp(operand, getTupleElementOp.getIndex()); - scope.add(getTupleElementOp.getResult(), result); - } else if (auto ifOp = dyn_cast(op)) { - auto pred = scope.findTensor(ifOp.getPred()); - auto &trueBranch = ifOp.getTrueBranch(); - auto &falseBranch = ifOp.getFalseBranch(); + evalDynamicUpdateSliceOp(operand, update, startIndices, op.getType()); + scope.add(op.getResult(), result); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalExponentialOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalExpm1Op(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalFloorOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto startIndices = scope.findTensor(op.getStartIndices()); + auto result = evalGatherOp( + operand, startIndices, Axes(op.getDimensionNumbers().getOffsetDims()), + Axes(op.getDimensionNumbers().getCollapsedSliceDims()), + Axes(op.getDimensionNumbers().getStartIndexMap()), + Axis(op.getDimensionNumbers().getIndexVectorDim()), + Sizes(op.getSliceSizes()), op.getIndicesAreSorted(), op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto dimension = op.getDimension(); + auto result = evalGetDimensionSizeOp(operand, dimension, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTuple(op.getOperand()); + auto result = evalGetTupleElementOp(operand, op.getIndex()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto pred = scope.findTensor(op.getPred()); + auto &trueBranch = op.getTrueBranch(); + auto &falseBranch = op.getFalseBranch(); auto results = evalIfOp(pred, trueBranch, falseBranch, process, scope); - scope.add(ifOp.getResults(), results); - } else if (auto imagOp = dyn_cast(op)) { - auto operand = scope.findTensor(imagOp.getOperand()); - auto result = evalImagOp(operand, imagOp.getType()); - scope.add(imagOp.getResult(), result); - } else if (auto infeedOp = dyn_cast(op)) { - auto token = scope.findToken(infeedOp.getToken()); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalImagOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto token = scope.findToken(op.getToken()); auto results = evalInfeedOp(token, process, region, scope); - scope.add(infeedOp.getResults(), results); - } else if (auto iotaOp = dyn_cast(op)) { - auto iotaDimension = iotaOp.getIotaDimension(); - auto result = evalIotaOp(iotaDimension, iotaOp.getType()); - scope.add(iotaOp.getResult(), result); - } else if (auto isFiniteOp = dyn_cast(op)) { - auto operand = scope.findTensor(isFiniteOp.getOperand()); - auto result = evalIsFiniteOp(operand, isFiniteOp.getType()); - scope.add(isFiniteOp.getResult(), result); - } else if (auto log1pOp = dyn_cast(op)) { - auto operand = scope.findTensor(log1pOp.getOperand()); - auto result = evalLog1pOp(operand, log1pOp.getType()); - scope.add(log1pOp.getResult(), result); - } else if (auto logOp = dyn_cast(op)) { - auto operand = scope.findTensor(logOp.getOperand()); - auto result = evalLogOp(operand, logOp.getType()); - scope.add(logOp.getResult(), result); - } else if (auto logisticOp = dyn_cast(op)) { - auto operand = scope.findTensor(logisticOp.getOperand()); - auto result = evalLogisticOp(operand, logisticOp.getType()); - scope.add(logisticOp.getResult(), result); - } else if (auto mapOp = dyn_cast(op)) { - auto inputs = scope.findTensors(mapOp.getInputs()); - auto &computation = mapOp.getComputation(); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto iotaDimension = op.getIotaDimension(); + auto result = evalIotaOp(iotaDimension, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalIsFiniteOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalLog1pOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalLogOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalLogisticOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto inputs = scope.findTensors(op.getInputs()); + auto &computation = op.getComputation(); auto result = - evalMapOp(inputs, computation, process, scope, mapOp.getType()); - scope.add(mapOp.getResult(), result); - } else if (auto maxOp = dyn_cast(op)) { - auto lhs = scope.findTensor(maxOp.getLhs()); - auto rhs = scope.findTensor(maxOp.getRhs()); - auto result = evalMaxOp(lhs, rhs, maxOp.getType()); - scope.add(maxOp.getResult(), result); - } else if (auto minOp = dyn_cast(op)) { - auto lhs = scope.findTensor(minOp.getLhs()); - auto rhs = scope.findTensor(minOp.getRhs()); - auto result = evalMinOp(lhs, rhs, minOp.getType()); - scope.add(minOp.getResult(), result); - } else if (auto multiplyOp = dyn_cast(op)) { - auto lhs = scope.findTensor(multiplyOp.getLhs()); - auto rhs = scope.findTensor(multiplyOp.getRhs()); - auto result = evalMultiplyOp(lhs, rhs, multiplyOp.getType()); - scope.add(multiplyOp.getResult(), result); - } else if (auto negOp = dyn_cast(op)) { - auto operand = scope.findTensor(negOp.getOperand()); - auto result = evalNegOp(operand, negOp.getType()); - scope.add(negOp.getResult(), result); - } else if (auto notOp = dyn_cast(op)) { - auto operand = scope.findTensor(notOp.getOperand()); - auto result = evalNotOp(operand, notOp.getType()); - scope.add(notOp.getResult(), result); - } else if (auto optimizationBarrierOp = - dyn_cast(op)) { - auto operand = scope.find(optimizationBarrierOp.getOperand()); + evalMapOp(inputs, computation, process, scope, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalMaxOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalMinOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalMultiplyOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalNegOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalNotOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.find(op.getOperand()); auto results = evalOptimizationBarrierOp(operand); - scope.add(optimizationBarrierOp.getResults(), results); - } else if (auto orOp = dyn_cast(op)) { - auto lhs = scope.findTensor(orOp.getLhs()); - auto rhs = scope.findTensor(orOp.getRhs()); - auto result = evalOrOp(lhs, rhs, orOp.getType()); - scope.add(orOp.getResult(), result); - } else if (auto outfeedOp = dyn_cast(op)) { - auto inputs = scope.findTensors(outfeedOp.getInputs()); - auto token = scope.findToken(outfeedOp.getToken()); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalOrOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto inputs = scope.findTensors(op.getInputs()); + auto token = scope.findToken(op.getToken()); auto result = evalOutfeedOp(inputs, token, process); - scope.add(outfeedOp.getResult(), result); - } else if (auto padOp = dyn_cast(op)) { - auto operand = scope.findTensor(padOp.getOperand()); - auto paddingValue = scope.findTensor(padOp.getPaddingValue()); - auto edgePaddingLow = Sizes(padOp.getEdgePaddingLow()); - auto interiorPadding = Sizes(padOp.getInteriorPadding()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto paddingValue = scope.findTensor(op.getPaddingValue()); + auto edgePaddingLow = Sizes(op.getEdgePaddingLow()); + auto interiorPadding = Sizes(op.getInteriorPadding()); auto result = evalPadOp(operand, paddingValue, edgePaddingLow, - interiorPadding, padOp.getType()); - scope.add(padOp.getResult(), result); - } else if (auto partitionIdOp = dyn_cast(op)) { + interiorPadding, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { auto result = evalPartitionIdOp(process, op.getContext()); - scope.add(partitionIdOp.getResult(), result); - } else if (auto populationCountOp = dyn_cast(op)) { - auto operand = scope.findTensor(populationCountOp.getOperand()); - auto result = evalPopulationCountOp(operand, populationCountOp.getType()); - scope.add(populationCountOp.getResult(), result); - } else if (auto powerOp = dyn_cast(op)) { - auto lhs = scope.findTensor(powerOp.getLhs()); - auto rhs = scope.findTensor(powerOp.getRhs()); - auto result = evalPowerOp(lhs, rhs, powerOp.getType()); - scope.add(powerOp.getResult(), result); - } else if (auto realOp = dyn_cast(op)) { - auto operand = scope.findTensor(realOp.getOperand()); - auto result = evalRealOp(operand, realOp.getType()); - scope.add(realOp.getResult(), result); - } else if (auto recvOp = dyn_cast(op)) { - auto token = scope.findToken(recvOp.getToken()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalPopulationCountOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalPowerOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalRealOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto token = scope.findToken(op.getToken()); ChannelId channelId = 0; - if (auto channelHandle = recvOp.getChannelHandle()) + if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle.getHandle(); auto results = evalRecvOp(token, channelId, process); - scope.add(recvOp.getResults(), results); - } else if (auto reduceOp = dyn_cast(op)) { - auto inputs = scope.findTensors(reduceOp.getInputs()); - auto initValues = scope.findTensors(reduceOp.getInitValues()); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto inputs = scope.findTensors(op.getInputs()); + auto initValues = scope.findTensors(op.getInitValues()); SmallVector resultTypes; - for (auto resultType : reduceOp.getResultTypes()) + for (auto resultType : op.getResultTypes()) resultTypes.push_back(cast(resultType)); - auto results = - evalReduceOp(inputs, initValues, Axes(reduceOp.getDimensions()), - reduceOp.getBody(), process, scope, resultTypes); - scope.add(reduceOp.getResults(), results); - } else if (auto reducePrecisionOp = dyn_cast(op)) { - auto operand = scope.findTensor(reducePrecisionOp.getOperand()); - int32_t exponentBits = reducePrecisionOp.getExponentBits(); - int32_t mantissaBits = reducePrecisionOp.getMantissaBits(); + auto results = evalReduceOp(inputs, initValues, Axes(op.getDimensions()), + op.getBody(), process, scope, resultTypes); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + int32_t exponentBits = op.getExponentBits(); + int32_t mantissaBits = op.getMantissaBits(); auto result = evalReducePrecisionOp(operand, exponentBits, mantissaBits, - reducePrecisionOp.getType()); - scope.add(reducePrecisionOp.getResult(), result); - } else if (auto reduceScatterOp = dyn_cast(op)) { - auto operand = scope.findTensor(reduceScatterOp.getOperand()); - int64_t scatterDimension = reduceScatterOp.getScatterDimension(); - auto replicaGroups = getReplicaGroups(reduceScatterOp.getReplicaGroups()); + op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + int64_t scatterDimension = op.getScatterDimension(); + auto replicaGroups = getReplicaGroups(op.getReplicaGroups()); ChannelId channelId = 0; - if (auto channelHandle = reduceScatterOp.getChannelHandle()) + if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle->getHandle(); auto result = evalReduceScatterOp( operand, scatterDimension, replicaGroups, channelId, - reduceScatterOp.getUseGlobalDeviceIds(), - reduceScatterOp.getComputation(), process, scope, - reduceScatterOp.getType()); - scope.add(reduceScatterOp.getResult(), result); - } else if (auto reduceWindowOp = dyn_cast(op)) { - auto inputs = scope.findTensors(reduceWindowOp.getInputs()); - auto initValues = scope.findTensors(reduceWindowOp.getInitValues()); + op.getUseGlobalDeviceIds(), op.getComputation(), process, scope, + op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto inputs = scope.findTensors(op.getInputs()); + auto initValues = scope.findTensors(op.getInitValues()); int64_t rank = inputs[0].getRank(); Sizes windowStrides(rank, 1); - if (auto windowStridesAttr = reduceWindowOp.getWindowStrides()) + if (auto windowStridesAttr = op.getWindowStrides()) windowStrides = Sizes(*windowStridesAttr); Sizes baseDilations(rank, 1); - if (auto baseDilationsAttr = reduceWindowOp.getBaseDilations()) + if (auto baseDilationsAttr = op.getBaseDilations()) baseDilations = Sizes(*baseDilationsAttr); Sizes windowDilations(rank, 1); - if (auto windowDilationsAttr = reduceWindowOp.getWindowDilations()) + if (auto windowDilationsAttr = op.getWindowDilations()) windowDilations = Sizes(*windowDilationsAttr); Sizes paddingLow(rank, 0), paddingHigh(rank, 0); - if (auto paddingAttr = reduceWindowOp.getPadding()) { + if (auto paddingAttr = op.getPadding()) { auto paddingOrErr = hlo::convertPaddingAttribute(paddingAttr, {}); if (failed(paddingOrErr)) report_fatal_error(invalidArgument("Invalid padding format found.")); @@ -809,87 +794,86 @@ SmallVector eval(Region ®ion, } SmallVector resultTypes; - for (auto resultType : reduceWindowOp.getResultTypes()) + for (auto resultType : op.getResultTypes()) resultTypes.push_back(cast(resultType)); auto results = evalReduceWindowOp( - inputs, initValues, Sizes(reduceWindowOp.getWindowDimensions()), - windowStrides, baseDilations, windowDilations, paddingLow, - paddingHigh, reduceWindowOp.getBody(), process, scope, resultTypes); - scope.add(reduceWindowOp.getResults(), results); - } else if (auto remOp = dyn_cast(op)) { - auto lhs = scope.findTensor(remOp.getLhs()); - auto rhs = scope.findTensor(remOp.getRhs()); - auto result = evalRemOp(lhs, rhs, remOp.getType()); - scope.add(remOp.getResult(), result); - } else if (auto replicaIdOp = dyn_cast(op)) { + inputs, initValues, Sizes(op.getWindowDimensions()), windowStrides, + baseDilations, windowDilations, paddingLow, paddingHigh, op.getBody(), + process, scope, resultTypes); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalRemOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { auto result = evalReplicaIdOp(process, op.getContext()); - scope.add(replicaIdOp.getResult(), result); - } else if (auto reshapeOp = dyn_cast(op)) { - auto operand = scope.findTensor(reshapeOp.getOperand()); - auto result = evalReshapeOp(operand, reshapeOp.getType()); - scope.add(reshapeOp.getResult(), result); - } else if (auto returnOp = dyn_cast(op)) { - return scope.find(returnOp.getOperands()); - } else if (auto returnOp = dyn_cast(op)) { - return scope.find(returnOp.getResults()); - } else if (auto reverseOp = dyn_cast(op)) { - auto operand = scope.findTensor(reverseOp.getOperand()); - auto dimensions = Axes(reverseOp.getDimensions()); - auto result = evalReverseOp(operand, dimensions, reverseOp.getType()); - scope.add(reverseOp.getResult(), result); - } else if (isa(op)) { - failOnDecomposableOp(op); - } else if (isa(op)) { - failOnDecomposableOp(op); - } else if (auto roundNearestEvenOp = dyn_cast(op)) { - auto operand = scope.findTensor(roundNearestEvenOp.getOperand()); - auto result = - evalRoundNearestEvenOp(operand, roundNearestEvenOp.getType()); - scope.add(roundNearestEvenOp.getResult(), result); - } else if (auto roundOp = dyn_cast(op)) { - auto operand = scope.findTensor(roundOp.getOperand()); - auto result = evalRoundOp(operand, roundOp.getType()); - scope.add(roundOp.getResult(), result); - } else if (auto rsqrtOp = dyn_cast(op)) { - auto operand = scope.findTensor(rsqrtOp.getOperand()); - auto result = evalRsqrtOp(operand, rsqrtOp.getType()); - scope.add(rsqrtOp.getResult(), result); - } else if (auto scatterOp = dyn_cast(op)) { - auto inputs = scope.findTensors(scatterOp.getInputs()); - auto scatterIndices = scope.findTensor(scatterOp.getScatterIndices()); - auto updates = scope.findTensors(scatterOp.getUpdates()); - auto scatterDimensionNumbers = scatterOp.getScatterDimensionNumbers(); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalReshapeOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + return scope.find(op.getOperands()); + } else if (auto op = dyn_cast(operation)) { + return scope.find(op.getResults()); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto dimensions = Axes(op.getDimensions()); + auto result = evalReverseOp(operand, dimensions, op.getType()); + scope.add(op.getResult(), result); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalRoundNearestEvenOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalRoundOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalRsqrtOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto inputs = scope.findTensors(op.getInputs()); + auto scatterIndices = scope.findTensor(op.getScatterIndices()); + auto updates = scope.findTensors(op.getUpdates()); + auto scatterDimensionNumbers = op.getScatterDimensionNumbers(); Axes updateWindowDims(scatterDimensionNumbers.getUpdateWindowDims()); Axes insertedWindowDims(scatterDimensionNumbers.getInsertedWindowDims()); Axes scatterDimsToOperandDims( scatterDimensionNumbers.getScatterDimsToOperandDims()); Axis indexVectorDim(scatterDimensionNumbers.getIndexVectorDim()); - auto &updateComputation = scatterOp.getUpdateComputation(); - SmallVector resultTypes(scatterOp->getResultTypes()); + auto &updateComputation = op.getUpdateComputation(); + SmallVector resultTypes(op->getResultTypes()); auto results = evalScatterOp( inputs, scatterIndices, updates, updateWindowDims, insertedWindowDims, scatterDimsToOperandDims, indexVectorDim, updateComputation, process, scope, resultTypes); - scope.add(scatterOp.getResults(), results); - } else if (auto selectAndScatterOp = dyn_cast(op)) { - auto operand = scope.findTensor(selectAndScatterOp.getOperand()); - auto source = scope.findTensor(selectAndScatterOp.getSource()); - auto initValue = scope.findTensor(selectAndScatterOp.getInitValue()); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto source = scope.findTensor(op.getSource()); + auto initValue = scope.findTensor(op.getInitValue()); auto rank = operand.getRank(); Sizes windowDimensions(rank, 1); - if (auto windowDimensionsAttr = selectAndScatterOp.getWindowDimensions()) + if (auto windowDimensionsAttr = op.getWindowDimensions()) windowDimensions.assign(windowDimensionsAttr->begin(), windowDimensionsAttr->end()); Sizes windowStrides(rank, 1); - if (auto windowStridesAttr = selectAndScatterOp.getWindowStrides()) + if (auto windowStridesAttr = op.getWindowStrides()) windowStrides.assign(windowStridesAttr->begin(), windowStridesAttr->end()); Sizes paddingLow(rank, 0); - if (auto padding = selectAndScatterOp.getPadding()) { + if (auto padding = op.getPadding()) { auto paddingOrErr = hlo::convertPaddingAttribute(padding, {}); if (failed(paddingOrErr)) report_fatal_error(invalidArgument("Invalid padding format found.")); @@ -898,114 +882,108 @@ SmallVector eval(Region ®ion, } } - auto result = evalSelectAndScatterOp( - operand, source, initValue, windowDimensions, windowStrides, - paddingLow, selectAndScatterOp.getSelect(), - selectAndScatterOp.getScatter(), process, scope, - selectAndScatterOp.getType()); - scope.add(selectAndScatterOp.getResult(), result); - } else if (auto selectOp = dyn_cast(op)) { - auto pred = scope.findTensor(selectOp.getPred()); - auto onTrue = scope.findTensor(selectOp.getOnTrue()); - auto onFalse = scope.findTensor(selectOp.getOnFalse()); - auto result = evalSelectOp(pred, onTrue, onFalse, selectOp.getType()); - scope.add(selectOp.getResult(), result); - } else if (auto sendOp = dyn_cast(op)) { - auto inputs = scope.findTensors(sendOp.getInputs()); - auto token = scope.findToken(sendOp.getToken()); + auto result = + evalSelectAndScatterOp(operand, source, initValue, windowDimensions, + windowStrides, paddingLow, op.getSelect(), + op.getScatter(), process, scope, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto pred = scope.findTensor(op.getPred()); + auto onTrue = scope.findTensor(op.getOnTrue()); + auto onFalse = scope.findTensor(op.getOnFalse()); + auto result = evalSelectOp(pred, onTrue, onFalse, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto inputs = scope.findTensors(op.getInputs()); + auto token = scope.findToken(op.getToken()); ChannelId channelId = 0; - if (auto channelHandle = sendOp.getChannelHandle()) + if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle.getHandle(); auto result = evalSendOp(inputs, token, channelId, process); - scope.add(sendOp.getResult(), result); - } else if (auto shiftLeftOp = dyn_cast(op)) { - auto lhs = scope.findTensor(shiftLeftOp.getLhs()); - auto rhs = scope.findTensor(shiftLeftOp.getRhs()); - auto result = evalShiftLeftOp(lhs, rhs, shiftLeftOp.getType()); - scope.add(shiftLeftOp.getResult(), result); - } else if (auto shiftRightArithmeticOp = - dyn_cast(op)) { - auto lhs = scope.findTensor(shiftRightArithmeticOp.getLhs()); - auto rhs = scope.findTensor(shiftRightArithmeticOp.getRhs()); - auto result = evalShiftRightArithmeticOp( - lhs, rhs, shiftRightArithmeticOp.getType()); - scope.add(shiftRightArithmeticOp.getResult(), result); - } else if (auto shiftRightLogicalOp = dyn_cast(op)) { - auto lhs = scope.findTensor(shiftRightLogicalOp.getLhs()); - auto rhs = scope.findTensor(shiftRightLogicalOp.getRhs()); - auto result = - evalShiftRightLogicalOp(lhs, rhs, shiftRightLogicalOp.getType()); - scope.add(shiftRightLogicalOp.getResult(), result); - } else if (auto signOp = dyn_cast(op)) { - auto operand = scope.findTensor(signOp.getOperand()); - auto result = evalSignOp(operand, signOp.getType()); - scope.add(signOp.getResult(), result); - } else if (auto sineOp = dyn_cast(op)) { - auto operand = scope.findTensor(sineOp.getOperand()); - auto result = evalSineOp(operand, sineOp.getType()); - scope.add(sineOp.getResult(), result); - } else if (auto sliceOp = dyn_cast(op)) { - auto operand = scope.findTensor(sliceOp.getOperand()); - auto startIndices = Sizes(sliceOp.getStartIndices()); - auto strides = Sizes(sliceOp.getStrides()); - auto result = - evalSliceOp(operand, startIndices, strides, sliceOp.getType()); - scope.add(sliceOp.getResult(), result); - } else if (auto sortOp = dyn_cast(op)) { - auto operands = scope.findTensors(sortOp.getInputs()); - auto dimension = sortOp.getDimension(); - auto isStable = sortOp.getIsStable(); - auto &comparator = sortOp.getComparator(); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalShiftLeftOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalShiftRightArithmeticOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalShiftRightLogicalOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalSignOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalSineOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto startIndices = Sizes(op.getStartIndices()); + auto strides = Sizes(op.getStrides()); + auto result = evalSliceOp(operand, startIndices, strides, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operands = scope.findTensors(op.getInputs()); + auto dimension = op.getDimension(); + auto isStable = op.getIsStable(); + auto &comparator = op.getComparator(); auto results = evalSortOp(operands, dimension, isStable, comparator, process, scope); - scope.add(sortOp.getResults(), results); - } else if (auto sqrtOp = dyn_cast(op)) { - auto operand = scope.findTensor(sqrtOp.getOperand()); - auto result = evalSqrtOp(operand, sqrtOp.getType()); - scope.add(sqrtOp.getResult(), result); - } else if (auto subtractOp = dyn_cast(op)) { - auto lhs = scope.findTensor(subtractOp.getLhs()); - auto rhs = scope.findTensor(subtractOp.getRhs()); - auto result = evalSubtractOp(lhs, rhs, subtractOp.getType()); - scope.add(subtractOp.getResult(), result); - } else if (auto tanhOp = dyn_cast(op)) { - auto operand = scope.findTensor(tanhOp.getOperand()); - auto result = evalTanhOp(operand, tanhOp.getType()); - scope.add(tanhOp.getResult(), result); - } else if (isa(op)) { - failOnDecomposableOp(op); - } else if (isa(op)) { - failOnDecomposableOp(op); - } else if (auto transposeOp = dyn_cast(op)) { - auto operand = scope.findTensor(transposeOp.getOperand()); - auto permutation = Axes(transposeOp.getPermutation()); - auto result = - evalTransposeOp(operand, permutation, transposeOp.getType()); - scope.add(transposeOp.getResult(), result); - } else if (isa(op)) { - failOnDecomposableOp(op); - } else if (auto tupleOp = dyn_cast(op)) { - auto val = scope.find(tupleOp.getVal()); - auto result = evalTupleOp(val, cast(tupleOp.getType())); - scope.add(tupleOp.getResult(), result); - } else if (isa(op)) { - failOnDecomposableOp(op); - } else if (auto whileOp = dyn_cast(op)) { - auto operand = scope.find(whileOp.getOperand()); - auto &cond = whileOp.getCond(); - auto &body = whileOp.getBody(); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalSqrtOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalSubtractOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto result = evalTanhOp(operand, op.getType()); + scope.add(op.getResult(), result); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.findTensor(op.getOperand()); + auto permutation = Axes(op.getPermutation()); + auto result = evalTransposeOp(operand, permutation, op.getType()); + scope.add(op.getResult(), result); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (auto op = dyn_cast(operation)) { + auto val = scope.find(op.getVal()); + auto result = evalTupleOp(val, cast(op.getType())); + scope.add(op.getResult(), result); + } else if (isa(operation)) { + failOnDecomposableOp(operation); + } else if (auto op = dyn_cast(operation)) { + auto operand = scope.find(op.getOperand()); + auto &cond = op.getCond(); + auto &body = op.getBody(); auto results = evalWhileOp(operand, cond, body, fallback, process, scope); - scope.add(whileOp.getResults(), results); - } else if (auto xorOp = dyn_cast(op)) { - auto lhs = scope.findTensor(xorOp.getLhs()); - auto rhs = scope.findTensor(xorOp.getRhs()); - auto result = evalXorOp(lhs, rhs, xorOp.getType()); - scope.add(xorOp.getResult(), result); + scope.add(op.getResults(), results); + } else if (auto op = dyn_cast(operation)) { + auto lhs = scope.findTensor(op.getLhs()); + auto rhs = scope.findTensor(op.getRhs()); + auto result = evalXorOp(lhs, rhs, op.getType()); + scope.add(op.getResult(), result); } else { if (!fallback) - report_fatal_error( - invalidArgument("Unsupported op: %s", debugString(op).c_str())); - auto status = (*fallback)(op, scope, process); + report_fatal_error(invalidArgument("Unsupported op: %s", + debugString(operation).c_str())); + auto status = (*fallback)(operation, scope, process); if (status) llvm::report_fatal_error(std::move(status)); } }