Skip to content

Commit

Permalink
Use map_to_vector instead of map_range ∘ to_vector (openxla#2281)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlevesquedion authored May 3, 2024
1 parent c0132de commit e0c6d5d
Show file tree
Hide file tree
Showing 15 changed files with 108 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2284,9 +2284,10 @@ struct PadOpNegativePaddingConversion final
// Then slice according to the negative edge padding. Static shapes only for
// now.
if (!op.getType().hasStaticShape()) return failure();
SmallVector<OpFoldResult> sizes(llvm::map_range(
op.getType().getShape(),
[&](int64_t dim) { return rewriter.getIndexAttr(dim); }));
auto sizes = llvm::map_to_vector(op.getType().getShape(),
[&](int64_t dim) -> OpFoldResult {
return rewriter.getIndexAttr(dim);
});
SmallVector<OpFoldResult> strides(sliceStarts.size(),
rewriter.getIndexAttr(1));
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(op, pad, sliceStarts,
Expand Down Expand Up @@ -2506,10 +2507,10 @@ struct SetDimensionSizeConverter final
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(resultType.getRank(),
rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes(llvm::map_range(
resultType.getShape(), [&](int64_t dim) -> OpFoldResult {
return rewriter.getIndexAttr(dim);
}));
auto sizes = llvm::map_to_vector(resultType.getShape(),
[&](int64_t dim) -> OpFoldResult {
return rewriter.getIndexAttr(dim);
});
Value dimensionSize =
rewriter.create<tensor::ExtractOp>(loc, setDimensionSizeOp.getSize());
sizes[setDimensionSizeOp.getDimension()] =
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/BroadcastUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ Value computeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
Value computeNaryElementwiseBroadcastingResultExtents(Location loc,
ValueRange operands,
OpBuilder& builder) {
auto shapes = llvm::to_vector<4>(llvm::map_range(operands, [&](Value v) {
auto shapes = llvm::map_to_vector<4>(operands, [&](Value v) {
return builder.createOrFold<shape::ShapeOfOp>(loc, v);
}));
});

int64_t resultRank = 0;
for (Value s : shapes) {
Expand Down
16 changes: 8 additions & 8 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1820,12 +1820,12 @@ void ReduceOp::build(OpBuilder&, OperationState& odsState, ValueRange inputs,
odsState.attributes.getDictionary(odsState.getContext()), {},
odsState.regions);

SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(adaptor.getInputs().getTypes(),
[](Type t) { return cast<ShapedType>(t); })};
SmallVector<ShapedType> initValueTensorTypes{
llvm::map_range(adaptor.getInitValues().getTypes(),
[](Type t) { return cast<ShapedType>(t); })};
auto inputArgTensorTypes =
llvm::map_to_vector(adaptor.getInputs().getTypes(),
[](Type t) { return cast<ShapedType>(t); });
auto initValueTensorTypes =
llvm::map_to_vector(adaptor.getInitValues().getTypes(),
[](Type t) { return cast<ShapedType>(t); });

if (failed(hlo::verifyReduceOpInputsAndInferShape(
odsState.location, inputArgTensorTypes, dimensions, newDimensions,
Expand Down Expand Up @@ -3288,8 +3288,8 @@ ParseResult parseWindowAttributes(OpAsmParser& parser, Attribute& windowStrides,
int64Parser))
return failure();
if (attributeName == "reverse") {
auto boolVector = llvm::to_vector<4>(
llvm::map_range(values, [](int64_t v) { return v != 0; }));
auto boolVector =
llvm::map_to_vector<4>(values, [](int64_t v) { return v != 0; });
windowReversal =
DenseBoolArrayAttr::get(parser.getContext(), boolVector);
} else {
Expand Down
46 changes: 23 additions & 23 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -798,9 +798,9 @@ FailureOr<SmallVector<ShapedType>> getAccumulatorTypes(
}

Block& block = region.front();
return llvm::to_vector(
llvm::map_range(block.getTerminator()->getOperands(),
[&](Value v) { return cast<ShapedType>(v.getType()); }));
return llvm::map_to_vector(
block.getTerminator()->getOperands(),
[&](Value v) { return cast<ShapedType>(v.getType()); });
}

LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
Expand Down Expand Up @@ -1665,8 +1665,8 @@ LogicalResult inferAllReduceOp(
std::optional<Location> location, ValueRange operands, Region& computation,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
TypeRange inputTypes = operands.getTypes();
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return cast<ShapedType>(t); })};
auto inputArgTensorTypes = llvm::map_to_vector(
inputTypes, [](Type t) { return cast<ShapedType>(t); });
// all_reduce_c6, all_reduce_c7
auto accumulatorTypesOrErr = getAccumulatorTypes(location, computation);
if (failed(accumulatorTypesOrErr)) return failure();
Expand Down Expand Up @@ -2675,8 +2675,8 @@ LogicalResult inferReduceOp(
std::optional<Location> location, TypeRange inputTypes,
ArrayRef<int64_t> dimensions, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return cast<ShapedType>(t); })};
auto inputArgTensorTypes = llvm::map_to_vector(
inputTypes, [](Type t) { return cast<ShapedType>(t); });

SmallVector<int64_t> newDimensions;
Attribute encoding;
Expand All @@ -2703,10 +2703,10 @@ LogicalResult inferReduceWindowOp(
std::optional<ArrayRef<int64_t>> windowDilations,
std::optional<DenseIntElementsAttr> padding, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputTypes{llvm::map_range(
inputs.getTypes(), [](Type t) { return cast<ShapedType>(t); })};
SmallVector<ShapedType> initValueTypes{llvm::map_range(
initValues.getTypes(), [](Type t) { return cast<ShapedType>(t); })};
auto inputTypes = llvm::map_to_vector(
inputs.getTypes(), [](Type t) { return cast<ShapedType>(t); });
auto initValueTypes = llvm::map_to_vector(
initValues.getTypes(), [](Type t) { return cast<ShapedType>(t); });

SmallVector<int64_t> windowDims;
SmallVector<WindowDimension> inferredWindow;
Expand Down Expand Up @@ -4009,10 +4009,10 @@ LogicalResult verifyRecvOp(HloDialectInterface* dialect,
LogicalResult verifyReduceOp(std::optional<Location> location,
ValueRange inputs, ValueRange initValues,
ArrayRef<int64_t> dimensions, Region& body) {
SmallVector<ShapedType> inputTypes{llvm::map_range(
inputs.getTypes(), [](Type t) { return cast<ShapedType>(t); })};
SmallVector<ShapedType> initValueTypes{llvm::map_range(
initValues.getTypes(), [](Type t) { return cast<ShapedType>(t); })};
auto inputTypes = llvm::map_to_vector(
inputs.getTypes(), [](Type t) { return cast<ShapedType>(t); });
auto initValueTypes = llvm::map_to_vector(
initValues.getTypes(), [](Type t) { return cast<ShapedType>(t); });

SmallVector<int64_t> newDimensions;
Attribute encoding;
Expand Down Expand Up @@ -4136,10 +4136,10 @@ LogicalResult verifyReduceWindowOp(
std::optional<ArrayRef<int64_t>> baseDilations,
std::optional<ArrayRef<int64_t>> windowDilations,
std::optional<DenseIntElementsAttr> padding, Region& body) {
SmallVector<ShapedType> inputTypes{llvm::map_range(
inputs.getTypes(), [](Type t) { return cast<ShapedType>(t); })};
SmallVector<ShapedType> initValueTypes{llvm::map_range(
initValues.getTypes(), [](Type t) { return cast<ShapedType>(t); })};
auto inputTypes = llvm::map_to_vector(
inputs.getTypes(), [](Type t) { return cast<ShapedType>(t); });
auto initValueTypes = llvm::map_to_vector(
initValues.getTypes(), [](Type t) { return cast<ShapedType>(t); });

SmallVector<int64_t> windowDims;
SmallVector<WindowDimension> inferredWindow;
Expand Down Expand Up @@ -4277,10 +4277,10 @@ LogicalResult verifyScatterOp(std::optional<Location> location,
auto numOperands = inputs.size();
auto scatterIndicesType = cast<ShapedType>(scatterIndices.getType());

SmallVector<ShapedType, 1> operandTypes = llvm::to_vector(llvm::map_range(
inputs.getTypes(), [](Type type) { return cast<ShapedType>(type); }));
SmallVector<ShapedType, 1> updatesTypes = llvm::to_vector(llvm::map_range(
updates.getTypes(), [](Type type) { return cast<ShapedType>(type); }));
auto operandTypes = llvm::map_to_vector(
inputs.getTypes(), [](Type type) { return cast<ShapedType>(type); });
auto updatesTypes = llvm::map_to_vector(
updates.getTypes(), [](Type type) { return cast<ShapedType>(type); });
bool scatterIndicesTypeRanked = isa<RankedTensorType>(scatterIndicesType);

// scatter_c1
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/dialect/VhloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ Type getVhloElementType(Type tensorType) {

bool checkIfOperandAndResultElementTypesMatch(TypeRange operandTypes,
TypeRange resultTypes) {
SmallVector<Type> inputElementTypes{llvm::map_range(
operandTypes, [](Type t) { return getVhloElementType(t); })};
SmallVector<Type> resultElementTypes{llvm::map_range(
resultTypes, [](Type t) { return getVhloElementType(t); })};
auto inputElementTypes = llvm::map_to_vector(
operandTypes, [](Type t) { return getVhloElementType(t); });
auto resultElementTypes = llvm::map_to_vector(
resultTypes, [](Type t) { return getVhloElementType(t); });

return llvm::all_of(
llvm::zip(inputElementTypes, resultElementTypes),
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/dialect/VhloTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ void VhloTypeConverter::addBuiltinToVhloConversions() {
Type convertedStorageType = convertType(type.getStorageType());
Type convertedExpressedType = convertType(type.getExpressedType());
if (!convertedStorageType || !convertedExpressedType) return {};
SmallVector<APFloat> scales = llvm::to_vector(llvm::map_range(
type.getScales(), [](double scale) { return APFloat(scale); }));
auto scales = llvm::map_to_vector(
type.getScales(), [](double scale) { return APFloat(scale); });
return vhlo::UniformQuantizedPerAxisV1Type::get(
type.getContext(), type.getFlags(), convertedStorageType,
convertedExpressedType, type.getQuantizedDimension(), scales,
Expand Down Expand Up @@ -239,9 +239,9 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
Type convertedStorageType = convertType(type.getStorageType());
Type convertedExpressedType = convertType(type.getExpressedType());
if (!convertedStorageType || !convertedExpressedType) return {};
SmallVector<double> scales = llvm::to_vector(llvm::map_range(
auto scales = llvm::map_to_vector(
type.getScales(),
[](const APFloat& scale) { return scale.convertToDouble(); }));
[](const APFloat& scale) { return scale.convertToDouble(); });
return quant::UniformQuantizedPerAxisType::get(
type.getFlags(), convertedStorageType, convertedExpressedType, scales,
type.getZeroPoints(), type.getQuantizedDimension(),
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/VhloTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def VHLO_QuantizationScalesV1 : ArrayRefParameter<"::llvm::APFloat", "array of d
return $_parser.parseFloat(scales.emplace_back());
});
if(failed(parseResult)) return failure();
return llvm::to_vector(llvm::map_range(
scales, [](double scale) { return APFloat(scale); }));
return llvm::map_to_vector(
scales, [](double scale) { return APFloat(scale); });
}()
}];
let printer = [{
Expand Down
10 changes: 5 additions & 5 deletions stablehlo/reference/Api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,18 @@ FailureOr<SmallVector<InterpreterValue>> evalModule(
FailureOr<SmallVector<DenseElementsAttr>> evalModule(
ModuleOp module, ArrayRef<DenseElementsAttr> inputs,
const InterpreterConfiguration &config) {
SmallVector<InterpreterValue> valueInputs = llvm::to_vector(
llvm::map_range(inputs, [](DenseElementsAttr attr) -> InterpreterValue {
SmallVector<InterpreterValue> valueInputs = llvm::map_to_vector(
inputs, [](DenseElementsAttr attr) -> InterpreterValue {
return InterpreterValue(makeTensor(attr));
}));
});

auto values = evalModule(module, valueInputs, config);
if (failed(values)) return failure();

SmallVector<DenseElementsAttr> results = llvm::to_vector(llvm::map_range(
SmallVector<DenseElementsAttr> results = llvm::map_to_vector(
values.value(), [](InterpreterValue val) -> DenseElementsAttr {
return makeDenseElementsAttr(val.getTensor());
}));
});

return results;
}
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ SmallVector<InterpreterValue> callOp(ArrayRef<Tensor> inputs,
symbolTableCollection.getSymbolTable(op->getParentOfType<ModuleOp>());
auto func = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
op, StringAttr::get(op->getContext(), funcName));
SmallVector<InterpreterValue> values = llvm::to_vector(llvm::map_range(
inputs, [](const Tensor &t) { return InterpreterValue(t); }));
SmallVector<InterpreterValue> values = llvm::map_to_vector(
inputs, [](const Tensor &t) { return InterpreterValue(t); });
return eval(func.getBody(), values, fallback, process, nullptr);
}

Expand Down Expand Up @@ -1031,9 +1031,9 @@ Tensor allGatherOp(const Tensor &operand, int64_t allGatherDim,

auto rendezvousResult =
process->rendezvous(*processGroup, channelId, operand);
SmallVector<Tensor> groupOperands(llvm::map_range(
auto groupOperands = llvm::map_to_vector(
*processGroup,
[&](const ProcessId &id) { return rendezvousResult.lookup(id); }));
[&](const ProcessId &id) { return rendezvousResult.lookup(id); });

return concatenateOp(groupOperands, allGatherDim, resultType);
}
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/reference/ProcessGrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ Tensor RendezvousResult::lookup(ProcessId processId) const {
}

SmallVector<Tensor> RendezvousResult::getSortedTensors() const {
return llvm::to_vector(
llvm::map_range(result_, [](const auto &pair) { return pair.second; }));
return llvm::map_to_vector(result_,
[](const auto &pair) { return pair.second; });
}

//===----------------------------------------------------------------------===//
Expand Down
12 changes: 6 additions & 6 deletions stablehlo/reference/Scope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,26 @@ InterpreterValue Scope::find(Value ssaValue) const {
}

SmallVector<InterpreterValue> Scope::find(ValueRange ssaValues) const {
return llvm::to_vector(
llvm::map_range(ssaValues, [&](Value value) { return find(value); }));
return llvm::map_to_vector(ssaValues,
[&](Value value) { return find(value); });
}

Tensor Scope::findTensor(Value ssaValue) const {
return find(ssaValue).getTensor();
}

SmallVector<Tensor> Scope::findTensors(ValueRange ssaValues) const {
return llvm::to_vector(llvm::map_range(
ssaValues, [&](Value value) { return find(value).getTensor(); }));
return llvm::map_to_vector(
ssaValues, [&](Value value) { return find(value).getTensor(); });
}

Token Scope::findToken(Value ssaValue) const {
return find(ssaValue).getToken();
}

SmallVector<Token> Scope::findTokens(ValueRange ssaValues) const {
return llvm::to_vector(llvm::map_range(
ssaValues, [&](Value value) { return find(value).getToken(); }));
return llvm::map_to_vector(
ssaValues, [&](Value value) { return find(value).getToken(); });
}

Tuple Scope::findTuple(Value ssaValue) const {
Expand Down
Loading

0 comments on commit e0c6d5d

Please sign in to comment.