From 8d0da9559cf57778a4b6a9676316f3258179d7b4 Mon Sep 17 00:00:00 2001 From: Kavitha Date: Fri, 20 Dec 2024 21:59:27 +0530 Subject: [PATCH] Unary broadcast lowering (#995) --- .../ConvertVectorToXsmm.cpp | 74 +++++-- .../ConvertVectorToXsmmPDL.pdll | 18 ++ lib/TPP/DefaultTppPasses.cpp | 12 +- lib/TPP/Dialect/Xsmm/XsmmUtils.cpp | 32 ++- lib/TPP/Transforms/Utils/BuilderUtils.cpp | 26 ++- .../VectorToXsmm/vector-to-identity.mlir | 195 ++++++++++++++++++ test/Integration/broadcast-2d.mlir | 33 +++ test/Integration/broadcast-row-1d.mlir | 26 +++ test/Integration/broadcast-row.mlir | 25 +++ test/Integration/broadcast-transpose.mlir | 38 ++++ 10 files changed, 437 insertions(+), 42 deletions(-) create mode 100644 test/Conversion/VectorToXsmm/vector-to-identity.mlir create mode 100644 test/Integration/broadcast-2d.mlir create mode 100644 test/Integration/broadcast-row-1d.mlir create mode 100644 test/Integration/broadcast-row.mlir create mode 100644 test/Integration/broadcast-transpose.mlir diff --git a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp index cf516e690..04c1f6fdb 100644 --- a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp @@ -29,16 +29,15 @@ using namespace mlir::xsmm; #define DEBUG_TYPE "convert-vector-to-xsmm" static pair -getTransposeXSMMCalls(PatternRewriter &rewriter, xsmm::UnaryInfo &unaryInfo, - Type outputType, xsmm::UnaryKind unaryKind, - Operation *input, Operation *output, - Operation *transposeOp) { +getUnaryXSMMCalls(PatternRewriter &rewriter, xsmm::UnaryInfo &unaryInfo, + Type outputType, xsmm::UnaryKind unaryKind, Operation *input, + Operation *output, Operation *unaryOp, int64_t unaryFlag) { std::string dispatchName = "xsmm_unary_dispatch"; std::string invokeName = "xsmm_unary_invoke"; - Location loc = transposeOp->getLoc(); + Location loc = unaryOp->getLoc(); auto dtype = - xsmm::utils::getDataType(rewriter, transposeOp->getOperand(0).getType()); + xsmm::utils::getDataType(rewriter, unaryOp->getOperand(0).getType()); SmallVector dispatchOperands; xsmm::UnaryKindAttr kind = @@ -49,14 +48,12 @@ getTransposeXSMMCalls(PatternRewriter &rewriter, xsmm::UnaryInfo &unaryInfo, dispatchOperands.append(SmallVector{ unaryInfo.m, unaryInfo.n, unaryInfo.ldi, unaryInfo.ldo}); - IntegerAttr unaryFlags = dyn_cast( - xsmm::UnaryFlagsAttr::get(rewriter.getContext(), xsmm::UnaryFlags::NONE)); - dispatchOperands.push_back(unaryFlags.getInt()); + dispatchOperands.push_back(unaryFlag); auto dispatchCall = xsmm::utils::buildXsmmCall( rewriter, xsmm::utils::XsmmCallType::DISPATCH, loc, dtype, dispatchOperands, IntegerType::get(rewriter.getContext(), 64), - SymbolRefAttr::get(transposeOp->getContext(), dispatchName), transposeOp, + SymbolRefAttr::get(unaryOp->getContext(), dispatchName), unaryOp, nullptr); SmallVector operandRange{ dyn_cast(dtype).getInt(), @@ -65,8 +62,8 @@ getTransposeXSMMCalls(PatternRewriter &rewriter, xsmm::UnaryInfo &unaryInfo, input->getOperand(0), output->getOperand(1)}; auto invokeCall = xsmm::utils::buildXsmmCall( rewriter, xsmm::utils::XsmmCallType::INVOKE, loc, dtype, operandRange, - TypeRange(), SymbolRefAttr::get(transposeOp->getContext(), invokeName), - transposeOp, output); + TypeRange(), SymbolRefAttr::get(unaryOp->getContext(), invokeName), + unaryOp, output); return std::make_pair(&*dispatchCall, &*invokeCall); } @@ -107,8 +104,11 @@ convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp, } else { std::swap(unaryInfo.m, unaryInfo.n); } - return getTransposeXSMMCalls(rewriter, unaryInfo, outputType, opType, input, - output, transposeOp); + + IntegerAttr unaryFlags = dyn_cast( + xsmm::UnaryFlagsAttr::get(rewriter.getContext(), xsmm::UnaryFlags::NONE)); + return getUnaryXSMMCalls(rewriter, unaryInfo, outputType, opType, input, + output, transposeOp, unaryFlags.getInt()); } static LogicalResult validateTransposeOp(PatternRewriter &rewriter, @@ -163,11 +163,57 @@ static LogicalResult validateTransposeOp(PatternRewriter &rewriter, return success(); } +static std::pair +convertBroadcast(PatternRewriter &rewriter, Operation *broadcastOp, + Operation *input, Operation *output) { + LLVM_DEBUG(llvm::dbgs() << "convertBroadcast\n"); + auto unaryFlag = xsmm::utils::getUnaryFlags(input->getOperand(0).getType(), + output->getOperand(1).getType()); + auto inputMemRefType = dyn_cast(input->getOperand(0).getType()); + auto outputMemRefType = dyn_cast(output->getOperand(1).getType()); + auto inputVectorType = dyn_cast(input->getResult(0).getType()); + auto outputVectorType = dyn_cast(output->getOperand(0).getType()); + auto unaryInfo = *xsmm::utils::getVectorUnaryInfo( + outputMemRefType, outputMemRefType, inputMemRefType, inputVectorType, + outputVectorType, *unaryFlag); + if (*unaryFlag == UnaryFlags::BCAST_ROW) + std::swap(unaryInfo.ldi, unaryInfo.ldo); + + IntegerAttr unaryFlags = dyn_cast( + xsmm::UnaryFlagsAttr::get(rewriter.getContext(), *unaryFlag)); + + return getUnaryXSMMCalls(rewriter, unaryInfo, outputVectorType, + xsmm::UnaryKind::IDENTITY, input, output, + broadcastOp, unaryFlags.getInt()); +} + +static LogicalResult validateBroadcastOp(PatternRewriter &rewriter, + Operation *broadcastOp, + Operation *input, Operation *output) { + LLVM_DEBUG(llvm::dbgs() << "validateBroadcastOp\n"); + auto unaryFlag = xsmm::utils::getUnaryFlags(input->getOperand(0).getType(), + output->getOperand(1).getType()); + auto inputMemRefType = dyn_cast(input->getOperand(0).getType()); + auto outputMemRefType = dyn_cast(output->getOperand(1).getType()); + auto inputVectorType = dyn_cast(input->getResult(0).getType()); + auto outputVectorType = dyn_cast(output->getOperand(0).getType()); + auto unaryInfo = xsmm::utils::getVectorUnaryInfo( + outputMemRefType, inputMemRefType, outputMemRefType, inputVectorType, + outputVectorType, *unaryFlag); + if (failed(unaryInfo)) + return failure(); + return success(); +} + static void registerNativeRewrite(RewritePatternSet &patterns) { patterns.getPDLPatterns().registerRewriteFunction("ConvertTranspose", convertTransposeOp); patterns.getPDLPatterns().registerConstraintFunction("ValidateTranspose", validateTransposeOp); + patterns.getPDLPatterns().registerRewriteFunction("ConvertBroadcast", + convertBroadcast); + patterns.getPDLPatterns().registerConstraintFunction("ValidateBroadcast", + validateBroadcastOp); } namespace mlir { diff --git a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmmPDL.pdll b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmmPDL.pdll index 202786c2c..004bc6cff 100644 --- a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmmPDL.pdll +++ b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmmPDL.pdll @@ -9,6 +9,11 @@ Rewrite ConvertTranspose(op:Op, input0:Op, input0:Op, output:Op, outputType:TypeRange); + +Rewrite ConvertBroadcast(op:Op, input:Op, output:Op)->(dispatch:Op, invoke:Op); + +Constraint ValidateBroadcast(op:Op, input0:Op, output:Op); + Pattern ConvertTransposePattern{ let input0 = op(alloc0:Value, indices0:ValueRange, const0:Value, constIndex:ValueRange)->(output:TypeRange); let transpose = op(input0)->(transposeOutput0:Type); @@ -20,3 +25,16 @@ Pattern ConvertTransposePattern{ erase output0; }; } + +Pattern ConvertBroadcastPattern{ + let input0 = op(alloc0:Value, indices0:ValueRange, const0:Value, constIndex:ValueRange)->(output:TypeRange); + let broadcast = op(input0)->(broadcastOutput0:Type); + let output0 = op(broadcast, alloc1:Value, outindices:ValueRange, constIndex2:ValueRange)->(typeRange:TypeRange); + ValidateBroadcast(broadcast, input0, output0); + rewrite broadcast with{ + let replacement = ConvertBroadcast(broadcast, input0, output0); + replace broadcast with (replacement.dispatch, replacement.invoke); + erase output0; + }; + +} diff --git a/lib/TPP/DefaultTppPasses.cpp b/lib/TPP/DefaultTppPasses.cpp index afe8a3a51..fc900f4f6 100644 --- a/lib/TPP/DefaultTppPasses.cpp +++ b/lib/TPP/DefaultTppPasses.cpp @@ -94,6 +94,7 @@ struct DefaultTppPasses } if (vectorToXSMM) { skipOperations.clear(); + skipOperations.push_back("unary"); skipOperations.push_back("transpose"); skipOperations.push_back("vnni"); } @@ -143,11 +144,11 @@ struct DefaultTppPasses pm.addNestedPass(createLoopInvariantCodeMotionPass()); pm.addNestedPass(createVectorizationPass()); - //Please note, canonicalizer should be after hoisting pass because - //it fuses outer tiling loops and it results in no pattern - //matching for hoisting pass. Moved inside VectorToKernel Path. - - if (vectorToXSMM) { + // Please note, canonicalizer should be after hoisting pass because + // it fuses outer tiling loops and it results in no pattern + // matching for hoisting pass. Moved inside VectorToKernel Path. + + if (vectorToXSMM) { pm.addPass(createVectorToXSMM()); } if (vectorToKernel) { @@ -193,4 +194,3 @@ struct DefaultTppPasses }; } // namespace - diff --git a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp index 7a070f928..70f915329 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp @@ -140,23 +140,31 @@ getVectorUnaryInfo(MemRefType shapedType, MemRefType inputType, UnaryInfo unaryInfo; - unaryInfo.m = shapedType.getShape()[0]; - unaryInfo.n = shapedType.getShape()[1]; + unaryInfo.m = 1; + for (unsigned i = 0; i < shapedType.getShape().size() - 1; i++) { + unaryInfo.m *= shapedType.getShape()[i]; + } + unaryInfo.n = shapedType.getShape()[shapedType.getShape().size() - 1]; - auto getStrideLoc = [&](MemRefType inputType) -> FailureOr { + auto getStrideLoc = [&](MemRefType memrefType) -> FailureOr { int64_t strideAtLoc; SmallVector strides; int64_t offset; - if (failed(getStridesAndOffset(inputType, strides, offset))) + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + return failure(); + } + if (strides.empty()) { return failure(); + } strideAtLoc = strides[0]; return strideAtLoc; }; auto strideLdi = getStrideLoc(inputType); - if (failed(strideLdi)) + if (failed(strideLdi)) { return failure(); + } unaryInfo.ldi = *strideLdi; int ldo = 1; // If we are broascasting a row into cols, the leading @@ -166,12 +174,16 @@ getVectorUnaryInfo(MemRefType shapedType, MemRefType inputType, ldo = 1; // If we are broascasting a col into rows, the leading // dimension is the size of the tensor. - else if (inputFlag == UnaryFlags::BCAST_COL) - ldo = inputVectorType.getShape()[0]; - else { + else if (inputFlag == UnaryFlags::BCAST_COL) { + if (inputVectorType.getShape().size() == 0) + ldo = 1; + else + ldo = inputVectorType.getShape()[0]; + } else { auto strideLdo = getStrideLoc(outputType); - if (failed(strideLdo)) + if (failed(strideLdo)) { return failure(); + } ldo = *strideLdo; } unaryInfo.ldo = ldo; @@ -241,8 +253,6 @@ FailureOr getBinaryInfo(Value lhs, BinaryFlags lhsFlag, Value rhs, FailureOr getUnaryFlags(Type inputType, Type outputType) { assert(isa(outputType) && "expect shaped type on output"); - assert(cast(outputType).getRank() == 2 && - "expect rank 2 on output"); if (!isa(inputType) || cast(inputType).getRank() == 0) { diff --git a/lib/TPP/Transforms/Utils/BuilderUtils.cpp b/lib/TPP/Transforms/Utils/BuilderUtils.cpp index 56012c0b0..fc230591b 100644 --- a/lib/TPP/Transforms/Utils/BuilderUtils.cpp +++ b/lib/TPP/Transforms/Utils/BuilderUtils.cpp @@ -36,18 +36,22 @@ arith::ConstantOp getConstant(OpBuilder &builder, Type type, ValueT value) { func::FuncOp createFunction(OpBuilder &builder, ModuleOp module, StringRef name, TypeRange args, TypeRange ret, bool createBody) { - auto unkLoc = builder.getUnknownLoc(); - auto funcType = FunctionType::get(builder.getContext(), args, ret); - auto func = func::FuncOp::create(unkLoc, name, funcType); - func.setVisibility(SymbolTable::Visibility::Private); - if (createBody) { - func.setVisibility(SymbolTable::Visibility::Public); - auto *entryBlock = func.addEntryBlock(); - builder.setInsertionPointToEnd(entryBlock); + auto oper = module.lookupSymbol(name); + if (oper) + return dyn_cast(oper); + else { + auto unkLoc = builder.getUnknownLoc(); + auto funcType = FunctionType::get(builder.getContext(), args, ret); + auto func = func::FuncOp::create(unkLoc, name, funcType); + func.setVisibility(SymbolTable::Visibility::Private); + if (createBody) { + func.setVisibility(SymbolTable::Visibility::Public); + auto *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToEnd(entryBlock); + } + module.push_back(func); + return func; } - module.push_back(func); - - return func; } Value getConstInt(OpBuilder &builder, int value, int width) { diff --git a/test/Conversion/VectorToXsmm/vector-to-identity.mlir b/test/Conversion/VectorToXsmm/vector-to-identity.mlir new file mode 100644 index 000000000..69651508c --- /dev/null +++ b/test/Conversion/VectorToXsmm/vector-to-identity.mlir @@ -0,0 +1,195 @@ +// RUN: tpp-opt --vector-to-xsmm %s --split-input-file | FileCheck %s + +func.func @identity(%arg0: memref<512xf32>, %arg1: memref<128x512xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0], %cst {in_bounds = [true]} : memref<512xf32>, vector<512xf32> + %1 = vector.broadcast %0 : vector<512xf32> to vector<128x512xf32> + vector.transfer_write %1, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<128x512xf32>, memref<128x512xf32> + return +} + +// CHECK-LABEL: func.func @identity( +// CHECK: %[[arg0:.*]]: memref<512xf32>, %[[arg1:.*]]: memref<128x512xf32>) { +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[c128_i64:.*]] = arith.constant 128 : i64 +// CHECK-DAG: %[[c512_i64:.*]] = arith.constant 512 : i64 +// CHECK-DAG: %[[c4_i64:.*]] = arith.constant 4 : i64 +// CHECK: %[[dispatch:.*]] = call @xsmm_unary_dispatch(%[[c1_i64]], %[[c1_i64]], %[[c128_i64]], %[[c512_i64]], %[[c512_i64]], %[[c512_i64]], %[[c4_i64]]) +// CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[arg0]] +// CHECK: %[[indexcast:.*]] = arith.index_cast %[[intptr]] +// CHECK: %[[inttoptr:.*]] = llvm.inttoptr %[[indexcast]] +// CHECK: %[[intptr_0:.*]] = memref.extract_aligned_pointer_as_index %[[arg1]] +// CHECK: %[[indexcast3:.*]] = arith.index_cast %[[intptr_0]] +// CHECK: %[[inttoptr4:.*]] = llvm.inttoptr %[[indexcast3]] +// CHECK: call @xsmm_unary_invoke(%[[c1_i64]], %[[dispatch]], %[[inttoptr]], %[[c0]], %[[inttoptr4]], %[[c0]]) + +// ----- + +func.func @identity_2d(%arg0: memref<512xf32>, %arg1: memref<128x4x512xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0], %cst {in_bounds = [true]} : memref<512xf32>, vector<512xf32> + %1 = vector.broadcast %0 : vector<512xf32> to vector<128x4x512xf32> + vector.transfer_write %1, %arg1[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<128x4x512xf32>, memref<128x4x512xf32> + return +} +// CHECK-LABEL: func.func @identity_2d( +// CHECK: %[[arg0:.*]]: memref<512xf32>, %[[arg1:.*]]: memref<128x4x512xf32>) { +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[c512_i64:.*]] = arith.constant 512 : i64 +// CHECK-DAG: %[[c4_i64:.*]] = arith.constant 4 : i64 +// CHECK-DAG: %[[c2048_i64:.*]] = arith.constant 2048 : i64 +// CHECK: %[[dispatch:.*]] = call @xsmm_unary_dispatch(%[[c1_i64]], %[[c1_i64]], %[[c512_i64]], %[[c512_i64]], %[[c2048_i64]], %[[c512_i64]], %[[c4_i64]]) +// CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[arg0]] +// CHECK: %[[indexcast:.*]] = arith.index_cast %[[intptr]] +// CHECK: %[[inttoptr:.*]] = llvm.inttoptr %[[indexcast]] +// CHECK: %[[intptr_0:.*]] = memref.extract_aligned_pointer_as_index %[[arg1]] +// CHECK: %[[indexcast3:.*]] = arith.index_cast %[[intptr_0]] +// CHECK: %[[inttoptr4:.*]] = llvm.inttoptr %[[indexcast3]] +// CHECK: call @xsmm_unary_invoke(%[[c1_i64]], %[[dispatch]], %[[inttoptr]], %[[c0]], %[[inttoptr4]], %[[c0]]) + +// ----- + +func.func @identity_subview_copy(%arg0: memref<128x1xf32>, %arg1: memref<512x128xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %subview = memref.subview %arg0[0, 0] [128, 1] [1, 1] : memref<128x1xf32> to memref<128xf32, strided<[1]>> + %0 = vector.transfer_read %subview[%c0], %cst {in_bounds = [true]} : memref<128xf32, strided<[1]>>, vector<128xf32> + %1 = vector.broadcast %0 : vector<128xf32> to vector<512x128xf32> + vector.transfer_write %1, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<512x128xf32>, memref<512x128xf32> + return +} + +// CHECK-LABEL: func.func @identity_subview_copy( +// CHECK: %[[arg0:.*]]: memref<128x1xf32>, %[[arg1:.*]]: memref<512x128xf32>) { +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[c128_i64:.*]] = arith.constant 128 : i64 +// CHECK-DAG: %[[c512_i64:.*]] = arith.constant 512 : i64 +// CHECK-DAG: %[[c4_i64:.*]] = arith.constant 4 : i64 +// CHECK: %[[dispatch:.*]] = call @xsmm_unary_dispatch(%[[c1_i64]], %[[c1_i64]], %[[c512_i64]], %[[c128_i64]], %[[c128_i64]], %[[c128_i64]], %[[c4_i64]]) +// CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, 0] [128, 1] [1, 1] +// CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[subview]] +// CHECK: %[[indexcast:.*]] = arith.index_cast %[[intptr]] +// CHECK: %[[inttoptr:.*]] = llvm.inttoptr %[[indexcast]] +// CHECK: %[[intptr_0:.*]] = memref.extract_aligned_pointer_as_index %[[arg1]] +// CHECK: %[[indexcast3:.*]] = arith.index_cast %[[intptr_0]] +// CHECK: %[[inttoptr4:.*]] = llvm.inttoptr %[[indexcast3]] +// CHECK: call @xsmm_unary_invoke(%[[c1_i64]], %[[dispatch]], %[[inttoptr]], %[[c0]], %[[inttoptr4]], %[[c0]]) + +// ----- + +func.func @identity_2d_bcast_to_3d(%arg0: memref<128x256xf32>, %arg1: memref<512x128x256xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x256xf32>, vector<128x256xf32> + %1 = vector.broadcast %0 : vector<128x256xf32> to vector<512x128x256xf32> + vector.transfer_write %1, %arg1[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<512x128x256xf32>, memref<512x128x256xf32> + return +} + +// CHECK-LABEL: func.func @identity_2d_bcast_to_3d( +// CHECK: %[[arg0:.*]]: memref<128x256xf32>, %[[arg1:.*]]: memref<512x128x256xf32>) { +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[c65536_i64:.*]] = arith.constant 65536 : i64 +// CHECK-DAG: %[[c256_i64:.*]] = arith.constant 256 : i64 +// CHECK-DAG: %[[c32768_i64:.*]] = arith.constant 32768 : i64 +// CHECK-DAG: %[[c128_i64:.*]] = arith.constant 128 : i64 +// CHECK-DAG: %[[c4_i64:.*]] = arith.constant 4 : i64 +// CHECK: %[[dispatch:.*]] = call @xsmm_unary_dispatch(%[[c1_i64]], %[[c1_i64]], %[[c65536_i64]], %[[c256_i64]], %[[c32768_i64]], %[[c128_i64]], %[[c4_i64]]) +// CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[arg0]] +// CHECK: %[[indexcast:.*]] = arith.index_cast %[[intptr]] +// CHECK: %[[inttoptr:.*]] = llvm.inttoptr %[[indexcast]] +// CHECK: %[[intptr_0:.*]] = memref.extract_aligned_pointer_as_index %[[arg1]] +// CHECK: %[[indexcast3:.*]] = arith.index_cast %[[intptr_0]] +// CHECK: %[[inttoptr4:.*]] = llvm.inttoptr %[[indexcast3]] +// CHECK: call @xsmm_unary_invoke(%[[c1_i64]], %[[dispatch]], %[[inttoptr]], %[[c0]], %[[inttoptr4]], %[[c0]]) + +// ----- + +func.func @identity_broadcast_exact_dim_match(%arg0: memref<4x1xf32>, %arg1: memref<4x2xf32>) -> memref<4x2xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x1xf32>, vector<4x1xf32> + %1 = vector.broadcast %0 : vector<4x1xf32> to vector<4x2xf32> + vector.transfer_write %1, %arg1[%c0, %c0] {in_bounds = [true, true]} + : vector<4x2xf32>, memref<4x2xf32> + return %arg1 : memref<4x2xf32> +} + +// CHECK-LABEL: func.func @identity_broadcast_exact_dim_match( +// CHECK: %[[arg0:.*]]: memref<4x1xf32>, %[[arg1:.*]]: memref<4x2xf32>) -> memref<4x2xf32> { +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[c4_i64:.*]] = arith.constant 4 : i64 +// CHECK-DAG: %[[c2_i64:.*]] = arith.constant 2 : i64 +// CHECK: %[[dispatch:.*]] = call @xsmm_unary_dispatch(%[[c1_i64]], %[[c1_i64]], %[[c4_i64]], %[[c2_i64]], %[[c1_i64]], %[[c2_i64]], %[[c2_i64]]) +// CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[arg0]] +// CHECK: %[[indexcast:.*]] = arith.index_cast %[[intptr]] +// CHECK: %[[inttoptr:.*]] = llvm.inttoptr %[[indexcast]] +// CHECK: %[[intptr_0:.*]] = memref.extract_aligned_pointer_as_index %[[arg1]] +// CHECK: %[[indexcast3:.*]] = arith.index_cast %[[intptr_0]] +// CHECK: %[[inttoptr4:.*]] = llvm.inttoptr %[[indexcast3]] +// CHECK: call @xsmm_unary_invoke(%[[c1_i64]], %[[dispatch]], %[[inttoptr]], %[[c0]], %[[inttoptr4]], %[[c0]]) + +// ----- + +func.func @identity_broadcast_same_rank(%arg0: memref<256xf32>, %arg1: memref<256xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0], %cst {in_bounds = [true]} : memref<256xf32>, vector<256xf32> + %1 = vector.broadcast %0 : vector<256xf32> to vector<256xf32> + vector.transfer_write %1, %arg1[%c0] {in_bounds = [true]} : vector<256xf32>, memref<256xf32> + return +} + +// CHECK-LABEL: func.func @identity_broadcast_same_rank( +// CHECK: %[[arg0:.*]]: memref<256xf32>, %[[arg1:.*]]: memref<256xf32>) { +// CHECK-NOT: call @xsmm_unary_dispatch +// CHECK-NOT: call @xsmm_unary_invoke + +// ----- + +func.func @identity_empty_buffer(%arg0: memref>, %arg1: memref<6x9xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[], %cst {in_bounds = []} : memref>, vector + %1 = vector.broadcast %0 : vector to vector<6x9xf32> + vector.transfer_write %1, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<6x9xf32>, memref<6x9xf32> + return +} + +// CHECK-LABEL: func.func @identity_empty_buffer( +// CHECK: %[[arg0:.*]]: memref>, %[[arg1:.*]]: memref<6x9xf32>) { +// CHECK-NOT: call @xsmm_unary_dispatch +// CHECK-NOT: call @xsmm_unary_invoke + +// ----- + +func.func @identity_strided_buffer(%arg0: memref<6x1xf32, strided<[6, 1]>>, %arg1: memref<6x9xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cst{in_bounds=[true, true]} : memref<6x1xf32, strided<[6, 1]>>, vector<6x1xf32> + %1 = vector.broadcast %0 : vector<6x1xf32> to vector<6x9xf32> + vector.transfer_write %1, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<6x9xf32>, memref<6x9xf32> + return +} +// CHECK-LABEL: func.func @identity_strided_buffer( +// CHECK: %[[arg0:.*]]: memref<6x1xf32, strided<[6, 1]>>, %[[arg1:.*]]: memref<6x9xf32>) { +// CHECK-DAG: %[[c6_i64:.*]] = arith.constant 6 : i64 +// CHECK-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[c9_i64:.*]] = arith.constant 9 : i64 +// CHECK-DAG: %[[c2_i64:.*]] = arith.constant 2 : i64 +// CHECK: %[[dispatch:.*]] = call @xsmm_unary_dispatch(%[[c1_i64]], %[[c1_i64]], %[[c6_i64]], %[[c9_i64]], %[[c1_i64]], %[[c9_i64]], %[[c2_i64]]) +// CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[arg0]] +// CHECK: %[[indexcast:.*]] = arith.index_cast %[[intptr]] +// CHECK: %[[inttoptr:.*]] = llvm.inttoptr %[[indexcast]] +// CHECK: %[[intptr_0:.*]] = memref.extract_aligned_pointer_as_index %[[arg1]] +// CHECK: %[[indexcast3:.*]] = arith.index_cast %[[intptr_0]] +// CHECK: %[[inttoptr4:.*]] = llvm.inttoptr %[[indexcast3]] +// CHECK: call @xsmm_unary_invoke(%[[c1_i64]], %[[dispatch]], %[[inttoptr]], %[[c0]], %[[inttoptr4]], %[[c0]]) + diff --git a/test/Integration/broadcast-2d.mlir b/test/Integration/broadcast-2d.mlir new file mode 100644 index 000000000..58ae43dd5 --- /dev/null +++ b/test/Integration/broadcast-2d.mlir @@ -0,0 +1,33 @@ +// RUN: tpp-run --vector-to-XSMM %s -e columnBroadcast --entry-point-result=void -print --seed 123 2>&1 | FileCheck %s -check-prefix=COLUMNBROADCAST +// RUN: tpp-run --linalg-to-loops %s -e columnBroadcast --entry-point-result=void -print -seed 123 2>&1 | FileCheck %s -check-prefix=COLUMNBROADCAST +// RUN: tpp-run --vector-to-XSMM %s -e columnBroadcast --entry-point-result=void --seed 123 2>&1 --mlir-print-ir-after=vectorization-pass | FileCheck %s --check-prefix=VECTOR + + + +#map2 = affine_map<(d0, d1, d2) -> (d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +func.func @columnBroadcast(%arg0: tensor<8xf32>, %arg1: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { + %arg = linalg.generic { + indexing_maps = [#map2, #map3], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<8xf32>) outs(%arg1 : tensor<2x4x8xf32>) { + ^bb0(%arg9: f32, %arg10: f32): + linalg.yield %arg9 : f32 + }->tensor<2x4x8xf32> + return %arg: tensor<2x4x8xf32> +} + +// VECTOR: vector.transfer_read +// VECTOR: vector.broadcast +// VECTOR: vector.transfer_write + +//COLUMNBROADCAST: ( 0, 0.130352, 0.151291, 0.0106365, 0.000375301, 0.298506, 0.0983867, 0.011257 ) +//COLUMNBROADCAST: ( 0, 0.130352, 0.151291, 0.0106365, 0.000375301, 0.298506, 0.0983867, 0.011257 ) +//COLUMNBROADCAST: ( 0, 0.130352, 0.151291, 0.0106365, 0.000375301, 0.298506, 0.0983867, 0.011257 ) +//COLUMNBROADCAST: ( 0, 0.130352, 0.151291, 0.0106365, 0.000375301, 0.298506, 0.0983867, 0.011257 ) +//COLUMNBROADCAST: ( 0, 0.130352, 0.151291, 0.0106365, 0.000375301, 0.298506, 0.0983867, 0.011257 ) +//COLUMNBROADCAST: ( 0, 0.130352, 0.151291, 0.0106365, 0.000375301, 0.298506, 0.0983867, 0.011257 ) +//COLUMNBROADCAST: ( 0, 0.130352, 0.151291, 0.0106365, 0.000375301, 0.298506, 0.0983867, 0.011257 ) +//COLUMNBROADCAST: ( 0, 0.130352, 0.151291, 0.0106365, 0.000375301, 0.298506, 0.0983867, 0.011257 ) + diff --git a/test/Integration/broadcast-row-1d.mlir b/test/Integration/broadcast-row-1d.mlir new file mode 100644 index 000000000..0a6b7c1f3 --- /dev/null +++ b/test/Integration/broadcast-row-1d.mlir @@ -0,0 +1,26 @@ +// RUN: tpp-run --vector-to-XSMM %s -e rowBroadcast --entry-point-result=void -print-mlir=mid 2>&1 | FileCheck %s -check-prefix=ROWBROADCAST +// RUN: tpp-run --vector-to-XSMM %s -e rowBroadcast --entry-point-result=void -print --seed 123 2>&1 +// RUN: tpp-run --linalg-to-loops %s -e rowBroadcast --entry-point-result=void -print --seed 123 2>&1 + +func.func @rowBroadcast(%arg0: memref<4x1xf32>, %arg1: memref<4x2xf32>) -> memref<4x2xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x1xf32>, vector<4x1xf32> + %1 = vector.broadcast %0 : vector<4x1xf32> to vector<4x2xf32> + vector.transfer_write %1, %arg1[%c0, %c0] {in_bounds = [true, true]} + : vector<4x2xf32>, memref<4x2xf32> + return %arg1 : memref<4x2xf32> +} + + +// ROWBROADCAST-DAG: %[[c4_i64:.*]] = arith.constant 4 : i64 +// ROWBROADCAST-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64 +// ROWBROADCAST-DAG: %[[c2_i64:.*]] = arith.constant 2 : i64 +// ROWBROADCAST: call @xsmm_unary_dispatch(%[[c1_i64]], %[[c1_i64]], %[[c4_i64]], %[[c2_i64]], %[[c1_i64]], %[[c2_i64]], %[[c2_i64]]) +// ROWBROADCAST: call @xsmm_unary_invoke + +// CHECK: ( 0, 0 ) +// CHECK: ( 0.130352, 0.130352 ) +// CHECK: ( 0.151291, 0.151291 ) +// CHECK: ( 0.0106365, 0.0106365 ) + diff --git a/test/Integration/broadcast-row.mlir b/test/Integration/broadcast-row.mlir new file mode 100644 index 000000000..434d8f464 --- /dev/null +++ b/test/Integration/broadcast-row.mlir @@ -0,0 +1,25 @@ +// RUN: tpp-run --vector-to-XSMM %s -e rowBroadcast --entry-point-result=void -print-mlir=mid 2>&1 | FileCheck %s -check-prefix=ROWBROADCAST +// RUN: tpp-run %s -e rowBroadcast --entry-point-result=void -print-mlir=mid 2>&1 | FileCheck %s -check-prefix=ROWBROADCAST + + +#map4 = affine_map<(d0, d1) -> (d1, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + +func.func @rowBroadcast(%arg0: tensor<128x1xf32>, %arg1: tensor<512x128xf32>) -> tensor<512x128xf32> { + %arg = linalg.generic { + indexing_maps = [#map4, #map5], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x1xf32>) outs(%arg1 : tensor<512x128xf32>) { + ^bb0(%arg9: f32, %arg10: f32): + linalg.yield %arg9 : f32 + }->tensor<512x128xf32> + return %arg: tensor<512x128xf32> +} +// ROWBROADCAST-DAG: %[[c0:.*]] = arith.constant 0 : index +// ROWBROADCAST-DAG: %[[c4_i64:.*]] = arith.constant 4 : i64 +// ROWBROADCAST-DAG: %[[c128_i64:.*]] = arith.constant 128 : i64 +// ROWBROADCAST-DAG: %[[c512_i64:.*]] = arith.constant 512 : i64 +// ROWBROADCAST-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64 +// ROWBROADCAST: call @xsmm_unary_dispatch(%[[c1_i64]], %[[c1_i64]], %[[c512_i64]], %[[c128_i64]], %[[c128_i64]], %[[c128_i64]], %[[c4_i64]]) +// ROWBROADCAST: call @xsmm_unary_invoke + diff --git a/test/Integration/broadcast-transpose.mlir b/test/Integration/broadcast-transpose.mlir new file mode 100644 index 000000000..1dcaa1c7b --- /dev/null +++ b/test/Integration/broadcast-transpose.mlir @@ -0,0 +1,38 @@ +// RUN: tpp-run --vector-to-XSMM %s -e broadcast_transpose --entry-point-result=void -print --seed 123 2>&1 | FileCheck %s -check-prefix=BROADCASTTRANSPOSE +// RUN: tpp-run --linalg-to-loops %s -e broadcast_transpose --entry-point-result=void -print -seed 123 2>&1 | FileCheck %s -check-prefix=BROADCASTTRANSPOSE +// RUN: tpp-run --vector-to-XSMM %s -e broadcast_transpose --entry-point-result=void -print-after=vectorization-pass --seed 123 2>&1 --mlir-print-ir-after=vectorization-pass | FileCheck %s --check-prefix=XSMM + + + +#map0 = affine_map<(d0, d1) -> (d1)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> + +func.func @broadcast_transpose(%arg0: tensor<8xf32>, %arg1: tensor<4x8xf32>, %arg2: tensor<8x4xf32>) -> tensor<8x4xf32> { + %arg = linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<8xf32>) outs(%arg1 : tensor<4x8xf32>) { + ^bb0(%arg9: f32, %arg10: f32): + linalg.yield %arg9 : f32 + }->tensor<4x8xf32> + + %out = linalg.transpose ins(%arg : tensor<4x8xf32>) outs(%arg2 : tensor<8x4xf32>) permutation = [1, 0] + return %out: tensor<8x4xf32> + +} + +// XSMM: vector.transfer_read +// XSMM: vector.broadcast +// XSMM: vector.transfer_write +// XSMM: vector.transfer_read +// XSMM: vector.transpose +// XSMM: vector.transfer_write + +//BROADCASTTRANSPOSE: ( 0, 0, 0, 0 ) +//BROADCASTTRANSPOSE: ( 0.130352, 0.130352, 0.130352, 0.130352 ) +//BROADCASTTRANSPOSE: ( 0.151291, 0.151291, 0.151291, 0.151291 ) +//BROADCASTTRANSPOSE: ( 0.0106365, 0.0106365, 0.0106365, 0.0106365 ) +//BROADCASTTRANSPOSE: ( 0.000375301, 0.000375301, 0.000375301, 0.000375301 ) +//BROADCASTTRANSPOSE: ( 0.298506, 0.298506, 0.298506, 0.298506 ) +//BROADCASTTRANSPOSE: ( 0.0983867, 0.0983867, 0.0983867, 0.0983867 ) +//BROADCASTTRANSPOSE: ( 0.011257, 0.011257, 0.011257, 0.011257 )