From 4e86e50ce13406935990916685bbe1b1f0f9803b Mon Sep 17 00:00:00 2001 From: Charitha Saumya <136391709+charithaintc@users.noreply.github.com> Date: Tue, 30 Jul 2024 14:28:48 -0700 Subject: [PATCH] Fix lowering logic for low-precision types in XeTileToXeGPU (#817) --- lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index b9d1a4ea9..5081677a0 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -610,6 +610,11 @@ struct SgLoadTileOpPattern : public XeOneToNConversion { // TODO: move these two into architecture abstracture in future. const int SIMD_WIDTH_IN_BITS = 32; int factor = SIMD_WIDTH_IN_BITS / elemTy.getIntOrFloatBitWidth(); + // TODO: use uArch for this? + auto isLowPrecision = [](unsigned int width) -> bool { + bool isPowerOf2 = (width & (width - 1)) == 0; + return isPowerOf2 & (width < 32) & (width > 1); + }; if (isForDPASB(op) && factor > 1) vnniAttr = mlir::UnitAttr::get(ctx); @@ -621,7 +626,7 @@ struct SgLoadTileOpPattern : public XeOneToNConversion { auto elemWidth = elemTy.getIntOrFloatBitWidth(); if (elemWidth == 32) { transposeAttr = rewriter.getDenseI64ArrayAttr({1, 0}); - } else if (elemWidth == 16 && vnniAttr) { + } else if (isLowPrecision(elemWidth) && vnniAttr) { transposeAttr = rewriter.getDenseI64ArrayAttr({1, 0}); transposeBitWidthAttr = rewriter.getI32IntegerAttr(32); vnniAttr = nullptr;