From a4d993d4b2505eac3cf4b258fd2d01551ae92560 Mon Sep 17 00:00:00 2001 From: Charitha Saumya <136391709+charithaintc@users.noreply.github.com> Date: Tue, 30 Jul 2024 11:15:32 -0700 Subject: [PATCH] Fix incomplete legality check for SCF dialect in XeGPUToVC (#816) --- lib/Conversion/XeGPUToVC/XeGPUToVC.cpp | 40 ++++++++++++-------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp index 9153333d6..bfee41a47 100644 --- a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp +++ b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp @@ -17,11 +17,14 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -31,6 +34,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "../PassDetail.h" +#include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" @@ -1270,28 +1274,18 @@ struct SCFYieldOpVCPattern final } }; -bool isLegalXeGPUSCFOp(mlir::Operation *op) { - bool result = true; - if (llvm::isa(op)) { - auto forOp = llvm::cast(op); - for (const auto &arg : forOp.getInitArgs()) { - auto type = arg.getType(); - if (mlir::isa(type)) - result &= (mlir::cast(type).getRank() == 1); - } - } - - if (llvm::isa(op)) { - auto yieldOp = llvm::cast(op); - for (const auto &arg : yieldOp.getResults()) { - auto type = arg.getType(); - - if (mlir::isa(type)) - result &= (mlir::cast(type).getRank() == 1); - } +bool isLegalXeGPUSCFOp(mlir::Operation *op, mlir::TypeConverter typeConverter) { + llvm::SmallVector args; + if (llvm::isa(op)) + args = llvm::cast(op).getInitArgs(); + else if (llvm::isa(op)) + args = llvm::cast(op).getResults(); + // Check the legality of arguments using the type converter. + for (const auto &arg : args) { + if (!typeConverter.isLegal(arg.getType())) + return false; } - - return result; + return true; } static bool isGenericVectorTy(mlir::Type type) { @@ -1377,7 +1371,9 @@ struct XeGPUToVCPass : public ::imex::ConvertXeGPUToVCBase { target.addIllegalDialect(); target.addDynamicallyLegalDialect( - [&](mlir::Operation *op) { return isLegalXeGPUSCFOp(op); }); + [&](mlir::Operation *op) { + return isLegalXeGPUSCFOp(op, typeConverter); + }); target.addDynamicallyLegalOp<::mlir::arith::MaximumFOp>( [&](::mlir::arith::MaximumFOp op) {