Skip to content

Commit

Permalink
WG To SG Transformation pass (intel#809)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbpatel authored Jul 17, 2024
1 parent 70cb6d0 commit 417a449
Show file tree
Hide file tree
Showing 28 changed files with 1,362 additions and 210 deletions.
4 changes: 2 additions & 2 deletions include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ class RewritePatternSet;
} // namespace mlir

namespace imex {
class XeGPUTypeConverter;
class XeOneToNTypeConverter;

/// Populate the given list with patterns rewrite XeTile Ops
void populateXeTileToXeGPUConversionPatterns(XeGPUTypeConverter &converter,
void populateXeTileToXeGPUConversionPatterns(XeOneToNTypeConverter &converter,
mlir::RewritePatternSet &patterns,
imex::TileUsageAnalysis &analysis);

Expand Down
44 changes: 23 additions & 21 deletions include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
//===----------------------------------------------------------------------===//
///
/// \file
/// This file defines the SgXeTileToXeGPUConversion, the base class for
/// XeTileToXeGPU conversion, XeGPUTypeConverter, converting types used in
/// XeTile dialect to types used in XeGPU dialect, XeGPUOneToNPatterRewriter a
/// This file defines the XeOneToNConversion, the base class for
/// XeTileToXeGPU conversion, XeOneToNTypeConverter, converting types used in
/// XeTile dialect to types used in XeGPU dialect, XeOneToNPatternRewriter a
/// wrapper around ConversionPatterRewriter providng interface for supporting
/// OneToN replace.
///
Expand All @@ -36,9 +36,9 @@

namespace imex {

class XeGPUTypeConverter : public imex::XeTypeConverter {
class XeOneToNTypeConverter : public imex::XeTypeConverter {
public:
XeGPUTypeConverter(mlir::MLIRContext &context);
XeOneToNTypeConverter(mlir::MLIRContext &context);

std::optional<mlir::LogicalResult>
convertTileType(xetile::TileType tileTy,
Expand All @@ -56,11 +56,11 @@ class XeGPUTypeConverter : public imex::XeTypeConverter {
mlir::Operation *targetOp;
};

class XeGPUOneToNPatterRewriter : public mlir::PatternRewriter,
public mlir::RewriterBase::Listener {
class XeOneToNPatternRewriter : public mlir::PatternRewriter,
public mlir::RewriterBase::Listener {
public:
explicit XeGPUOneToNPatterRewriter(mlir::ConversionPatternRewriter &rewriter,
XeGPUTypeConverter &converter)
explicit XeOneToNPatternRewriter(mlir::ConversionPatternRewriter &rewriter,
XeOneToNTypeConverter &converter)
: mlir::PatternRewriter(rewriter.getContext()), typeConverter(converter),
rewriter(rewriter) {
setListener(this);
Expand Down Expand Up @@ -94,6 +94,9 @@ class XeGPUOneToNPatterRewriter : public mlir::PatternRewriter,

void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) override;

void replaceOp(mlir::Operation *op, mlir::ValueRange newValues,
const mlir::OneToNTypeMapping &resultMapping);

void eraseOp(mlir::Operation *op) override { rewriter.eraseOp(op); }

void eraseBlock(mlir::Block *block) override { rewriter.eraseBlock(block); }
Expand All @@ -108,18 +111,17 @@ class XeGPUOneToNPatterRewriter : public mlir::PatternRewriter,
};

private:
XeGPUTypeConverter &typeConverter;
XeOneToNTypeConverter &typeConverter;
mlir::ConversionPatternRewriter &rewriter;
};

template <typename SourceOp>
class SgXeTileToXeGPUConversion
: public XeConversionPattern<TileUsageAnalysis> {
class XeOneToNConversion : public XeConversionPattern<TileUsageAnalysis> {
public:
SgXeTileToXeGPUConversion(mlir::MLIRContext *context,
XeGPUTypeConverter &typeConverter,
TileUsageAnalysis &analysis,
mlir::PatternBenefit benefit = 1)
XeOneToNConversion(mlir::MLIRContext *context,
XeOneToNTypeConverter &typeConverter,
TileUsageAnalysis &analysis,
mlir::PatternBenefit benefit = 1)
: XeConversionPattern(typeConverter, analysis,
SourceOp::getOperationName(), benefit, context) {}

Expand All @@ -130,7 +132,7 @@ class SgXeTileToXeGPUConversion
* This overwrites the RewritePattern::matchAndRewrite as it is the entry
* point. It will set up the OpAdaptor such that it contains the converted
* values, and wrap the ConversionPatternRewriter with
* XeGPUOneToNPatterRewriter to provide a clean interface for users.
* XeOneToNPatternRewriter to provide a clean interface for users.
*/
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
Expand All @@ -145,7 +147,7 @@ class SgXeTileToXeGPUConversion
// One-To-One mapping provided by mlir::ConversionPatternRewriter.
// remappedValues contains new values for each operand of the operation. It
// is supposed to be a UnrealizedConversionCastOp (created by the replaceOp
// of XeGPUOneToNPatternRewriter in form of cast newvalues to oldType) for
// of XeGPUXeOneToNPatternRewriter in form of cast newvalues to oldType) for
// each operand that has One-to-N mapping.
llvm::SmallVector<mlir::Value> remappedValues;
if (mlir::failed(convertionPatternRewriter.getRemappedValues(
Expand All @@ -167,14 +169,14 @@ class SgXeTileToXeGPUConversion

auto sourceOp = llvm::dyn_cast<SourceOp>(op);
OpAdaptor adaptor(convertedValues, sourceOp);
XeGPUOneToNPatterRewriter OneToNRewriter(
convertionPatternRewriter, getTypeConverter<XeGPUTypeConverter>());
XeOneToNPatternRewriter OneToNRewriter(
convertionPatternRewriter, getTypeConverter<XeOneToNTypeConverter>());
return matchAndRewrite(sourceOp, adaptor, OneToNRewriter);
}

virtual mlir::LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
XeGPUOneToNPatterRewriter &rewriter) const {
XeOneToNPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
};
Expand Down
3 changes: 3 additions & 0 deletions include/imex/Dialect/XeTile/IR/XeTileTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface],
}

xetile::WorkGroupMapAttr getWgMap() {
auto wgmap = llvm::dyn_cast_if_present<xetile::WorkGroupMapAttr>(getEncoding());
if (wgmap)
return wgmap;
auto encoding = llvm::dyn_cast_if_present<xetile::XeTileAttr>(getEncoding());
if (encoding)
return encoding.getWgMap();
Expand Down
1 change: 1 addition & 0 deletions include/imex/Dialect/XeTile/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ std::unique_ptr<mlir::Pass> createXeTileInitDuplicatePass();
std::unique_ptr<mlir::Pass>
createXeTileBlockingPass(const std::string &device = "pvc");
std::unique_ptr<mlir::Pass> createXeTileBlockAligningPass();
std::unique_ptr<mlir::Pass> createXeTileWgToSgPass();
std::unique_ptr<mlir::Pass> createXeTileOptimizeTransposePass();

///
Expand Down
16 changes: 16 additions & 0 deletions include/imex/Dialect/XeTile/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ def XeTileBlockAligning: Pass <"xetile-block-aligning", "::mlir::gpu::GPUModuleO
"mlir::vector::VectorDialect"];
}

def XeTileWgToSg : Pass<"xetile-wg-to-sg", "::mlir::gpu::GPUModuleOp">{
let summary = "Transform WG level XeTile code to SG XeTile";

let description = [{
This transform pass transforms WG level XeTile code to SG XeTile.
}];

let constructor = "imex::createXeTileWgToSgPass()";
let dependentDialects = ["imex::xetile::XeTileDialect",
"mlir::arith::ArithDialect",
"mlir::gpu::GPUDialect",
"mlir::index::IndexDialect",
"mlir::memref::MemRefDialect",
"mlir::vector::VectorDialect"];
}

def XeTileOptimizeTranspose : Pass<"xetile-optimize-transpose"> {
let summary = "Fuse tile loads and transpose operations.";

Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h>
#include <mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h>
#include <mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h>
#include <mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h>
#include <mlir/Conversion/MathToSPIRV/MathToSPIRV.h>
#include <mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h>
#include <mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h>
Expand Down Expand Up @@ -352,6 +353,7 @@ void GPUXToSPIRVPass::runOnOperation() {
mlir::populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
mlir::populateVectorToSPIRVPatterns(typeConverter, patterns);
mlir::populateMathToSPIRVPatterns(typeConverter, patterns);
mlir::index::populateIndexToSPIRVPatterns(typeConverter, patterns);
mlir::populateMemRefToSPIRVPatterns(typeConverter, patterns);
mlir::populateFuncToSPIRVPatterns(typeConverter, patterns);
// ---------------------------------------
Expand Down
27 changes: 13 additions & 14 deletions lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ extern VectorTypedValue concat(mlir::Value v1, mlir::Value v2,
extern mlir::Value mergeVectorsWrapper(mlir::ValueRange ins,
std::function<funcTy> transFunc,
mlir::Location loc,
XeGPUOneToNPatterRewriter &rewriter);
XeOneToNPatternRewriter &rewriter);

static mlir::Value createBinOp(mlir::vector::CombiningKind kind,
mlir::Value lhs, mlir::Value rhs,
mlir::Type elemTy, mlir::Location &loc,
XeGPUOneToNPatterRewriter &rewriter) {
XeOneToNPatternRewriter &rewriter) {

// ADD and MUL are defined for both Integers and Floats,
// need to generate code based on element data type.
Expand Down Expand Up @@ -91,7 +91,7 @@ static mlir::Value createBinOp(mlir::vector::CombiningKind kind,
llvm::SmallVector<mlir::Value>
lowerOuterReduction(mlir::ValueRange sources, llvm::ArrayRef<int64_t> shape,
mlir::vector::CombiningKind kind, mlir::Location loc,
mlir::Type elemTy, XeGPUOneToNPatterRewriter &rewriter) {
mlir::Type elemTy, XeOneToNPatternRewriter &rewriter) {
assert(shape.size() == 4 && "shape should be 4D.");
llvm::SmallVector<mlir::Value> intermediates;
for (auto j = 0; j < shape[1]; j++) {
Expand Down Expand Up @@ -136,7 +136,7 @@ lowerOuterReduction(mlir::ValueRange sources, llvm::ArrayRef<int64_t> shape,
llvm::SmallVector<mlir::Value> lowerInnerReductionWithIntraVectorShuffles(
mlir::ValueRange sources, llvm::ArrayRef<int64_t> shape,
mlir::vector::CombiningKind kind, mlir::Location loc, mlir::Type elemTy,
XeGPUOneToNPatterRewriter &rewriter) {
XeOneToNPatternRewriter &rewriter) {

assert(shape.size() == 4 && "shape should be 4D.");

Expand Down Expand Up @@ -235,7 +235,7 @@ llvm::SmallVector<mlir::Value> lowerInnerReductionWithIntraVectorShuffles(
llvm::SmallVector<mlir::Value> lowerInnerReductionWithVectorReduction(
mlir::ValueRange sources, llvm::ArrayRef<int64_t> shape,
mlir::vector::CombiningKind kind, mlir::Location loc, mlir::Type elemTy,
XeGPUOneToNPatterRewriter &rewriter) {
XeOneToNPatternRewriter &rewriter) {

assert(shape.size() == 4 && "shape should be 4D.");
// vector<ixjx1xnxf16> equals to a grid of ixj of vector<1xnxf16>
Expand Down Expand Up @@ -266,13 +266,13 @@ llvm::SmallVector<mlir::Value> lowerInnerReductionWithVectorReduction(
}

class SgVectorMultiDimReductionOpPattern
: public SgXeTileToXeGPUConversion<mlir::vector::MultiDimReductionOp> {
using SgXeTileToXeGPUConversion<
mlir::vector::MultiDimReductionOp>::SgXeTileToXeGPUConversion;
: public XeOneToNConversion<mlir::vector::MultiDimReductionOp> {
using XeOneToNConversion<
mlir::vector::MultiDimReductionOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::vector::MultiDimReductionOp op, OpAdaptor adaptor,
XeGPUOneToNPatterRewriter &rewriter) const override {
XeOneToNPatternRewriter &rewriter) const override {
auto srcTy = op.getSource().getType();
auto elemTy = srcTy.getElementType();
auto dims = op.getReductionDims();
Expand Down Expand Up @@ -331,13 +331,12 @@ class SgVectorMultiDimReductionOpPattern
};

class SgArithConstantOpPattern
: public SgXeTileToXeGPUConversion<mlir::arith::ConstantOp> {
using SgXeTileToXeGPUConversion<
mlir::arith::ConstantOp>::SgXeTileToXeGPUConversion;
: public XeOneToNConversion<mlir::arith::ConstantOp> {
using XeOneToNConversion<mlir::arith::ConstantOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::arith::ConstantOp op, OpAdaptor adaptor,
XeGPUOneToNPatterRewriter &rewriter) const override {
XeOneToNPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto value = llvm::dyn_cast<mlir::DenseElementsAttr>(op.getValue());

Expand Down Expand Up @@ -392,7 +391,7 @@ bool isLegalArithOp(mlir::Operation *op) {
return true;
}

void populateArithOpConversionPatterns(imex::XeGPUTypeConverter &converter,
void populateArithOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
mlir::RewritePatternSet &patterns,
TileUsageAnalysis &analysis) {
patterns.add<SgArithConstantOpPattern, SgVectorMultiDimReductionOpPattern>(
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/XeTileToXeGPU/ArithOpConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace imex {
bool isLegalArithOp(mlir::Operation *op);

void populateArithOpConversionPatterns(imex::XeGPUTypeConverter &converter,
void populateArithOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
mlir::RewritePatternSet &patterns,
TileUsageAnalysis &analysis);

Expand Down
21 changes: 9 additions & 12 deletions lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@

namespace imex {

struct SgSCFForOpBlockPattern
: public SgXeTileToXeGPUConversion<mlir::scf::ForOp> {
using SgXeTileToXeGPUConversion<mlir::scf::ForOp>::SgXeTileToXeGPUConversion;
struct SgSCFForOpBlockPattern : public XeOneToNConversion<mlir::scf::ForOp> {
using XeOneToNConversion<mlir::scf::ForOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor,
imex::XeGPUOneToNPatterRewriter &rewriter) const override {
imex::XeOneToNPatternRewriter &rewriter) const override {
// OpAdaptor is defined with ValueRange, so it contains results after
// One-to-N mapping
llvm::SmallVector<mlir::Value> convertedArgs;
Expand All @@ -42,9 +41,9 @@ struct SgSCFForOpBlockPattern
// lowered into many different types of TensorDescType (due to different
// setting of array_length). But typeconverter has no knowledge about when
// to use array_lenght and when not.
auto typeConverter = getTypeConverter<XeGPUTypeConverter>();
auto typeConverter = getTypeConverter<XeOneToNTypeConverter>();
auto argTys = op.getRegion().getArgumentTypes();
mlir::OneToNTypeMapping argumentMapping(argTys);
mlir::OneToNTypeMapping argumentMapping(argTys); // vectorty
llvm::ArrayRef<mlir::Value> args(op.getRegion().getArguments().begin(),
op.getRegion().getArguments().end());
llvm::ArrayRef<mlir::Value> newArgs(
Expand All @@ -71,14 +70,12 @@ struct SgSCFForOpBlockPattern
}
};

struct SgSCFYieldOpPattern
: public SgXeTileToXeGPUConversion<mlir::scf::YieldOp> {
using SgXeTileToXeGPUConversion<
mlir::scf::YieldOp>::SgXeTileToXeGPUConversion;
struct SgSCFYieldOpPattern : public XeOneToNConversion<mlir::scf::YieldOp> {
using XeOneToNConversion<mlir::scf::YieldOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor,
imex::XeGPUOneToNPatterRewriter &rewriter) const override {
imex::XeOneToNPatternRewriter &rewriter) const override {
llvm::SmallVector<mlir::Value> convertedResults;
for (auto &values : adaptor.getResults())
convertedResults.append(values.begin(), values.end());
Expand Down Expand Up @@ -116,7 +113,7 @@ bool isLegalSCFOp(mlir::Operation *op) {
return result;
}

void populateSCFOpConversionPatterns(imex::XeGPUTypeConverter &converter,
void populateSCFOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
mlir::RewritePatternSet &patterns,
TileUsageAnalysis &analysis) {
patterns.add<SgSCFForOpBlockPattern, SgSCFYieldOpPattern>(
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/XeTileToXeGPU/SCFOpConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace imex {
bool isLegalSCFOp(mlir::Operation *op);

void populateSCFOpConversionPatterns(imex::XeGPUTypeConverter &converter,
void populateSCFOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
mlir::RewritePatternSet &patterns,
TileUsageAnalysis &analysis);

Expand Down
Loading

0 comments on commit 417a449

Please sign in to comment.