diff --git a/lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp b/lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp index 524940561..d945ad0da 100644 --- a/lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp +++ b/lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp @@ -39,6 +39,17 @@ bool tensor_shape_compatible_with_shard(Operation *op, LayoutAttr layout) { return (MTiles % gridR == 0) && (KTIles % gridC == 0); } +bool cantChangeOutputLayout(Operation *op) { + // Only TTIR ops. + if (not llvm::isa(op)) { + return true; + } + if (llvm::isa(op)) { + return true; + } + return false; +} + bool LegalGridAnalysis::applyOverrides() { // Lookup grid size overrides based on location information for current // operation. @@ -74,23 +85,22 @@ void LegalGridAnalysis::analysisImplementation() { // This implementation is a placeholder and is meant to just enable testing of // other components. - // Process only TTIR ops. - if (not llvm::isa(op)) { - return; - } // Skip operations that don't have output tensors. if (op->getNumResults() == 0) { return; } - if (llvm::isa(op)) { - return; - } // Get output tensor type. RankedTensorType tensorType = mlir::cast(op->getResult(0).getType()); LayoutAttr layout = mlir::cast(tensorType.getEncoding()); + // Return existing layout if it is not possible to change it. + if (cantChangeOutputLayout(op)) { + analysisResult.push_back(layout); + return; + } + // DRAM // No grid is set since the tensor is not sharded. // TODO(odjuricic): We need to set grid here since it will be used as the