Skip to content

Commit

Permalink
Unary broadcast lowering (#995)
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu authored Dec 20, 2024
1 parent d7fcc28 commit 8d0da95
Show file tree
Hide file tree
Showing 10 changed files with 437 additions and 42 deletions.
74 changes: 60 additions & 14 deletions lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,15 @@ using namespace mlir::xsmm;
#define DEBUG_TYPE "convert-vector-to-xsmm"

static pair<Operation *, Operation *>
getTransposeXSMMCalls(PatternRewriter &rewriter, xsmm::UnaryInfo &unaryInfo,
Type outputType, xsmm::UnaryKind unaryKind,
Operation *input, Operation *output,
Operation *transposeOp) {
getUnaryXSMMCalls(PatternRewriter &rewriter, xsmm::UnaryInfo &unaryInfo,
Type outputType, xsmm::UnaryKind unaryKind, Operation *input,
Operation *output, Operation *unaryOp, int64_t unaryFlag) {
std::string dispatchName = "xsmm_unary_dispatch";
std::string invokeName = "xsmm_unary_invoke";
Location loc = transposeOp->getLoc();
Location loc = unaryOp->getLoc();

auto dtype =
xsmm::utils::getDataType(rewriter, transposeOp->getOperand(0).getType());
xsmm::utils::getDataType(rewriter, unaryOp->getOperand(0).getType());

SmallVector<xsmm::utils::XsmmOperand> dispatchOperands;
xsmm::UnaryKindAttr kind =
Expand All @@ -49,14 +48,12 @@ getTransposeXSMMCalls(PatternRewriter &rewriter, xsmm::UnaryInfo &unaryInfo,

dispatchOperands.append(SmallVector<xsmm::utils::XsmmOperand>{
unaryInfo.m, unaryInfo.n, unaryInfo.ldi, unaryInfo.ldo});
IntegerAttr unaryFlags = dyn_cast<IntegerAttr>(
xsmm::UnaryFlagsAttr::get(rewriter.getContext(), xsmm::UnaryFlags::NONE));
dispatchOperands.push_back(unaryFlags.getInt());
dispatchOperands.push_back(unaryFlag);

auto dispatchCall = xsmm::utils::buildXsmmCall(
rewriter, xsmm::utils::XsmmCallType::DISPATCH, loc, dtype,
dispatchOperands, IntegerType::get(rewriter.getContext(), 64),
SymbolRefAttr::get(transposeOp->getContext(), dispatchName), transposeOp,
SymbolRefAttr::get(unaryOp->getContext(), dispatchName), unaryOp,
nullptr);
SmallVector<xsmm::utils::XsmmOperand> operandRange{
dyn_cast<DataTypeAttr>(dtype).getInt(),
Expand All @@ -65,8 +62,8 @@ getTransposeXSMMCalls(PatternRewriter &rewriter, xsmm::UnaryInfo &unaryInfo,
input->getOperand(0), output->getOperand(1)};
auto invokeCall = xsmm::utils::buildXsmmCall(
rewriter, xsmm::utils::XsmmCallType::INVOKE, loc, dtype, operandRange,
TypeRange(), SymbolRefAttr::get(transposeOp->getContext(), invokeName),
transposeOp, output);
TypeRange(), SymbolRefAttr::get(unaryOp->getContext(), invokeName),
unaryOp, output);
return std::make_pair(&*dispatchCall, &*invokeCall);
}

Expand Down Expand Up @@ -107,8 +104,11 @@ convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp,
} else {
std::swap(unaryInfo.m, unaryInfo.n);
}
return getTransposeXSMMCalls(rewriter, unaryInfo, outputType, opType, input,
output, transposeOp);

IntegerAttr unaryFlags = dyn_cast<IntegerAttr>(
xsmm::UnaryFlagsAttr::get(rewriter.getContext(), xsmm::UnaryFlags::NONE));
return getUnaryXSMMCalls(rewriter, unaryInfo, outputType, opType, input,
output, transposeOp, unaryFlags.getInt());
}

static LogicalResult validateTransposeOp(PatternRewriter &rewriter,
Expand Down Expand Up @@ -163,11 +163,57 @@ static LogicalResult validateTransposeOp(PatternRewriter &rewriter,
return success();
}

static std::pair<Operation *, Operation *>
convertBroadcast(PatternRewriter &rewriter, Operation *broadcastOp,
Operation *input, Operation *output) {
LLVM_DEBUG(llvm::dbgs() << "convertBroadcast\n");
auto unaryFlag = xsmm::utils::getUnaryFlags(input->getOperand(0).getType(),
output->getOperand(1).getType());
auto inputMemRefType = dyn_cast<MemRefType>(input->getOperand(0).getType());
auto outputMemRefType = dyn_cast<MemRefType>(output->getOperand(1).getType());
auto inputVectorType = dyn_cast<VectorType>(input->getResult(0).getType());
auto outputVectorType = dyn_cast<VectorType>(output->getOperand(0).getType());
auto unaryInfo = *xsmm::utils::getVectorUnaryInfo(
outputMemRefType, outputMemRefType, inputMemRefType, inputVectorType,
outputVectorType, *unaryFlag);
if (*unaryFlag == UnaryFlags::BCAST_ROW)
std::swap(unaryInfo.ldi, unaryInfo.ldo);

IntegerAttr unaryFlags = dyn_cast<IntegerAttr>(
xsmm::UnaryFlagsAttr::get(rewriter.getContext(), *unaryFlag));

return getUnaryXSMMCalls(rewriter, unaryInfo, outputVectorType,
xsmm::UnaryKind::IDENTITY, input, output,
broadcastOp, unaryFlags.getInt());
}

static LogicalResult validateBroadcastOp(PatternRewriter &rewriter,
Operation *broadcastOp,
Operation *input, Operation *output) {
LLVM_DEBUG(llvm::dbgs() << "validateBroadcastOp\n");
auto unaryFlag = xsmm::utils::getUnaryFlags(input->getOperand(0).getType(),
output->getOperand(1).getType());
auto inputMemRefType = dyn_cast<MemRefType>(input->getOperand(0).getType());
auto outputMemRefType = dyn_cast<MemRefType>(output->getOperand(1).getType());
auto inputVectorType = dyn_cast<VectorType>(input->getResult(0).getType());
auto outputVectorType = dyn_cast<VectorType>(output->getOperand(0).getType());
auto unaryInfo = xsmm::utils::getVectorUnaryInfo(
outputMemRefType, inputMemRefType, outputMemRefType, inputVectorType,
outputVectorType, *unaryFlag);
if (failed(unaryInfo))
return failure();
return success();
}

static void registerNativeRewrite(RewritePatternSet &patterns) {
patterns.getPDLPatterns().registerRewriteFunction("ConvertTranspose",
convertTransposeOp);
patterns.getPDLPatterns().registerConstraintFunction("ValidateTranspose",
validateTransposeOp);
patterns.getPDLPatterns().registerRewriteFunction("ConvertBroadcast",
convertBroadcast);
patterns.getPDLPatterns().registerConstraintFunction("ValidateBroadcast",
validateBroadcastOp);
}

namespace mlir {
Expand Down
18 changes: 18 additions & 0 deletions lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmmPDL.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ Rewrite ConvertTranspose(op:Op<vector.transpose>, input0:Op<vector.transfer_read

Constraint ValidateTranspose(op:Op<vector.transpose>, input0:Op<vector.transfer_read>, output:Op<vector.transfer_write>, outputType:TypeRange);


Rewrite ConvertBroadcast(op:Op<vector.broadcast>, input:Op<vector.transfer_read>, output:Op<vector.transfer_write>)->(dispatch:Op<func.callOp>, invoke:Op<func.callOp>);

Constraint ValidateBroadcast(op:Op<vector.broadcast>, input0:Op<vector.transfer_read>, output:Op<vector.transfer_write>);

Pattern ConvertTransposePattern{
let input0 = op<vector.transfer_read>(alloc0:Value, indices0:ValueRange, const0:Value, constIndex:ValueRange)->(output:TypeRange);
let transpose = op<vector.transpose>(input0)->(transposeOutput0:Type);
Expand All @@ -20,3 +25,16 @@ Pattern ConvertTransposePattern{
erase output0;
};
}

Pattern ConvertBroadcastPattern{
let input0 = op<vector.transfer_read>(alloc0:Value, indices0:ValueRange, const0:Value, constIndex:ValueRange)->(output:TypeRange);
let broadcast = op<vector.broadcast>(input0)->(broadcastOutput0:Type);
let output0 = op<vector.transfer_write>(broadcast, alloc1:Value, outindices:ValueRange, constIndex2:ValueRange)->(typeRange:TypeRange);
ValidateBroadcast(broadcast, input0, output0);
rewrite broadcast with{
let replacement = ConvertBroadcast(broadcast, input0, output0);
replace broadcast with (replacement.dispatch, replacement.invoke);
erase output0;
};

}
12 changes: 6 additions & 6 deletions lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ struct DefaultTppPasses
}
if (vectorToXSMM) {
skipOperations.clear();
skipOperations.push_back("unary");
skipOperations.push_back("transpose");
skipOperations.push_back("vnni");
}
Expand Down Expand Up @@ -143,11 +144,11 @@ struct DefaultTppPasses
pm.addNestedPass<func::FuncOp>(createLoopInvariantCodeMotionPass());
pm.addNestedPass<func::FuncOp>(createVectorizationPass());

//Please note, canonicalizer should be after hoisting pass because
//it fuses outer tiling loops and it results in no pattern
//matching for hoisting pass. Moved inside VectorToKernel Path.
if (vectorToXSMM) {
// Please note, canonicalizer should be after hoisting pass because
// it fuses outer tiling loops and it results in no pattern
// matching for hoisting pass. Moved inside VectorToKernel Path.

if (vectorToXSMM) {
pm.addPass(createVectorToXSMM());
}
if (vectorToKernel) {
Expand Down Expand Up @@ -193,4 +194,3 @@ struct DefaultTppPasses
};

} // namespace

32 changes: 21 additions & 11 deletions lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,23 +140,31 @@ getVectorUnaryInfo(MemRefType shapedType, MemRefType inputType,

UnaryInfo unaryInfo;

unaryInfo.m = shapedType.getShape()[0];
unaryInfo.n = shapedType.getShape()[1];
unaryInfo.m = 1;
for (unsigned i = 0; i < shapedType.getShape().size() - 1; i++) {
unaryInfo.m *= shapedType.getShape()[i];
}
unaryInfo.n = shapedType.getShape()[shapedType.getShape().size() - 1];

auto getStrideLoc = [&](MemRefType inputType) -> FailureOr<int64_t> {
auto getStrideLoc = [&](MemRefType memrefType) -> FailureOr<int64_t> {
int64_t strideAtLoc;
SmallVector<int64_t> strides;
int64_t offset;

if (failed(getStridesAndOffset(inputType, strides, offset)))
if (failed(getStridesAndOffset(memrefType, strides, offset))) {
return failure();
}
if (strides.empty()) {
return failure();
}
strideAtLoc = strides[0];
return strideAtLoc;
};

auto strideLdi = getStrideLoc(inputType);
if (failed(strideLdi))
if (failed(strideLdi)) {
return failure();
}
unaryInfo.ldi = *strideLdi;
int ldo = 1;
// If we are broascasting a row into cols, the leading
Expand All @@ -166,12 +174,16 @@ getVectorUnaryInfo(MemRefType shapedType, MemRefType inputType,
ldo = 1;
// If we are broascasting a col into rows, the leading
// dimension is the size of the tensor.
else if (inputFlag == UnaryFlags::BCAST_COL)
ldo = inputVectorType.getShape()[0];
else {
else if (inputFlag == UnaryFlags::BCAST_COL) {
if (inputVectorType.getShape().size() == 0)
ldo = 1;
else
ldo = inputVectorType.getShape()[0];
} else {
auto strideLdo = getStrideLoc(outputType);
if (failed(strideLdo))
if (failed(strideLdo)) {
return failure();
}
ldo = *strideLdo;
}
unaryInfo.ldo = ldo;
Expand Down Expand Up @@ -241,8 +253,6 @@ FailureOr<BinaryInfo> getBinaryInfo(Value lhs, BinaryFlags lhsFlag, Value rhs,

FailureOr<UnaryFlags> getUnaryFlags(Type inputType, Type outputType) {
assert(isa<ShapedType>(outputType) && "expect shaped type on output");
assert(cast<ShapedType>(outputType).getRank() == 2 &&
"expect rank 2 on output");

if (!isa<ShapedType>(inputType) ||
cast<ShapedType>(inputType).getRank() == 0) {
Expand Down
26 changes: 15 additions & 11 deletions lib/TPP/Transforms/Utils/BuilderUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,22 @@ arith::ConstantOp getConstant(OpBuilder &builder, Type type, ValueT value) {

func::FuncOp createFunction(OpBuilder &builder, ModuleOp module, StringRef name,
TypeRange args, TypeRange ret, bool createBody) {
auto unkLoc = builder.getUnknownLoc();
auto funcType = FunctionType::get(builder.getContext(), args, ret);
auto func = func::FuncOp::create(unkLoc, name, funcType);
func.setVisibility(SymbolTable::Visibility::Private);
if (createBody) {
func.setVisibility(SymbolTable::Visibility::Public);
auto *entryBlock = func.addEntryBlock();
builder.setInsertionPointToEnd(entryBlock);
auto oper = module.lookupSymbol(name);
if (oper)
return dyn_cast<func::FuncOp>(oper);
else {
auto unkLoc = builder.getUnknownLoc();
auto funcType = FunctionType::get(builder.getContext(), args, ret);
auto func = func::FuncOp::create(unkLoc, name, funcType);
func.setVisibility(SymbolTable::Visibility::Private);
if (createBody) {
func.setVisibility(SymbolTable::Visibility::Public);
auto *entryBlock = func.addEntryBlock();
builder.setInsertionPointToEnd(entryBlock);
}
module.push_back(func);
return func;
}
module.push_back(func);

return func;
}

Value getConstInt(OpBuilder &builder, int value, int width) {
Expand Down
Loading

0 comments on commit 8d0da95

Please sign in to comment.