From e0c6d5d4a7a6caa804d4d7bb6f6a6ca8df704a48 Mon Sep 17 00:00:00 2001 From: mlevesquedion Date: Fri, 3 May 2024 09:44:21 -0700 Subject: [PATCH] =?UTF-8?q?Use=20map=5Fto=5Fvector=20instead=20of=20map=5F?= =?UTF-8?q?range=20=E2=88=98=20to=5Fvector=20(#2281)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../transforms/StablehloLegalizeToLinalg.cpp | 15 ++--- stablehlo/dialect/BroadcastUtils.cpp | 4 +- stablehlo/dialect/StablehloOps.cpp | 16 ++--- stablehlo/dialect/TypeInference.cpp | 46 +++++++------- stablehlo/dialect/VhloOps.cpp | 8 +-- stablehlo/dialect/VhloTypes.cpp | 8 +-- stablehlo/dialect/VhloTypes.td | 4 +- stablehlo/reference/Api.cpp | 10 +-- stablehlo/reference/Ops.cpp | 8 +-- stablehlo/reference/ProcessGrid.cpp | 4 +- stablehlo/reference/Scope.cpp | 12 ++-- stablehlo/reference/Tensor.cpp | 62 +++++++++---------- .../transforms/ShapeLegalizeToStablehlo.cpp | 6 +- .../StablehloAggressiveSimplification.cpp | 6 +- .../transforms/StablehloRefineShapes.cpp | 6 +- 15 files changed, 108 insertions(+), 107 deletions(-) diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index ad7954d7d2b..2cb05d181e8 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -2284,9 +2284,10 @@ struct PadOpNegativePaddingConversion final // Then slice according to the negative edge padding. Static shapes only for // now. if (!op.getType().hasStaticShape()) return failure(); - SmallVector sizes(llvm::map_range( - op.getType().getShape(), - [&](int64_t dim) { return rewriter.getIndexAttr(dim); })); + auto sizes = llvm::map_to_vector(op.getType().getShape(), + [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + }); SmallVector strides(sliceStarts.size(), rewriter.getIndexAttr(1)); rewriter.replaceOpWithNewOp(op, pad, sliceStarts, @@ -2506,10 +2507,10 @@ struct SetDimensionSizeConverter final rewriter.getIndexAttr(0)); SmallVector strides(resultType.getRank(), rewriter.getIndexAttr(1)); - SmallVector sizes(llvm::map_range( - resultType.getShape(), [&](int64_t dim) -> OpFoldResult { - return rewriter.getIndexAttr(dim); - })); + auto sizes = llvm::map_to_vector(resultType.getShape(), + [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + }); Value dimensionSize = rewriter.create(loc, setDimensionSizeOp.getSize()); sizes[setDimensionSizeOp.getDimension()] = diff --git a/stablehlo/dialect/BroadcastUtils.cpp b/stablehlo/dialect/BroadcastUtils.cpp index 130ec510157..4f15696092d 100644 --- a/stablehlo/dialect/BroadcastUtils.cpp +++ b/stablehlo/dialect/BroadcastUtils.cpp @@ -56,9 +56,9 @@ Value computeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, Value computeNaryElementwiseBroadcastingResultExtents(Location loc, ValueRange operands, OpBuilder& builder) { - auto shapes = llvm::to_vector<4>(llvm::map_range(operands, [&](Value v) { + auto shapes = llvm::map_to_vector<4>(operands, [&](Value v) { return builder.createOrFold(loc, v); - })); + }); int64_t resultRank = 0; for (Value s : shapes) { diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index adc829d8014..5deff557dd6 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -1820,12 +1820,12 @@ void ReduceOp::build(OpBuilder&, OperationState& odsState, ValueRange inputs, odsState.attributes.getDictionary(odsState.getContext()), {}, odsState.regions); - SmallVector inputArgTensorTypes{ - llvm::map_range(adaptor.getInputs().getTypes(), - [](Type t) { return cast(t); })}; - SmallVector initValueTensorTypes{ - llvm::map_range(adaptor.getInitValues().getTypes(), - [](Type t) { return cast(t); })}; + auto inputArgTensorTypes = + llvm::map_to_vector(adaptor.getInputs().getTypes(), + [](Type t) { return cast(t); }); + auto initValueTensorTypes = + llvm::map_to_vector(adaptor.getInitValues().getTypes(), + [](Type t) { return cast(t); }); if (failed(hlo::verifyReduceOpInputsAndInferShape( odsState.location, inputArgTensorTypes, dimensions, newDimensions, @@ -3288,8 +3288,8 @@ ParseResult parseWindowAttributes(OpAsmParser& parser, Attribute& windowStrides, int64Parser)) return failure(); if (attributeName == "reverse") { - auto boolVector = llvm::to_vector<4>( - llvm::map_range(values, [](int64_t v) { return v != 0; })); + auto boolVector = + llvm::map_to_vector<4>(values, [](int64_t v) { return v != 0; }); windowReversal = DenseBoolArrayAttr::get(parser.getContext(), boolVector); } else { diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 92d815bb777..7ad992705d5 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -798,9 +798,9 @@ FailureOr> getAccumulatorTypes( } Block& block = region.front(); - return llvm::to_vector( - llvm::map_range(block.getTerminator()->getOperands(), - [&](Value v) { return cast(v.getType()); })); + return llvm::map_to_vector( + block.getTerminator()->getOperands(), + [&](Value v) { return cast(v.getType()); }); } LogicalResult verifyReducerShape(std::optional loc, Block& block, @@ -1665,8 +1665,8 @@ LogicalResult inferAllReduceOp( std::optional location, ValueRange operands, Region& computation, SmallVectorImpl& inferredReturnShapes) { TypeRange inputTypes = operands.getTypes(); - SmallVector inputArgTensorTypes{ - llvm::map_range(inputTypes, [](Type t) { return cast(t); })}; + auto inputArgTensorTypes = llvm::map_to_vector( + inputTypes, [](Type t) { return cast(t); }); // all_reduce_c6, all_reduce_c7 auto accumulatorTypesOrErr = getAccumulatorTypes(location, computation); if (failed(accumulatorTypesOrErr)) return failure(); @@ -2675,8 +2675,8 @@ LogicalResult inferReduceOp( std::optional location, TypeRange inputTypes, ArrayRef dimensions, Region& body, SmallVectorImpl& inferredReturnShapes) { - SmallVector inputArgTensorTypes{ - llvm::map_range(inputTypes, [](Type t) { return cast(t); })}; + auto inputArgTensorTypes = llvm::map_to_vector( + inputTypes, [](Type t) { return cast(t); }); SmallVector newDimensions; Attribute encoding; @@ -2703,10 +2703,10 @@ LogicalResult inferReduceWindowOp( std::optional> windowDilations, std::optional padding, Region& body, SmallVectorImpl& inferredReturnShapes) { - SmallVector inputTypes{llvm::map_range( - inputs.getTypes(), [](Type t) { return cast(t); })}; - SmallVector initValueTypes{llvm::map_range( - initValues.getTypes(), [](Type t) { return cast(t); })}; + auto inputTypes = llvm::map_to_vector( + inputs.getTypes(), [](Type t) { return cast(t); }); + auto initValueTypes = llvm::map_to_vector( + initValues.getTypes(), [](Type t) { return cast(t); }); SmallVector windowDims; SmallVector inferredWindow; @@ -4009,10 +4009,10 @@ LogicalResult verifyRecvOp(HloDialectInterface* dialect, LogicalResult verifyReduceOp(std::optional location, ValueRange inputs, ValueRange initValues, ArrayRef dimensions, Region& body) { - SmallVector inputTypes{llvm::map_range( - inputs.getTypes(), [](Type t) { return cast(t); })}; - SmallVector initValueTypes{llvm::map_range( - initValues.getTypes(), [](Type t) { return cast(t); })}; + auto inputTypes = llvm::map_to_vector( + inputs.getTypes(), [](Type t) { return cast(t); }); + auto initValueTypes = llvm::map_to_vector( + initValues.getTypes(), [](Type t) { return cast(t); }); SmallVector newDimensions; Attribute encoding; @@ -4136,10 +4136,10 @@ LogicalResult verifyReduceWindowOp( std::optional> baseDilations, std::optional> windowDilations, std::optional padding, Region& body) { - SmallVector inputTypes{llvm::map_range( - inputs.getTypes(), [](Type t) { return cast(t); })}; - SmallVector initValueTypes{llvm::map_range( - initValues.getTypes(), [](Type t) { return cast(t); })}; + auto inputTypes = llvm::map_to_vector( + inputs.getTypes(), [](Type t) { return cast(t); }); + auto initValueTypes = llvm::map_to_vector( + initValues.getTypes(), [](Type t) { return cast(t); }); SmallVector windowDims; SmallVector inferredWindow; @@ -4277,10 +4277,10 @@ LogicalResult verifyScatterOp(std::optional location, auto numOperands = inputs.size(); auto scatterIndicesType = cast(scatterIndices.getType()); - SmallVector operandTypes = llvm::to_vector(llvm::map_range( - inputs.getTypes(), [](Type type) { return cast(type); })); - SmallVector updatesTypes = llvm::to_vector(llvm::map_range( - updates.getTypes(), [](Type type) { return cast(type); })); + auto operandTypes = llvm::map_to_vector( + inputs.getTypes(), [](Type type) { return cast(type); }); + auto updatesTypes = llvm::map_to_vector( + updates.getTypes(), [](Type type) { return cast(type); }); bool scatterIndicesTypeRanked = isa(scatterIndicesType); // scatter_c1 diff --git a/stablehlo/dialect/VhloOps.cpp b/stablehlo/dialect/VhloOps.cpp index 1d2ae0afa5d..528d6acbb8f 100644 --- a/stablehlo/dialect/VhloOps.cpp +++ b/stablehlo/dialect/VhloOps.cpp @@ -314,10 +314,10 @@ Type getVhloElementType(Type tensorType) { bool checkIfOperandAndResultElementTypesMatch(TypeRange operandTypes, TypeRange resultTypes) { - SmallVector inputElementTypes{llvm::map_range( - operandTypes, [](Type t) { return getVhloElementType(t); })}; - SmallVector resultElementTypes{llvm::map_range( - resultTypes, [](Type t) { return getVhloElementType(t); })}; + auto inputElementTypes = llvm::map_to_vector( + operandTypes, [](Type t) { return getVhloElementType(t); }); + auto resultElementTypes = llvm::map_to_vector( + resultTypes, [](Type t) { return getVhloElementType(t); }); return llvm::all_of( llvm::zip(inputElementTypes, resultElementTypes), diff --git a/stablehlo/dialect/VhloTypes.cpp b/stablehlo/dialect/VhloTypes.cpp index 247a696d34b..9688bbd75f5 100644 --- a/stablehlo/dialect/VhloTypes.cpp +++ b/stablehlo/dialect/VhloTypes.cpp @@ -124,8 +124,8 @@ void VhloTypeConverter::addBuiltinToVhloConversions() { Type convertedStorageType = convertType(type.getStorageType()); Type convertedExpressedType = convertType(type.getExpressedType()); if (!convertedStorageType || !convertedExpressedType) return {}; - SmallVector scales = llvm::to_vector(llvm::map_range( - type.getScales(), [](double scale) { return APFloat(scale); })); + auto scales = llvm::map_to_vector( + type.getScales(), [](double scale) { return APFloat(scale); }); return vhlo::UniformQuantizedPerAxisV1Type::get( type.getContext(), type.getFlags(), convertedStorageType, convertedExpressedType, type.getQuantizedDimension(), scales, @@ -239,9 +239,9 @@ void VhloTypeConverter::addVhloToBuiltinConversions() { Type convertedStorageType = convertType(type.getStorageType()); Type convertedExpressedType = convertType(type.getExpressedType()); if (!convertedStorageType || !convertedExpressedType) return {}; - SmallVector scales = llvm::to_vector(llvm::map_range( + auto scales = llvm::map_to_vector( type.getScales(), - [](const APFloat& scale) { return scale.convertToDouble(); })); + [](const APFloat& scale) { return scale.convertToDouble(); }); return quant::UniformQuantizedPerAxisType::get( type.getFlags(), convertedStorageType, convertedExpressedType, scales, type.getZeroPoints(), type.getQuantizedDimension(), diff --git a/stablehlo/dialect/VhloTypes.td b/stablehlo/dialect/VhloTypes.td index 7ce69c6add1..ff815e642b8 100644 --- a/stablehlo/dialect/VhloTypes.td +++ b/stablehlo/dialect/VhloTypes.td @@ -237,8 +237,8 @@ def VHLO_QuantizationScalesV1 : ArrayRefParameter<"::llvm::APFloat", "array of d return $_parser.parseFloat(scales.emplace_back()); }); if(failed(parseResult)) return failure(); - return llvm::to_vector(llvm::map_range( - scales, [](double scale) { return APFloat(scale); })); + return llvm::map_to_vector( + scales, [](double scale) { return APFloat(scale); }); }() }]; let printer = [{ diff --git a/stablehlo/reference/Api.cpp b/stablehlo/reference/Api.cpp index 2c1f67a29f1..d448a27b11e 100644 --- a/stablehlo/reference/Api.cpp +++ b/stablehlo/reference/Api.cpp @@ -205,18 +205,18 @@ FailureOr> evalModule( FailureOr> evalModule( ModuleOp module, ArrayRef inputs, const InterpreterConfiguration &config) { - SmallVector valueInputs = llvm::to_vector( - llvm::map_range(inputs, [](DenseElementsAttr attr) -> InterpreterValue { + SmallVector valueInputs = llvm::map_to_vector( + inputs, [](DenseElementsAttr attr) -> InterpreterValue { return InterpreterValue(makeTensor(attr)); - })); + }); auto values = evalModule(module, valueInputs, config); if (failed(values)) return failure(); - SmallVector results = llvm::to_vector(llvm::map_range( + SmallVector results = llvm::map_to_vector( values.value(), [](InterpreterValue val) -> DenseElementsAttr { return makeDenseElementsAttr(val.getTensor()); - })); + }); return results; } diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index f564066ebe1..7f3c263b773 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -140,8 +140,8 @@ SmallVector callOp(ArrayRef inputs, symbolTableCollection.getSymbolTable(op->getParentOfType()); auto func = symbolTable.lookupNearestSymbolFrom( op, StringAttr::get(op->getContext(), funcName)); - SmallVector values = llvm::to_vector(llvm::map_range( - inputs, [](const Tensor &t) { return InterpreterValue(t); })); + SmallVector values = llvm::map_to_vector( + inputs, [](const Tensor &t) { return InterpreterValue(t); }); return eval(func.getBody(), values, fallback, process, nullptr); } @@ -1031,9 +1031,9 @@ Tensor allGatherOp(const Tensor &operand, int64_t allGatherDim, auto rendezvousResult = process->rendezvous(*processGroup, channelId, operand); - SmallVector groupOperands(llvm::map_range( + auto groupOperands = llvm::map_to_vector( *processGroup, - [&](const ProcessId &id) { return rendezvousResult.lookup(id); })); + [&](const ProcessId &id) { return rendezvousResult.lookup(id); }); return concatenateOp(groupOperands, allGatherDim, resultType); } diff --git a/stablehlo/reference/ProcessGrid.cpp b/stablehlo/reference/ProcessGrid.cpp index 72cc4b3ec89..b02bca0d696 100644 --- a/stablehlo/reference/ProcessGrid.cpp +++ b/stablehlo/reference/ProcessGrid.cpp @@ -73,8 +73,8 @@ Tensor RendezvousResult::lookup(ProcessId processId) const { } SmallVector RendezvousResult::getSortedTensors() const { - return llvm::to_vector( - llvm::map_range(result_, [](const auto &pair) { return pair.second; })); + return llvm::map_to_vector(result_, + [](const auto &pair) { return pair.second; }); } //===----------------------------------------------------------------------===// diff --git a/stablehlo/reference/Scope.cpp b/stablehlo/reference/Scope.cpp index 9d5a0cac1f4..f9319159715 100644 --- a/stablehlo/reference/Scope.cpp +++ b/stablehlo/reference/Scope.cpp @@ -80,8 +80,8 @@ InterpreterValue Scope::find(Value ssaValue) const { } SmallVector Scope::find(ValueRange ssaValues) const { - return llvm::to_vector( - llvm::map_range(ssaValues, [&](Value value) { return find(value); })); + return llvm::map_to_vector(ssaValues, + [&](Value value) { return find(value); }); } Tensor Scope::findTensor(Value ssaValue) const { @@ -89,8 +89,8 @@ Tensor Scope::findTensor(Value ssaValue) const { } SmallVector Scope::findTensors(ValueRange ssaValues) const { - return llvm::to_vector(llvm::map_range( - ssaValues, [&](Value value) { return find(value).getTensor(); })); + return llvm::map_to_vector( + ssaValues, [&](Value value) { return find(value).getTensor(); }); } Token Scope::findToken(Value ssaValue) const { @@ -98,8 +98,8 @@ Token Scope::findToken(Value ssaValue) const { } SmallVector Scope::findTokens(ValueRange ssaValues) const { - return llvm::to_vector(llvm::map_range( - ssaValues, [&](Value value) { return find(value).getToken(); })); + return llvm::map_to_vector( + ssaValues, [&](Value value) { return find(value).getToken(); }); } Tuple Scope::findTuple(Value ssaValue) const { diff --git a/stablehlo/reference/Tensor.cpp b/stablehlo/reference/Tensor.cpp index f2e1bfcec5f..60506029726 100644 --- a/stablehlo/reference/Tensor.cpp +++ b/stablehlo/reference/Tensor.cpp @@ -385,10 +385,10 @@ Tensor makeTensor(DenseElementsAttr attr) { if (elementType.isFloat8E4M3B11FNUZ() || elementType.isFloat8E4M3FN() || elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() || elementType.isFloat8E5M2FNUZ()) { - auto floatValues = llvm::to_vector(llvm::map_range( + auto floatValues = llvm::map_to_vector( attr.getValues(), [&](APFloat value) -> uint8_t { return value.bitcastToAPInt().getZExtValue(); - })); + }); // For f8E4M3B11FNUZ, f8E4M3FN, f8E4M3FNUZ, f8E5M2, and f8E5M2FNUZ // floating-point types, we use uint8_t as their storage type because there @@ -398,10 +398,10 @@ Tensor makeTensor(DenseElementsAttr attr) { } if (elementType.isF16() || elementType.isBF16()) { - auto floatValues = llvm::to_vector(llvm::map_range( + auto floatValues = llvm::map_to_vector( attr.getValues(), [&](APFloat value) -> uint16_t { return value.bitcastToAPInt().getZExtValue(); - })); + }); // For both f16 and bf16 floating-point types, we use uint16_t as their // storage type because there are no builtin types for those. @@ -411,85 +411,85 @@ Tensor makeTensor(DenseElementsAttr attr) { } if (elementType.isF32()) { - auto floatValues = llvm::to_vector(llvm::map_range( + auto floatValues = llvm::map_to_vector( attr.getValues(), - [&](APFloat value) -> float { return value.convertToFloat(); })); + [&](APFloat value) -> float { return value.convertToFloat(); }); return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( floatValues)); } if (elementType.isF64()) { - auto floatValues = llvm::to_vector(llvm::map_range( + auto floatValues = llvm::map_to_vector( attr.getValues(), - [&](APFloat value) -> double { return value.convertToDouble(); })); + [&](APFloat value) -> double { return value.convertToDouble(); }); return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( floatValues)); } // Handle signed integer types. if (elementType.isSignlessInteger(4) || elementType.isSignlessInteger(8)) { - auto intValues = llvm::to_vector(llvm::map_range( + auto intValues = llvm::map_to_vector( attr.getValues(), - [&](APInt value) -> int8_t { return value.getSExtValue(); })); + [&](APInt value) -> int8_t { return value.getSExtValue(); }); return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( intValues)); } if (elementType.isSignlessInteger(16)) { - auto intValues = llvm::to_vector(llvm::map_range( + auto intValues = llvm::map_to_vector( attr.getValues(), - [&](APInt value) -> int16_t { return value.getSExtValue(); })); + [&](APInt value) -> int16_t { return value.getSExtValue(); }); return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( intValues)); } if (elementType.isSignlessInteger(32)) { - auto intValues = llvm::to_vector(llvm::map_range( + auto intValues = llvm::map_to_vector( attr.getValues(), - [&](APInt value) -> int32_t { return value.getSExtValue(); })); + [&](APInt value) -> int32_t { return value.getSExtValue(); }); return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( intValues)); } if (elementType.isSignlessInteger(64)) { - auto intValues = llvm::to_vector(llvm::map_range( + auto intValues = llvm::map_to_vector( attr.getValues(), - [&](APInt value) -> int64_t { return value.getSExtValue(); })); + [&](APInt value) -> int64_t { return value.getSExtValue(); }); return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( intValues)); } // Handle unsigned integer types. if (elementType.isUnsignedInteger(4) || elementType.isUnsignedInteger(8)) { - auto intValues = llvm::to_vector(llvm::map_range( + auto intValues = llvm::map_to_vector( attr.getValues(), - [&](APInt value) -> uint8_t { return value.getZExtValue(); })); + [&](APInt value) -> uint8_t { return value.getZExtValue(); }); return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( intValues)); } if (elementType.isUnsignedInteger(16)) { - auto intValues = llvm::to_vector(llvm::map_range( + auto intValues = llvm::map_to_vector( attr.getValues(), - [&](APInt value) -> uint16_t { return value.getZExtValue(); })); + [&](APInt value) -> uint16_t { return value.getZExtValue(); }); return Tensor( type, HeapAsmResourceBlob::allocateAndCopyInferAlign(intValues)); } if (elementType.isUnsignedInteger(32)) { - auto intValues = llvm::to_vector(llvm::map_range( + auto intValues = llvm::map_to_vector( attr.getValues(), - [&](APInt value) -> uint32_t { return value.getZExtValue(); })); + [&](APInt value) -> uint32_t { return value.getZExtValue(); }); return Tensor( type, HeapAsmResourceBlob::allocateAndCopyInferAlign(intValues)); } if (elementType.isUnsignedInteger(64)) { - auto intValues = llvm::to_vector(llvm::map_range( + auto intValues = llvm::map_to_vector( attr.getValues(), - [&](APInt value) -> uint64_t { return value.getZExtValue(); })); + [&](APInt value) -> uint64_t { return value.getZExtValue(); }); return Tensor( type, HeapAsmResourceBlob::allocateAndCopyInferAlign(intValues)); @@ -497,9 +497,9 @@ Tensor makeTensor(DenseElementsAttr attr) { // Handle boolean type. if (isSupportedBooleanType(elementType)) { - auto boolValues = llvm::to_vector( - llvm::map_range(attr.getValues(), - [&](bool value) -> uint8_t { return value ? 1 : 0; })); + auto boolValues = llvm::map_to_vector( + attr.getValues(), + [&](bool value) -> uint8_t { return value ? 1 : 0; }); return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( boolValues)); } @@ -508,24 +508,24 @@ Tensor makeTensor(DenseElementsAttr attr) { if (isa(elementType)) { auto complexElemTy = cast(elementType).getElementType(); if (complexElemTy.isF32()) { - auto complexValues = llvm::to_vector(llvm::map_range( + auto complexValues = llvm::map_to_vector( attr.getValues>(), [&](std::complex value) -> std::complex { return std::complex(value.real().convertToFloat(), value.imag().convertToFloat()); - })); + }); return Tensor( type, HeapAsmResourceBlob::allocateAndCopyInferAlign>( complexValues)); } if (complexElemTy.isF64()) { - auto complexValues = llvm::to_vector(llvm::map_range( + auto complexValues = llvm::map_to_vector( attr.getValues>(), [&](std::complex value) -> std::complex { return std::complex(value.real().convertToDouble(), value.imag().convertToDouble()); - })); + }); return Tensor( type, HeapAsmResourceBlob::allocateAndCopyInferAlign>( diff --git a/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp b/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp index 24995a094cb..a44bad060a9 100644 --- a/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp @@ -295,9 +295,9 @@ struct ConvertConstShapeOpPattern if (!operandType) return rewriter.notifyMatchFailure(op, "expected ranked operand"); - llvm::SmallVector shape{ - llvm::map_range(op.getShape().getValues(), - [](int64_t val) { return static_cast(val); })}; + auto shape = llvm::map_to_vector( + op.getShape().getValues(), + [](int64_t val) { return static_cast(val); }); auto newConst = rewriter.create( op.getLoc(), DenseElementsAttr::get( diff --git a/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/StablehloAggressiveSimplification.cpp index e4a43943628..8daac7c027a 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplification.cpp +++ b/stablehlo/transforms/StablehloAggressiveSimplification.cpp @@ -491,9 +491,9 @@ struct BroadcastInDimOpCanon final // Eliminate redundant nested BroadcastInDim. if (auto definingOp = operand.getDefiningOp()) { - auto newIndices = llvm::to_vector( - llvm::map_range(definingOp.getBroadcastDimensions(), - [&dims](int64_t dim) { return dims[dim]; })); + auto newIndices = + llvm::map_to_vector(definingOp.getBroadcastDimensions(), + [&dims](int64_t dim) { return dims[dim]; }); rewriter.replaceOpWithNewOp( op, type, definingOp.getOperand(), newIndices); return success(); diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp index d9cb7c992bf..1c1ba528f21 100644 --- a/stablehlo/transforms/StablehloRefineShapes.cpp +++ b/stablehlo/transforms/StablehloRefineShapes.cpp @@ -462,10 +462,10 @@ struct EvalComputeReshapeShapeOpPattern dynShapeValues[unspecifiedDimIdx.value()] = numElems / dimProduct; const auto resultBitWidth = resultType.getElementTypeBitWidth(); - auto result = llvm::to_vector( - llvm::map_range(dynShapeValues, [&](int64_t value) -> APSInt { + auto result = + llvm::map_to_vector(dynShapeValues, [&](int64_t value) -> APSInt { return APSInt(APInt(resultBitWidth, value), false); - })); + }); rewriter.replaceOpWithNewOp(op, getTensorAttr(resultType, result));