Skip to content

Commit

Permalink
Fix incomplete legality check for SCF dialect in XeGPUToVC (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
charithaintc authored Jul 30, 2024
1 parent 5a7bb80 commit a4d993d
Showing 1 changed file with 18 additions and 22 deletions.
40 changes: 18 additions & 22 deletions lib/Conversion/XeGPUToVC/XeGPUToVC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"

Expand All @@ -31,6 +34,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"

#include "../PassDetail.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"

Expand Down Expand Up @@ -1270,28 +1274,18 @@ struct SCFYieldOpVCPattern final
}
};

bool isLegalXeGPUSCFOp(mlir::Operation *op) {
bool result = true;
if (llvm::isa<mlir::scf::ForOp>(op)) {
auto forOp = llvm::cast<mlir::scf::ForOp>(op);
for (const auto &arg : forOp.getInitArgs()) {
auto type = arg.getType();
if (mlir::isa<mlir::VectorType>(type))
result &= (mlir::cast<mlir::VectorType>(type).getRank() == 1);
}
}

if (llvm::isa<mlir::scf::YieldOp>(op)) {
auto yieldOp = llvm::cast<mlir::scf::YieldOp>(op);
for (const auto &arg : yieldOp.getResults()) {
auto type = arg.getType();

if (mlir::isa<mlir::VectorType>(type))
result &= (mlir::cast<mlir::VectorType>(type).getRank() == 1);
}
bool isLegalXeGPUSCFOp(mlir::Operation *op, mlir::TypeConverter typeConverter) {
llvm::SmallVector<mlir::Value> args;
if (llvm::isa<mlir::scf::ForOp>(op))
args = llvm::cast<mlir::scf::ForOp>(op).getInitArgs();
else if (llvm::isa<mlir::scf::YieldOp>(op))
args = llvm::cast<mlir::scf::YieldOp>(op).getResults();
// Check the legality of arguments using the type converter.
for (const auto &arg : args) {
if (!typeConverter.isLegal(arg.getType()))
return false;
}

return result;
return true;
}

static bool isGenericVectorTy(mlir::Type type) {
Expand Down Expand Up @@ -1377,7 +1371,9 @@ struct XeGPUToVCPass : public ::imex::ConvertXeGPUToVCBase<XeGPUToVCPass> {
target.addIllegalDialect<mlir::xegpu::XeGPUDialect>();

target.addDynamicallyLegalDialect<mlir::scf::SCFDialect>(
[&](mlir::Operation *op) { return isLegalXeGPUSCFOp(op); });
[&](mlir::Operation *op) {
return isLegalXeGPUSCFOp(op, typeConverter);
});

target.addDynamicallyLegalOp<::mlir::arith::MaximumFOp>(
[&](::mlir::arith::MaximumFOp op) {
Expand Down

0 comments on commit a4d993d

Please sign in to comment.