diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index b9d1a4ea9..5081677a0 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -610,6 +610,11 @@ struct SgLoadTileOpPattern : public XeOneToNConversion { // TODO: move these two into architecture abstracture in future. const int SIMD_WIDTH_IN_BITS = 32; int factor = SIMD_WIDTH_IN_BITS / elemTy.getIntOrFloatBitWidth(); + // TODO: use uArch for this? + auto isLowPrecision = [](unsigned int width) -> bool { + bool isPowerOf2 = (width & (width - 1)) == 0; + return isPowerOf2 & (width < 32) & (width > 1); + }; if (isForDPASB(op) && factor > 1) vnniAttr = mlir::UnitAttr::get(ctx); @@ -621,7 +626,7 @@ struct SgLoadTileOpPattern : public XeOneToNConversion { auto elemWidth = elemTy.getIntOrFloatBitWidth(); if (elemWidth == 32) { transposeAttr = rewriter.getDenseI64ArrayAttr({1, 0}); - } else if (elemWidth == 16 && vnniAttr) { + } else if (isLowPrecision(elemWidth) && vnniAttr) { transposeAttr = rewriter.getDenseI64ArrayAttr({1, 0}); transposeBitWidthAttr = rewriter.getI32IntegerAttr(32); vnniAttr = nullptr;