diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 3d6a911f365410..94197e473ce012 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -214,8 +214,7 @@ void MmaOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); // Print the types of the operands and result. - p << " : " - << "("; + p << " : " << "("; llvm::interleaveComma(SmallVector{frags[0].regs[0].getType(), frags[1].regs[0].getType(), frags[2].regs[0].getType()}, @@ -956,9 +955,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() { ss << "},"; // Need to map read/write registers correctly. regCnt = (regCnt * 2); - ss << " $" << (regCnt) << "," - << " $" << (regCnt + 1) << "," - << " p"; + ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p"; if (getTypeD() != WGMMATypes::s32) { ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); } @@ -1056,14 +1053,10 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, StringAttr attrName = attr.getName(); // Kernel function attribute should be attached to functions. if (attrName == NVVMDialect::getKernelFuncAttrName()) { - auto funcOp = dyn_cast(op); - if (!funcOp) { + if (!isa(op)) { return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() << "' attribute attached to unexpected op"; } - if (!funcOp.getResultTypes().empty()) { - return op->emitError() << "kernel function cannot have results"; - } } // If maxntid and reqntid exist, it must be an array with max 3 dim if (attrName == NVVMDialect::getMaxntidAttrName() || diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 26ba80cba6ed58..a8ae4d97888c90 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -574,10 +574,3 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} { llvm.return } - -// ----- - -// expected-error @below{{kernel function cannot have results}} -llvm.func @kernel_with_result(%i: i32) -> i32 attributes {nvvm.kernel} { - llvm.return %i : i32 -}