Skip to content

Commit

Permalink
Update XeTileAttr constructors and add pattern for cmpi op
Browse files Browse the repository at this point in the history
  • Loading branch information
chencha3 committed Nov 13, 2024
1 parent 09e7136 commit 0aac661
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
9 changes: 5 additions & 4 deletions include/imex/Dialect/XeTile/IR/XeTileAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> {
[{
mlir::Type intType = mlir::IntegerType::get($_ctxt, 32);
mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered);
mlir::DenseI64ArrayAttr blkAttr = inner_blocks.empty()? mlir::DenseI64ArrayAttr():
mlir::DenseI64ArrayAttr::get($_ctxt, inner_blocks);
return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order),
mlir::DenseI64ArrayAttr::get($_ctxt, inner_blocks),
mlir::IntegerAttr::get(intType, memory_space), scatteredAttr);
blkAttr, mlir::IntegerAttr::get(intType, memory_space), scatteredAttr);
}]>,
AttrBuilder<(ins CArg<"llvm::ArrayRef<int32_t>", "{1, 0}">:$order,
CArg<"int", "0">:$memory_space, CArg<"bool", "false">:$scattered),
Expand All @@ -90,7 +91,7 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> {
mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered);
return $_get($_ctxt, xetile::SubGroupMapAttr(), xetile::WorkGroupMapAttr(),
mlir::DenseI32ArrayAttr::get($_ctxt, order),
mlir::DenseI64ArrayAttr::get($_ctxt, {}),
mlir::DenseI64ArrayAttr(),
mlir::IntegerAttr::get(intType, memory_space), scatteredAttr);
}]>,
AttrBuilder<(ins CArg<"xetile::SubGroupMapAttr", "{}">:$sg_map,
Expand All @@ -101,7 +102,7 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> {
mlir::Type intType = mlir::IntegerType::get($_ctxt, 32);
mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered);
return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order),
mlir::DenseI64ArrayAttr::get($_ctxt, {}),
mlir::DenseI64ArrayAttr(),
mlir::IntegerAttr::get(intType, memory_space), scatteredAttr);
}]>
];
Expand Down
30 changes: 28 additions & 2 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,32 @@ struct TypecastOpPattern : public XeOneToNConversion<CastOp> {
}
};

struct SgArithCmpIOpPattern : public XeOneToNConversion<mlir::arith::CmpIOp> {
using XeOneToNConversion<mlir::arith::CmpIOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::arith::CmpIOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {
auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
if (!resType || resType.getRank() != 4)
return mlir::failure();

auto lhs = adaptor.getLhs();
auto rhs = adaptor.getRhs();

llvm::SmallVector<mlir::Value> newOps;
for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
auto newOp = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), op.getPredicate(), l, r);
newOps.push_back(newOp);
}

rewriter.replaceOp(op, newOps);
return mlir::success();
}
};

struct SgBroadcastOpPattern : public XeOneToNConversion<xetile::BroadcastOp> {
using XeOneToNConversion<xetile::BroadcastOp>::XeOneToNConversion;

Expand Down Expand Up @@ -1256,8 +1282,8 @@ void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
SgVectorSplatOpPattern, SgUpdateTileOffsetOpPattern,
SgTransposeOpPattern<mlir::vector::TransposeOp>,
SgTransposeOpPattern<xetile::TransposeOp>, SgBroadcastOpPattern,
SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>(
patterns.getContext(), converter, analysis);
SgTileReductionOpPattern, SgVectorCreateMaskOpPattern,
SgArithCmpIOpPattern>(patterns.getContext(), converter, analysis);

// Element-wise math operations
patterns.insert<ElementWiseOpPattern<mlir::arith::NegFOp, 1>,
Expand Down

0 comments on commit 0aac661

Please sign in to comment.