From d6d6aea54b2e2842dc25c5e14ce6f1a3b4b92469 Mon Sep 17 00:00:00 2001 From: Nikola Obradovic <132568163+nobradovictt@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:40:59 +0200 Subject: [PATCH] [Optimizer/TTNN] Migrating Optimizer to TTNN (#944) --- .../ttmlir/Dialect/TTIR/Transforms/Passes.h | 1 - .../ttmlir/Dialect/TTIR/Transforms/Passes.td | 26 --- .../Analysis/DFShardingPolicy.h | 17 +- .../Dialect/{TTIR => TTNN}/Analysis/Edge.h | 14 +- .../Analysis/LegalGridAnalysis.h | 16 +- .../Analysis/OpConfigAnalysis.h | 26 +-- .../Analysis/ShardChainConfig.h | 26 +-- .../{TTIR => TTNN}/Analysis/ShardSolver.h | 56 +++--- .../Analysis/ShardingAnalysis.h | 30 +-- .../Analysis/TTNNAnalysis.h} | 16 +- .../ttmlir/Dialect/TTNN/Transforms/Passes.h | 1 + .../ttmlir/Dialect/TTNN/Transforms/Passes.td | 26 +++ lib/CAPI/CMakeLists.txt | 2 +- lib/CMakeLists.txt | 2 +- lib/Dialect/TTIR/CMakeLists.txt | 1 - lib/Dialect/TTIR/Transforms/CMakeLists.txt | 1 - lib/Dialect/TTMetal/Pipelines/CMakeLists.txt | 2 +- .../{TTIR => TTNN}/Analysis/CMakeLists.txt | 6 +- .../Analysis/DFShardingPolicy.cpp | 34 ++-- .../Analysis/LegalGridAnalysis.cpp | 55 +++--- .../Analysis/OpConfigAnalysis.cpp | 6 +- .../Analysis/ShardChainConfig.cpp | 13 +- .../{TTIR => TTNN}/Analysis/ShardSolver.cpp | 40 ++-- .../Analysis/ShardingAnalysis.cpp | 19 +- lib/Dialect/TTNN/CMakeLists.txt | 1 + lib/Dialect/TTNN/Pipelines/CMakeLists.txt | 2 +- lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 11 +- lib/Dialect/TTNN/Transforms/CMakeLists.txt | 1 + .../{TTIR => TTNN}/Transforms/Optimizer.cpp | 182 +++++++++++------- lib/Scheduler/Scheduler.cpp | 17 +- lib/SharedLib/CMakeLists.txt | 2 +- test/ttmlir/Dialect/TTIR/test_grid_set.mlir | 11 -- test/ttmlir/Dialect/TTNN/test_grid_set.mlir | 22 +++ test/unittests/TestScheduler/CMakeLists.txt | 1 + 34 files changed, 393 insertions(+), 293 deletions(-) rename include/ttmlir/Dialect/{TTIR => TTNN}/Analysis/DFShardingPolicy.h (66%) rename include/ttmlir/Dialect/{TTIR => TTNN}/Analysis/Edge.h (73%) rename include/ttmlir/Dialect/{TTIR => TTNN}/Analysis/LegalGridAnalysis.h (76%) rename include/ttmlir/Dialect/{TTIR => TTNN}/Analysis/OpConfigAnalysis.h (50%) rename include/ttmlir/Dialect/{TTIR => TTNN}/Analysis/ShardChainConfig.h (70%) rename include/ttmlir/Dialect/{TTIR => TTNN}/Analysis/ShardSolver.h (83%) rename include/ttmlir/Dialect/{TTIR => TTNN}/Analysis/ShardingAnalysis.h (59%) rename include/ttmlir/Dialect/{TTIR/Analysis/TTIRAnalysis.h => TTNN/Analysis/TTNNAnalysis.h} (80%) rename lib/Dialect/{TTIR => TTNN}/Analysis/CMakeLists.txt (77%) rename lib/Dialect/{TTIR => TTNN}/Analysis/DFShardingPolicy.cpp (87%) rename lib/Dialect/{TTIR => TTNN}/Analysis/LegalGridAnalysis.cpp (80%) rename lib/Dialect/{TTIR => TTNN}/Analysis/OpConfigAnalysis.cpp (81%) rename lib/Dialect/{TTIR => TTNN}/Analysis/ShardChainConfig.cpp (75%) rename lib/Dialect/{TTIR => TTNN}/Analysis/ShardSolver.cpp (93%) rename lib/Dialect/{TTIR => TTNN}/Analysis/ShardingAnalysis.cpp (73%) rename lib/Dialect/{TTIR => TTNN}/Transforms/Optimizer.cpp (53%) delete mode 100644 test/ttmlir/Dialect/TTIR/test_grid_set.mlir create mode 100644 test/ttmlir/Dialect/TTNN/test_grid_set.mlir diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.h b/include/ttmlir/Dialect/TTIR/Transforms/Passes.h index 1331caf93..dd3772c37 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.h +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.h @@ -5,7 +5,6 @@ #ifndef TTMLIR_DIALECT_TTIR_TRANSFORMS_PASSES_H #define TTMLIR_DIALECT_TTIR_TRANSFORMS_PASSES_H -#include "ttmlir/Dialect/TT/Utils/OverrideParams.h" #include "ttmlir/Dialect/TTIR/IR/TTIR.h" #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 3e5298440..1cee4cbb5 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -108,32 +108,6 @@ def TTIRAllocate: Pass<"ttir-allocate", "::mlir::ModuleOp"> { }]; } -def TTIROptimizer: Pass<"ttir-optimizer", "::mlir::ModuleOp"> { - let summary = "Determine op configurations for maximum performance."; - let description = [{ - Go through the ops, set sharding specs for each op based on sharding analysis, - by updating layout attribute of each op. - }]; - let options = [ - Option<"overrideOutputLayout", "override-output-layout", - "llvm::StringMap", - /*default=*/"llvm::StringMap()", - "Override output tensor layout for specific ops.">, - Option<"shardingPassEnabled", "sharding-pass-enabled", - "bool", - /*default=*/"false", - "Enable sharding pass.">, - Option<"reshardingEnabled", "resharding-enabled", - "bool", - /*default=*/"false", - "Resharding pass. Temp disabled till we support all types of shard specs.">, - Option<"maxLegalLayouts", "max-legal-layouts", - "int64_t", - /*default=*/"64", - "Override maximum number of legal layouts for grid analysis."> - ]; -} - def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { let summary = "Load system desc."; let description = [{ diff --git a/include/ttmlir/Dialect/TTIR/Analysis/DFShardingPolicy.h b/include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h similarity index 66% rename from include/ttmlir/Dialect/TTIR/Analysis/DFShardingPolicy.h rename to include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h index 38eae455e..75ef796c9 100644 --- a/include/ttmlir/Dialect/TTIR/Analysis/DFShardingPolicy.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h @@ -2,13 +2,13 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_DFSHARDINGPOLICY_H -#define TTMLIR_DIALECT_TTIR_ANALYSIS_DFSHARDINGPOLICY_H +#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_DFSHARDINGPOLICY_H +#define TTMLIR_DIALECT_TTNN_ANALYSIS_DFSHARDINGPOLICY_H #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "ttmlir/Dialect/TTIR/Analysis/ShardChainConfig.h" +#include "ttmlir/Dialect/TTNN/Analysis/ShardChainConfig.h" -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { // Process ops in DFS schedulable order and build shard chain configs. // Schedule is also produced as a side effect of sharding. @@ -17,14 +17,15 @@ class DFShardingPolicy { private: Operation *rootOp; std::vector *shardChainConfigs; - llvm::DenseMap> legalLayouts; + llvm::DenseMap> legalLayouts; llvm::DenseMap> *schedule; unsigned usableL1CacheSize = 0; public: DFShardingPolicy( Operation *rootOp, std::vector &shardChainConfigs, - const llvm::DenseMap> &legalLayouts, + const llvm::DenseMap> + &legalLayouts, llvm::DenseMap> &schedule, unsigned usableL1CacheSize) : rootOp(rootOp), shardChainConfigs(&shardChainConfigs), @@ -34,6 +35,6 @@ class DFShardingPolicy { void run(); }; -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn -#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_DFSHARDINGPOLICY_H +#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_DFSHARDINGPOLICY_H diff --git a/include/ttmlir/Dialect/TTIR/Analysis/Edge.h b/include/ttmlir/Dialect/TTNN/Analysis/Edge.h similarity index 73% rename from include/ttmlir/Dialect/TTIR/Analysis/Edge.h rename to include/ttmlir/Dialect/TTNN/Analysis/Edge.h index 56597bdb1..1ee9f801f 100644 --- a/include/ttmlir/Dialect/TTIR/Analysis/Edge.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/Edge.h @@ -2,12 +2,12 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_EDGE_H -#define TTMLIR_DIALECT_TTIR_ANALYSIS_EDGE_H +#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_EDGE_H +#define TTMLIR_DIALECT_TTNN_ANALYSIS_EDGE_H #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { struct Edge { Operation *producerOp = nullptr; Operation *consumerOp = nullptr; @@ -23,11 +23,11 @@ struct Edge { } }; -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn namespace std { -template <> struct hash { - size_t operator()(const mlir::tt::ttir::Edge &edge) const noexcept { +template <> struct hash { + size_t operator()(const mlir::tt::ttnn::Edge &edge) const noexcept { llvm::hash_code code = llvm::hash_value(edge.operandIndex); code = llvm::hash_combine(code, edge.producerOp); code = llvm::hash_combine(code, edge.consumerOp); @@ -36,4 +36,4 @@ template <> struct hash { }; } // namespace std -#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_EDGE_H +#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_EDGE_H diff --git a/include/ttmlir/Dialect/TTIR/Analysis/LegalGridAnalysis.h b/include/ttmlir/Dialect/TTNN/Analysis/LegalGridAnalysis.h similarity index 76% rename from include/ttmlir/Dialect/TTIR/Analysis/LegalGridAnalysis.h rename to include/ttmlir/Dialect/TTNN/Analysis/LegalGridAnalysis.h index 337163fe5..5a5db9057 100644 --- a/include/ttmlir/Dialect/TTIR/Analysis/LegalGridAnalysis.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/LegalGridAnalysis.h @@ -2,15 +2,15 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_LEGALGRIDANALYSIS_H -#define TTMLIR_DIALECT_TTIR_ANALYSIS_LEGALGRIDANALYSIS_H +#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H +#define TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TT/Utils/OverrideParams.h" -#include "ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h" +#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h" #include "llvm/ADT/StringMap.h" -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { struct LegalGridAnalysisInput { ChipDescAttr chipDesc; @@ -43,15 +43,15 @@ struct LegalGridAnalysisInput { }; class LegalGridAnalysis - : public TTIRAnalysis> { + : public TTNNAnalysis> { private: void analysisImplementation() override; bool applyOverrides() override; public: - LegalGridAnalysis(Operation *op) : TTIRAnalysis(op) {} + LegalGridAnalysis(Operation *op) : TTNNAnalysis(op) {} }; -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn -#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_LEGALGRIDANALYSIS_H +#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H diff --git a/include/ttmlir/Dialect/TTIR/Analysis/OpConfigAnalysis.h b/include/ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h similarity index 50% rename from include/ttmlir/Dialect/TTIR/Analysis/OpConfigAnalysis.h rename to include/ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h index 13c2478c9..5e0e79580 100644 --- a/include/ttmlir/Dialect/TTIR/Analysis/OpConfigAnalysis.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h @@ -2,25 +2,27 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_OPCONFIGANALYSIS_H -#define TTMLIR_DIALECT_TTIR_ANALYSIS_OPCONFIGANALYSIS_H +#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_OPCONFIGANALYSIS_H +#define TTMLIR_DIALECT_TTNN_ANALYSIS_OPCONFIGANALYSIS_H #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h" +#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h" -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { struct OpConfigAnalysisInput { - llvm::DenseMap> legalGrids; + llvm::DenseMap> legalGrids; OpConfigAnalysisInput() : legalGrids() {} OpConfigAnalysisInput( - const llvm::DenseMap> &&legalGrids) + const llvm::DenseMap> + &&legalGrids) : legalGrids(std::move(legalGrids)) {} OpConfigAnalysisInput( - const llvm::DenseMap> &legalGrids) + const llvm::DenseMap> + &legalGrids) : legalGrids(legalGrids) {} bool operator==(const OpConfigAnalysisInput &rhs) const { @@ -35,16 +37,16 @@ struct OpConfigAnalysisInput { // Determine optimal configuration for each op. // class OpConfigAnalysis - : public TTIRAnalysis> { + : public TTNNAnalysis> { private: void analysisImplementation() override; bool applyOverrides() override; public: - OpConfigAnalysis(Operation *op) : TTIRAnalysis(op) {} + OpConfigAnalysis(Operation *op) : TTNNAnalysis(op) {} }; -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn -#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_OPCONFIGANALYSIS_H +#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_OPCONFIGANALYSIS_H diff --git a/include/ttmlir/Dialect/TTIR/Analysis/ShardChainConfig.h b/include/ttmlir/Dialect/TTNN/Analysis/ShardChainConfig.h similarity index 70% rename from include/ttmlir/Dialect/TTIR/Analysis/ShardChainConfig.h rename to include/ttmlir/Dialect/TTNN/Analysis/ShardChainConfig.h index 16ad2da23..5b61c558b 100644 --- a/include/ttmlir/Dialect/TTIR/Analysis/ShardChainConfig.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/ShardChainConfig.h @@ -2,19 +2,19 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDCHAINCONFIG_H -#define TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDCHAINCONFIG_H +#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDCHAINCONFIG_H +#define TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDCHAINCONFIG_H #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/Analysis/ShardSolver.h" +#include "ttmlir/Dialect/TTNN/Analysis/ShardSolver.h" #include -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { struct ShardSpec { Operation *op; uint tensorSplitFactor; - LayoutAttr layout; + tt::LayoutAttr layout; }; // Enum to track the state of the shard chain. @@ -36,12 +36,14 @@ class ShardChainConfig { public: ShardChainConfig() : shardSpecs(), state() {} - ShardSolver resolve( - const llvm::DenseMap> &legalLayouts, - unsigned usableL1CacheSize); + ShardSolver + resolve(const llvm::DenseMap> + &legalLayouts, + unsigned usableL1CacheSize); void build(); - void complete(const llvm::DenseMap &selectedOpLayout, - std::unordered_set &reshardedEdges); + void + complete(const llvm::DenseMap &selectedOpLayout, + std::unordered_set &reshardedEdges); bool isEmpty() { return shardSpecs.empty(); } void addShardSpec(ShardSpec &&spec) { @@ -56,6 +58,6 @@ class ShardChainConfig { } }; -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn -#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDCHAINCONFIG_H +#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDCHAINCONFIG_H diff --git a/include/ttmlir/Dialect/TTIR/Analysis/ShardSolver.h b/include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h similarity index 83% rename from include/ttmlir/Dialect/TTIR/Analysis/ShardSolver.h rename to include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h index 561627cc8..1999cbf46 100644 --- a/include/ttmlir/Dialect/TTIR/Analysis/ShardSolver.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h @@ -2,11 +2,11 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDSOLVER_H -#define TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDSOLVER_H +#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDSOLVER_H +#define TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDSOLVER_H #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/Analysis/Edge.h" +#include "ttmlir/Dialect/TTNN/Analysis/Edge.h" #include #include #include @@ -14,16 +14,16 @@ #include #include -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { struct ShardSpec; struct ShardSolverSolution { - llvm::DenseMap selectedOpLayout; + llvm::DenseMap selectedOpLayout; std::unordered_set reshardedEdges; ShardSolverSolution( - const llvm::DenseMap &selectedOpLayout, + const llvm::DenseMap &selectedOpLayout, const std::unordered_set &reshardedEdges) : selectedOpLayout(selectedOpLayout), reshardedEdges(reshardedEdges) {} }; @@ -39,7 +39,7 @@ class ShardSolver { struct RemainingLayoutAttrs { class Iterator { std::uint64_t i = 0; - std::vector const *p = nullptr; + std::vector const *p = nullptr; Bitset mask = 0; private: @@ -58,12 +58,12 @@ class ShardSolver { public: using iterator_category = std::input_iterator_tag; - using value_type = const LayoutAttr; - using difference_type = const LayoutAttr; - using pointer = const LayoutAttr *; - using reference = const LayoutAttr &; + using value_type = const tt::LayoutAttr; + using difference_type = const tt::LayoutAttr; + using pointer = const tt::LayoutAttr *; + using reference = const tt::LayoutAttr &; - Iterator(std::vector const *p, const Bitset &mask, + Iterator(std::vector const *p, const Bitset &mask, std::uint64_t i = 0) : i(i), p(p), mask(mask) { nextValid(); @@ -87,7 +87,8 @@ class ShardSolver { reference operator*() const { return (*p)[i]; } }; - RemainingLayoutAttrs(std::vector const &p, const Bitset &mask) + RemainingLayoutAttrs(std::vector const &p, + const Bitset &mask) : p(&p), mask(mask) {} Iterator begin() const { return Iterator(p, mask); } @@ -96,7 +97,7 @@ class ShardSolver { } size_t size() const { return mask.count(); } - std::vector const *p = nullptr; + std::vector const *p = nullptr; Bitset mask = 0; }; @@ -246,7 +247,8 @@ class ShardSolver { Paths paths; }; - const std::vector &getLegalLayouts(Operation *operation) const; + const std::vector & + getLegalLayouts(Operation *operation) const; void reset(); PathSet *getPathSetPt(const Edge &edge); @@ -269,21 +271,21 @@ class ShardSolver { void preprocessFirstOp(); bool checkShardCompatible(Operation *producerOp, - LayoutAttr const &producerLayout, + tt::LayoutAttr const &producerLayout, Operation *consumerOp, - LayoutAttr const &consumerLayout) const; + tt::LayoutAttr const &consumerLayout) const; public: - ShardSolver( - const llvm::DenseMap> &legalLayouts, - const std::vector &shardSpecs, - const llvm::DenseSet &shardedOps, - const unsigned usableL1CacheSize); + ShardSolver(const llvm::DenseMap> + &legalLayouts, + const std::vector &shardSpecs, + const llvm::DenseSet &shardedOps, + const unsigned usableL1CacheSize); RemainingLayoutAttrs at(Operation *operation) const; - void set(Operation *operation, LayoutAttr const &layout); + void set(Operation *operation, tt::LayoutAttr const &layout); private: - const llvm::DenseMap> *legalLayouts; + const llvm::DenseMap> *legalLayouts; const std::vector *shardSpecs; const llvm::DenseSet *shardedOps; unsigned usableL1CacheSize; @@ -296,10 +298,10 @@ class ShardSolver { std::unordered_map pathSetIds; std::unordered_map bitsetIds; - llvm::DenseMap selectedOpLayout; + llvm::DenseMap selectedOpLayout; std::unordered_set reshardedEdges; }; -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn -#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDSOLVER_H +#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDSOLVER_H diff --git a/include/ttmlir/Dialect/TTIR/Analysis/ShardingAnalysis.h b/include/ttmlir/Dialect/TTNN/Analysis/ShardingAnalysis.h similarity index 59% rename from include/ttmlir/Dialect/TTIR/Analysis/ShardingAnalysis.h rename to include/ttmlir/Dialect/TTNN/Analysis/ShardingAnalysis.h index 4b179c35d..a25738c2e 100644 --- a/include/ttmlir/Dialect/TTIR/Analysis/ShardingAnalysis.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/ShardingAnalysis.h @@ -2,28 +2,29 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDINGANALYSIS_H -#define TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDINGANALYSIS_H +#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H +#define TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "ttmlir/Dialect/TTIR/Analysis/Edge.h" -#include "ttmlir/Dialect/TTIR/Analysis/ShardChainConfig.h" -#include "ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h" +#include "ttmlir/Dialect/TTNN/Analysis/Edge.h" +#include "ttmlir/Dialect/TTNN/Analysis/ShardChainConfig.h" +#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h" -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { enum class ShardingPolicyType { DFSharding, }; struct ShardingAnalysisInput { - llvm::DenseMap> legalLayouts; + llvm::DenseMap> legalLayouts; unsigned usableL1CacheSize = 0; ShardingAnalysisInput() : legalLayouts() {} ShardingAnalysisInput( - const llvm::DenseMap> &legalLayouts, + const llvm::DenseMap> + &legalLayouts, unsigned usableL1CacheSize) : legalLayouts(legalLayouts), usableL1CacheSize(usableL1CacheSize) {} @@ -37,14 +38,15 @@ struct ShardingAnalysisInput { }; struct ShardingAnalysisResult { - llvm::DenseMap> legalLayouts; + llvm::DenseMap> legalLayouts; std::unordered_set reshardedEdges; llvm::DenseMap> schedule; ShardingAnalysisResult() : legalLayouts(), reshardedEdges(), schedule() {} ShardingAnalysisResult( - const llvm::DenseMap> &legalLayouts, + const llvm::DenseMap> + &legalLayouts, const std::unordered_set &reshardedEdges) : legalLayouts(legalLayouts), reshardedEdges(reshardedEdges) {} }; @@ -52,7 +54,7 @@ struct ShardingAnalysisResult { // Determine shard chain configs. // class ShardingAnalysis - : public TTIRAnalysis { + : public TTNNAnalysis { private: void analysisImplementation() override; @@ -60,8 +62,8 @@ class ShardingAnalysis std::vector shardChainConfigs; public: - ShardingAnalysis(Operation *op) : TTIRAnalysis(op) {} + ShardingAnalysis(Operation *op) : TTNNAnalysis(op) {} }; -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn -#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDINGANALYSIS_H +#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H diff --git a/include/ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h b/include/ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h similarity index 80% rename from include/ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h rename to include/ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h index 1c0bb13f4..622272e00 100644 --- a/include/ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h @@ -2,22 +2,22 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_TTIRANALYSIS_H -#define TTMLIR_DIALECT_TTIR_ANALYSIS_TTIRANALYSIS_H +#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_TTNNANALYSIS_H +#define TTMLIR_DIALECT_TTNN_ANALYSIS_TTNNANALYSIS_H #include "mlir/IR/Operation.h" -namespace mlir::tt::ttir { -// Base class for all TTIR analyses. +namespace mlir::tt::ttnn { +// Base class for all TTNN analyses. // -template class TTIRAnalysis { +template class TTNNAnalysis { protected: Operation *op; bool isValid = false; R analysisResult; I analysisInput; - TTIRAnalysis(Operation *op) : op(op) {} + TTNNAnalysis(Operation *op) : op(op) {} // Actual implementation of the analysis. // Must be implemented by every analysis type. @@ -31,7 +31,7 @@ template class TTIRAnalysis { virtual bool applyOverrides() = 0; public: - virtual ~TTIRAnalysis() {}; + virtual ~TTNNAnalysis() {}; // Initialize the analysis with the input if needed. // @@ -70,6 +70,6 @@ template class TTIRAnalysis { } } }; -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn #endif diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.h b/include/ttmlir/Dialect/TTNN/Transforms/Passes.h index ed8d2a964..6df19409a 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.h +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.h @@ -8,6 +8,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "ttmlir/Dialect/TT/Utils/OverrideParams.h" #include "ttmlir/Dialect/TTNN/IR/TTNN.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td index bb002f0f9..825154569 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td @@ -14,4 +14,30 @@ def TTNNDeallocate: Pass<"ttnn-deallocate", "::mlir::ModuleOp"> { }]; } +def TTNNOptimizer: Pass<"ttnn-optimizer", "::mlir::ModuleOp"> { + let summary = "Determine op configurations for maximum performance."; + let description = [{ + Go through the ops, set sharding specs for each op based on sharding analysis, + by updating layout attribute of each op. + }]; + let options = [ + Option<"overrideOutputLayout", "override-output-layout", + "llvm::StringMap", + /*default=*/"llvm::StringMap()", + "Override output tensor layout for specific ops.">, + Option<"shardingPassEnabled", "sharding-pass-enabled", + "bool", + /*default=*/"false", + "Enable sharding pass.">, + Option<"reshardingEnabled", "resharding-enabled", + "bool", + /*default=*/"false", + "Resharding pass. Temp disabled till we support all types of shard specs.">, + Option<"maxLegalLayouts", "max-legal-layouts", + "int64_t", + /*default=*/"64", + "Override maximum number of legal layouts for grid analysis."> + ]; +} + #endif diff --git a/lib/CAPI/CMakeLists.txt b/lib/CAPI/CMakeLists.txt index d3c6752b5..d256ad3a5 100644 --- a/lib/CAPI/CMakeLists.txt +++ b/lib/CAPI/CMakeLists.txt @@ -17,5 +17,5 @@ add_mlir_public_c_api_library(TTMLIRCAPI MLIRTTIRDialect MLIRTTKernelDialect MLIRTTIRTransforms - MLIRTTIRAnalysis + MLIRTTNNAnalysis ) diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 9d3162da2..c3dc3a4b7 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -18,7 +18,7 @@ MLIRTTDialect MLIRTTIRDialect MLIRTTIRTransforms TTMLIRConversions -MLIRTTIRAnalysis +MLIRTTNNAnalysis MLIRTTNNDialect MLIRTTNNTransforms MLIRTTKernelDialect diff --git a/lib/Dialect/TTIR/CMakeLists.txt b/lib/Dialect/TTIR/CMakeLists.txt index 920567e7c..0713751a0 100644 --- a/lib/Dialect/TTIR/CMakeLists.txt +++ b/lib/Dialect/TTIR/CMakeLists.txt @@ -1,4 +1,3 @@ add_subdirectory(IR) add_subdirectory(Pipelines) add_subdirectory(Transforms) -add_subdirectory(Analysis) diff --git a/lib/Dialect/TTIR/Transforms/CMakeLists.txt b/lib/Dialect/TTIR/Transforms/CMakeLists.txt index 29ac7ec80..f5fec45a8 100644 --- a/lib/Dialect/TTIR/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTIR/Transforms/CMakeLists.txt @@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRTTIRTransforms Constant.cpp Generic.cpp Layout.cpp - Optimizer.cpp Transforms.cpp Utility.cpp diff --git a/lib/Dialect/TTMetal/Pipelines/CMakeLists.txt b/lib/Dialect/TTMetal/Pipelines/CMakeLists.txt index 2a0958cf4..3a9c51eab 100644 --- a/lib/Dialect/TTMetal/Pipelines/CMakeLists.txt +++ b/lib/Dialect/TTMetal/Pipelines/CMakeLists.txt @@ -8,7 +8,7 @@ add_mlir_dialect_library(MLIRTTMetalPipelines MLIRTTIRDialect MLIRTTMetalDialect MLIRTTIRTransforms - MLIRTTIRAnalysis + MLIRTTNNAnalysis MLIRPass MLIRTransforms ) diff --git a/lib/Dialect/TTIR/Analysis/CMakeLists.txt b/lib/Dialect/TTNN/Analysis/CMakeLists.txt similarity index 77% rename from lib/Dialect/TTIR/Analysis/CMakeLists.txt rename to lib/Dialect/TTNN/Analysis/CMakeLists.txt index 2a4f132ab..20385e3b1 100644 --- a/lib/Dialect/TTIR/Analysis/CMakeLists.txt +++ b/lib/Dialect/TTNN/Analysis/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(MLIRTTIRAnalysis +add_mlir_dialect_library(MLIRTTNNAnalysis LegalGridAnalysis.cpp OpConfigAnalysis.cpp ShardingAnalysis.cpp @@ -10,8 +10,8 @@ add_mlir_dialect_library(MLIRTTIRAnalysis ${PROJECT_SOURCE_DIR}/include/ttmlir DEPENDS - MLIRTTIROpsIncGen - MLIRTTIRPassesIncGen + MLIRTTNNOpsIncGen + MLIRTTNNPassesIncGen MLIRTTOpsIncGen LINK_LIBS diff --git a/lib/Dialect/TTIR/Analysis/DFShardingPolicy.cpp b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp similarity index 87% rename from lib/Dialect/TTIR/Analysis/DFShardingPolicy.cpp rename to lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp index 3ab32e75c..434071a55 100644 --- a/lib/Dialect/TTIR/Analysis/DFShardingPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp @@ -2,11 +2,16 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttmlir/Dialect/TTIR/Analysis/DFShardingPolicy.h" -#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" +#include "ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Scheduler/Scheduler.h" -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { + +bool isMemoryManagementOp(mlir::Operation *op) { + return isa(op) || isa(op) || + isa(op); +} void DFShardingPolicy::run() { rootOp->walk([&](func::FuncOp func) { @@ -25,16 +30,16 @@ void DFShardingPolicy::run() { while (scheduler.hasUnscheduledOps()) { scheduleableOps = scheduler.getScheduleableOps(); - // Before starting a sharding chain, schedule ttir.to_layout ops first - // until they are exhausted from schedulable ops. + // Before starting a sharding chain, schedule layout/memory management ops + // first until they are exhausted from schedulable ops. // TODO(nobradovic) : - // We need to examine type of to_layout op and determine if for + // We need to examine type of memory op and determine if for // example we have a space in DRAM to perform this?(system->dram, double // check this) // if (shardChainConfigs->back().isEmpty()) { for (auto *op : scheduleableOps) { - if (isa(op)) { + if (isMemoryManagementOp(op)) { currentOp = op; break; } @@ -49,10 +54,10 @@ void DFShardingPolicy::run() { // scheduler.scheduleOp(currentOp); - // Skip sharding process if currentOp is a ttir.to_layout op. + // Skip starting sharding chain if currentOp is a memory management op. // if (shardChainConfigs->back().isEmpty() && - isa(currentOp)) { + isMemoryManagementOp(currentOp)) { currentOp = nullptr; continue; } @@ -88,7 +93,8 @@ void DFShardingPolicy::run() { // currentOp output tensor shard spec, nextOp exec and nextOp output // tensor. // - LayoutAttr currentOpLayout = legalLayouts.lookup(currentOp).front(); + tt::LayoutAttr currentOpLayout = + legalLayouts.lookup(currentOp).front(); assert(currentOpLayout.hasShardedL1TensorMemoryLayout()); llvm::ArrayRef currentOpOutputTensorShape = mlir::cast(currentOp->getResult(0).getType()) @@ -97,7 +103,7 @@ void DFShardingPolicy::run() { currentOpOutputTensorShape, currentOpLayout, currentOpLayout.getMemorySpace()); - LayoutAttr nextOpLayout = legalLayouts.lookup(nextOp).front(); + tt::LayoutAttr nextOpLayout = legalLayouts.lookup(nextOp).front(); assert(nextOpLayout.hasShardedL1TensorMemoryLayout()); llvm::ArrayRef nextOpOutputTensorShape = mlir::cast(nextOp->getResult(0).getType()) @@ -129,10 +135,10 @@ void DFShardingPolicy::run() { .getDefiningOp() ->getResult(0) .getType()); - LayoutAttr firstOpInputLayout = mlir::cast( + tt::LayoutAttr firstOpInputLayout = mlir::cast( firstOpInputTensorType.getEncoding()); - LayoutAttr firstOpInputShardedLayout = + tt::LayoutAttr firstOpInputShardedLayout = firstOpInputLayout .withMemorySpace(currentOp->getContext(), currentOpLayout.getMemorySpace()) @@ -207,4 +213,4 @@ void DFShardingPolicy::run() { } } -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp b/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp similarity index 80% rename from lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp rename to lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp index c0861bbd7..7bb6d27e2 100644 --- a/lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp @@ -2,18 +2,19 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttmlir/Dialect/TTIR/Analysis/LegalGridAnalysis.h" +#include "ttmlir/Dialect/TTNN/Analysis/LegalGridAnalysis.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNN.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { -bool mock_is_output_tensor_legal_for_op(Operation *op, LayoutAttr layout) { +bool mock_is_output_tensor_legal_for_op(Operation *op, tt::LayoutAttr layout) { // Placeholder, needs to be replaced with a call the the TTNN op interface. return true; } -bool tensor_shape_compatible_with_shard(Operation *op, LayoutAttr layout) { +bool tensor_shape_compatible_with_shard(Operation *op, tt::LayoutAttr layout) { // These constraints are implemented seperatelly in every TTNN op. // Almost nothing seems to be shared between EVERY op, so is hard to have any // logic here without the risk of discarding a valid configuraiton or modeling @@ -41,13 +42,21 @@ bool tensor_shape_compatible_with_shard(Operation *op, LayoutAttr layout) { } bool cantChangeOutputLayout(Operation *op) { - // Only TTIR ops. - if (not llvm::isa(op)) { + // Check if OP belongs to TTNN dialect. + // + if (!isa(op->getDialect())) { + return true; + } + + if (llvm::isa(op)) { return true; } - if (llvm::isa(op)) { + + if (llvm::isa(op) || llvm::isa(op) || + llvm::isa(op)) { return true; } + return false; } @@ -74,7 +83,7 @@ bool LegalGridAnalysis::applyOverrides() { LayoutOverrideParams override = gridOverride->getValue(); RankedTensorType tensorType = mlir::cast(op->getResult(0).getType()); - LayoutAttr layout = mlir::cast(tensorType.getEncoding()); + tt::LayoutAttr layout = mlir::cast(tensorType.getEncoding()); analysisResult.push_back( layout.withMemorySpace(op->getContext(), override.memorySpace) @@ -99,7 +108,7 @@ void LegalGridAnalysis::analysisImplementation() { // Get output tensor type. RankedTensorType tensorType = mlir::cast(op->getResult(0).getType()); - LayoutAttr layout = mlir::cast(tensorType.getEncoding()); + tt::LayoutAttr layout = mlir::cast(tensorType.getEncoding()); // Return existing layout if it is not possible to change it. if (cantChangeOutputLayout(op)) { @@ -111,9 +120,10 @@ void LegalGridAnalysis::analysisImplementation() { // No grid is set since the tensor is not sharded. // TODO(odjuricic): We need to set grid here since it will be used as the // compute gird. (not implemented in runtime atm) - LayoutAttr dram = + tt::LayoutAttr dram = layout.withMemorySpace(op->getContext(), MemorySpace::DeviceDRAM) - .withMemoryLayout(op->getContext(), TensorMemoryLayout::Interleaved) + .withMemoryLayout(op->getContext(), + tt::TensorMemoryLayout::Interleaved) .withGrid(op->getContext(), tensorType, GridAttr::get(op->getContext(), analysisInput.maxGrid.getShape())); @@ -122,9 +132,10 @@ void LegalGridAnalysis::analysisImplementation() { } // L1 Interleaved (same as above). - LayoutAttr l1Interleaved = + tt::LayoutAttr l1Interleaved = layout.withMemorySpace(op->getContext(), MemorySpace::DeviceL1) - .withMemoryLayout(op->getContext(), TensorMemoryLayout::Interleaved) + .withMemoryLayout(op->getContext(), + tt::TensorMemoryLayout::Interleaved) .withGrid(op->getContext(), tensorType, GridAttr::get(op->getContext(), analysisInput.maxGrid.getShape())); @@ -133,9 +144,9 @@ void LegalGridAnalysis::analysisImplementation() { } // L1 Sharded - LayoutAttr shardedBase = + tt::LayoutAttr shardedBase = layout.withMemorySpace(op->getContext(), MemorySpace::DeviceL1); - std::vector shardedResults; + std::vector shardedResults; // Block Sharded for (auto width = 1; width <= analysisInput.maxGrid.getShape()[0]; ++width) { @@ -146,7 +157,7 @@ void LegalGridAnalysis::analysisImplementation() { .withGrid(op->getContext(), tensorType, GridAttr::get(op->getContext(), {width, height})) .withMemoryLayout(op->getContext(), - TensorMemoryLayout::BlockSharded)); + tt::TensorMemoryLayout::BlockSharded)); } } @@ -161,7 +172,7 @@ void LegalGridAnalysis::analysisImplementation() { .withGrid(op->getContext(), tensorType, GridAttr::get(op->getContext(), {height, 1})) .withMemoryLayout(op->getContext(), - TensorMemoryLayout::HeightSharded)); + tt::TensorMemoryLayout::HeightSharded)); } // Width Sharded @@ -171,13 +182,13 @@ void LegalGridAnalysis::analysisImplementation() { .withGrid(op->getContext(), tensorType, GridAttr::get(op->getContext(), {1, width})) .withMemoryLayout(op->getContext(), - TensorMemoryLayout::WidthSharded)); + tt::TensorMemoryLayout::WidthSharded)); } // Filter layouts based on output tensor legality for current op. shardedResults.erase( std::remove_if(shardedResults.begin(), shardedResults.end(), - [this](LayoutAttr layout) { + [this](tt::LayoutAttr layout) { return !tensor_shape_compatible_with_shard(op, layout) || !mock_is_output_tensor_legal_for_op(op, layout); }), @@ -185,7 +196,7 @@ void LegalGridAnalysis::analysisImplementation() { // Pick top largest sharded grids. std::sort(shardedResults.begin(), shardedResults.end(), - [](LayoutAttr a, LayoutAttr b) { + [](tt::LayoutAttr a, tt::LayoutAttr b) { return a.getGrid().getShape()[0] * a.getGrid().getShape()[1] > b.getGrid().getShape()[0] * b.getGrid().getShape()[1]; }); @@ -196,4 +207,4 @@ void LegalGridAnalysis::analysisImplementation() { std::min(analysisInput.maxShardedGrids, static_cast(shardedResults.size()))); } -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTIR/Analysis/OpConfigAnalysis.cpp b/lib/Dialect/TTNN/Analysis/OpConfigAnalysis.cpp similarity index 81% rename from lib/Dialect/TTIR/Analysis/OpConfigAnalysis.cpp rename to lib/Dialect/TTNN/Analysis/OpConfigAnalysis.cpp index ae1d64cf8..d4a79d64e 100644 --- a/lib/Dialect/TTIR/Analysis/OpConfigAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/OpConfigAnalysis.cpp @@ -2,9 +2,9 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttmlir/Dialect/TTIR/Analysis/OpConfigAnalysis.h" +#include "ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h" -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { bool OpConfigAnalysis::applyOverrides() { @@ -24,4 +24,4 @@ void OpConfigAnalysis::analysisImplementation() { } } } -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTIR/Analysis/ShardChainConfig.cpp b/lib/Dialect/TTNN/Analysis/ShardChainConfig.cpp similarity index 75% rename from lib/Dialect/TTIR/Analysis/ShardChainConfig.cpp rename to lib/Dialect/TTNN/Analysis/ShardChainConfig.cpp index 398592746..a70e93364 100644 --- a/lib/Dialect/TTIR/Analysis/ShardChainConfig.cpp +++ b/lib/Dialect/TTNN/Analysis/ShardChainConfig.cpp @@ -2,10 +2,10 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttmlir/Dialect/TTIR/Analysis/ShardChainConfig.h" -#include "ttmlir/Dialect/TTIR/Analysis/ShardSolver.h" +#include "ttmlir/Dialect/TTNN/Analysis/ShardChainConfig.h" +#include "ttmlir/Dialect/TTNN/Analysis/ShardSolver.h" -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { void ShardChainConfig::build() { assert(state == ShardChainState::InBuild); @@ -13,7 +13,8 @@ void ShardChainConfig::build() { } ShardSolver ShardChainConfig::resolve( - const llvm::DenseMap> &legalLayouts, + const llvm::DenseMap> + &legalLayouts, unsigned usableL1CacheSize) { assert(state == ShardChainState::Built); @@ -28,7 +29,7 @@ ShardSolver ShardChainConfig::resolve( } void ShardChainConfig::complete( - const llvm::DenseMap &selectedOpLayout, + const llvm::DenseMap &selectedOpLayout, std::unordered_set &reshardedEdges) { assert(state == ShardChainState::Resolved); for (auto &shardSpec : shardSpecs) { @@ -42,4 +43,4 @@ void ShardChainConfig::complete( state = ShardChainState::Completed; } -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTIR/Analysis/ShardSolver.cpp b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp similarity index 93% rename from lib/Dialect/TTIR/Analysis/ShardSolver.cpp rename to lib/Dialect/TTNN/Analysis/ShardSolver.cpp index 649f5d1f1..7e98be3a8 100644 --- a/lib/Dialect/TTIR/Analysis/ShardSolver.cpp +++ b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp @@ -2,16 +2,17 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttmlir/Dialect/TTIR/Analysis/ShardSolver.h" -#include "ttmlir/Dialect/TTIR/Analysis/ShardChainConfig.h" +#include "ttmlir/Dialect/TTNN/Analysis/ShardSolver.h" +#include "ttmlir/Dialect/TTNN/Analysis/ShardChainConfig.h" #include -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { ShardSolver::Bitset ShardSolver::kBitsetAll = ~kBitsetNone; ShardSolver::ShardSolver( - const llvm::DenseMap> &legalLayouts, + const llvm::DenseMap> + &legalLayouts, const std::vector &shardSpecs, const llvm::DenseSet &shardedOps, const unsigned usableL1CacheSize) @@ -67,7 +68,7 @@ bool ShardSolver::resolveStep() { for (const auto shardSpec : *shardSpecs) { Operation *consumerOp = shardSpec.op; Bitset *consumerBitset = getOrInsertBitset(consumerOp, kBitsetAll); - std::vector const &consumerLayouts = + std::vector const &consumerLayouts = getLegalLayouts(consumerOp); for (Edge edge : operandOpEdges[consumerOp]) { @@ -76,7 +77,7 @@ bool ShardSolver::resolveStep() { Operation *producerOp = edge.producerOp; Bitset *producerBitset = getOrInsertBitset(producerOp, kBitsetAll); - std::vector const &producerLayouts = + std::vector const &producerLayouts = getLegalLayouts(producerOp); assert(not(consumerLayouts.empty() && producerLayouts.empty())); @@ -91,7 +92,7 @@ bool ShardSolver::resolveStep() { for (std::uint64_t producerId = 0; producerId < producer_count; ++producerId) { // If the producer cannot accomodate this path, continue. - // Also if this is not the LayoutAttr we selected, continue. + // Also if this is not the tt::LayoutAttr we selected, continue. // if (!producerBitset->test(producerId)) { continue; @@ -189,13 +190,13 @@ void ShardSolver::preprocessFirstOp() { } Bitset *firstOpBitset = getOrInsertBitset(firstOp, kBitsetAll); - std::vector const &firstOpLayouts = getLegalLayouts(firstOp); + std::vector const &firstOpLayouts = getLegalLayouts(firstOp); Operation *operandOp = firstOp->getOperand(0).getDefiningOp(); RankedTensorType firstOpInputTensorType = mlir::cast(operandOp->getResult(0).getType()); - LayoutAttr firstOpInputLayout = - mlir::cast(firstOpInputTensorType.getEncoding()); + tt::LayoutAttr firstOpInputLayout = + mlir::cast(firstOpInputTensorType.getEncoding()); constexpr float tensorL1UsageCap = 0.8; for (size_t i = 0; i < firstOpLayouts.size(); ++i) { @@ -203,10 +204,10 @@ void ShardSolver::preprocessFirstOp() { continue; } - LayoutAttr firstOpLayout = firstOpLayouts[i]; + tt::LayoutAttr firstOpLayout = firstOpLayouts[i]; assert(firstOpLayout.hasShardedL1TensorMemoryLayout()); - LayoutAttr firstOpInputShardedLayout = + tt::LayoutAttr firstOpInputShardedLayout = firstOpInputLayout .withMemorySpace(firstOp->getContext(), firstOpLayout.getMemorySpace()) @@ -441,9 +442,9 @@ ShardSolver::Bitset *ShardSolver::getOrInsertBitset(Operation *op, // Returns vector of legal LayoutAttrs for passed in op. // -const std::vector & +const std::vector & ShardSolver::getLegalLayouts(Operation *op) const { - static std::vector nullLayouts; + static std::vector nullLayouts; const auto legalIt = legalLayouts->find(op); @@ -460,7 +461,7 @@ ShardSolver::RemainingLayoutAttrs ShardSolver::at(Operation *op) const { return layouts; } -void ShardSolver::set(Operation *op, LayoutAttr const &layout) { +void ShardSolver::set(Operation *op, tt::LayoutAttr const &layout) { assert(selectedOpLayout.count(op) == 0); selectedOpLayout[op] = layout; @@ -486,10 +487,9 @@ void ShardSolver::set(Operation *op, LayoutAttr const &layout) { updateSolver(op, true /*expand_root*/, true /*invokedBySet*/); } -bool ShardSolver::checkShardCompatible(Operation *producerOp, - LayoutAttr const &producerLayout, - Operation *consumerOp, - LayoutAttr const &consumerLayout) const { +bool ShardSolver::checkShardCompatible( + Operation *producerOp, tt::LayoutAttr const &producerLayout, + Operation *consumerOp, tt::LayoutAttr const &consumerLayout) const { // TEMP : Dummy mock implementation, will be replaced. // @@ -535,4 +535,4 @@ ShardSolverSolution const ShardSolver::finish() { assert(selectedOpLayout.size() == shardedOps->size()); return ShardSolverSolution(selectedOpLayout, reshardedEdges); } -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTIR/Analysis/ShardingAnalysis.cpp b/lib/Dialect/TTNN/Analysis/ShardingAnalysis.cpp similarity index 73% rename from lib/Dialect/TTIR/Analysis/ShardingAnalysis.cpp rename to lib/Dialect/TTNN/Analysis/ShardingAnalysis.cpp index d18300b9e..83cc1f074 100644 --- a/lib/Dialect/TTIR/Analysis/ShardingAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/ShardingAnalysis.cpp @@ -2,10 +2,10 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttmlir/Dialect/TTIR/Analysis/ShardingAnalysis.h" -#include "ttmlir/Dialect/TTIR/Analysis/DFShardingPolicy.h" +#include "ttmlir/Dialect/TTNN/Analysis/ShardingAnalysis.h" +#include "ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h" -namespace mlir::tt::ttir { +namespace mlir::tt::ttnn { bool ShardingAnalysis::applyOverrides() { @@ -15,11 +15,12 @@ bool ShardingAnalysis::applyOverrides() { return false; } -llvm::DenseMap> filterShardedOnly( - const llvm::DenseMap> &legalLayouts) { - llvm::DenseMap> shardedLayouts; +llvm::DenseMap> +filterShardedOnly(const llvm::DenseMap> + &legalLayouts) { + llvm::DenseMap> shardedLayouts; for (const auto &opLayouts : legalLayouts) { - std::vector opShardedLayouts; + std::vector opShardedLayouts; for (const auto &layout : opLayouts.second) { if (layout.hasShardedL1TensorMemoryLayout()) { opShardedLayouts.push_back(layout); @@ -54,7 +55,7 @@ void ShardingAnalysis::analysisImplementation() { assert(shardChainConfig.getState() == ShardChainState::Completed); for (const auto &shardSpec : shardChainConfig.getShardSpecs()) { analysisResult.legalLayouts[shardSpec.op] = - std::vector{shardSpec.layout}; + std::vector{shardSpec.layout}; } analysisResult.reshardedEdges.insert( @@ -62,4 +63,4 @@ void ShardingAnalysis::analysisImplementation() { shardChainConfig.getReshardedEdges().end()); } } -} // namespace mlir::tt::ttir +} // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/CMakeLists.txt b/lib/Dialect/TTNN/CMakeLists.txt index e7b9dd7c3..4385b925d 100644 --- a/lib/Dialect/TTNN/CMakeLists.txt +++ b/lib/Dialect/TTNN/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(IR) add_subdirectory(Pipelines) add_subdirectory(Transforms) add_subdirectory(Utils) +add_subdirectory(Analysis) diff --git a/lib/Dialect/TTNN/Pipelines/CMakeLists.txt b/lib/Dialect/TTNN/Pipelines/CMakeLists.txt index 114af95c6..71b531a92 100644 --- a/lib/Dialect/TTNN/Pipelines/CMakeLists.txt +++ b/lib/Dialect/TTNN/Pipelines/CMakeLists.txt @@ -9,7 +9,7 @@ add_mlir_dialect_library(MLIRTTNNPipelines MLIRTTNNDialect MLIRTTIRTransforms MLIRTTNNTransforms - MLIRTTIRAnalysis + MLIRTTNNAnalysis MLIRPass MLIRTransforms ) diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index fdd3a5eba..fb84388d5 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -38,12 +38,15 @@ void createTTNNPipelineTTIRPasses( void createTTNNPipelineAnalysisPasses( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { if (options.optimizerPassEnabled) { - ttir::TTIROptimizerOptions optimizerOptions; + ttnn::TTNNOptimizerOptions optimizerOptions; optimizerOptions.overrideOutputLayout = options.overrideOutputLayout; optimizerOptions.shardingPassEnabled = options.shardingPassEnabled; optimizerOptions.maxLegalLayouts = options.maxLegalLayouts; - pm.addPass(mlir::tt::ttir::createTTIROptimizer(optimizerOptions)); + pm.addPass(mlir::tt::ttnn::createTTNNOptimizer(optimizerOptions)); } + + // Dealloc pass for tensor memory deallocation after last use. + pm.addPass(createTTNNDeallocate()); } void createTTNNPipelineLoweringPasses( @@ -52,8 +55,6 @@ void createTTNNPipelineLoweringPasses( pm.addPass(createConvertTTIRToTTNNPass()); // Add pass to remove unused values. pm.addPass(mlir::createRemoveDeadValuesPass()); - // Dealloc pass for tensor memory deallocation after last use. - pm.addPass(createTTNNDeallocate()); } void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm, @@ -80,8 +81,8 @@ void createTTNNPipelineLoweringPassesFromString(OpPassManager &pm, void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { createTTNNPipelineTTIRPasses(pm, options); - createTTNNPipelineAnalysisPasses(pm, options); createTTNNPipelineLoweringPasses(pm, options); + createTTNNPipelineAnalysisPasses(pm, options); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/Transforms/CMakeLists.txt b/lib/Dialect/TTNN/Transforms/CMakeLists.txt index 88b4bff54..7232acdfd 100644 --- a/lib/Dialect/TTNN/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTNN/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTTNNTransforms Passes.cpp + Optimizer.cpp TTNNToCpp.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Dialect/TTIR/Transforms/Optimizer.cpp b/lib/Dialect/TTNN/Transforms/Optimizer.cpp similarity index 53% rename from lib/Dialect/TTIR/Transforms/Optimizer.cpp rename to lib/Dialect/TTNN/Transforms/Optimizer.cpp index de43ec771..a3649cab8 100644 --- a/lib/Dialect/TTIR/Transforms/Optimizer.cpp +++ b/lib/Dialect/TTNN/Transforms/Optimizer.cpp @@ -2,25 +2,23 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttmlir/Dialect/TTIR/Analysis/LegalGridAnalysis.h" -#include "ttmlir/Dialect/TTIR/Analysis/OpConfigAnalysis.h" -#include "ttmlir/Dialect/TTIR/Analysis/ShardingAnalysis.h" -#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" - +#include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" - -namespace mlir::tt::ttir { -#define GEN_PASS_DEF_TTIROPTIMIZER -#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" - -//===----------------------------------------------------------------------===// -// Optimizer pass -//===----------------------------------------------------------------------===// - -class TTIROptimizer : public impl::TTIROptimizerBase { +#include "mlir/IR/PatternMatch.h" +#include "ttmlir/Dialect/TTNN/Analysis/LegalGridAnalysis.h" +#include "ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h" +#include "ttmlir/Dialect/TTNN/Analysis/ShardingAnalysis.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" +#include "ttmlir/Dialect/TTNN/Transforms/Passes.h" +#include "ttmlir/Dialect/TTNN/Utils/Utils.h" + +namespace mlir::tt::ttnn { +#define GEN_PASS_DEF_TTNNOPTIMIZER +#include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" + +class TTNNOptimizer : public impl::TTNNOptimizerBase { public: - using impl::TTIROptimizerBase::TTIROptimizerBase; + using impl::TTNNOptimizerBase::TTNNOptimizerBase; void runOnOperation() final { // Generate legal OP configuration candidates. // Perform sharding analysis. @@ -39,13 +37,17 @@ class TTIROptimizer : public impl::TTIROptimizerBase { SystemDescAttr systemDesc = mlir::cast( moduleOp->getAttr(tt::SystemDescAttr::name)); ChipDescAttr chipDesc = systemDesc.getChipDescs()[0]; - llvm::DenseMap> legalLayouts; + llvm::DenseMap> legalLayouts; moduleOp->walk([&](Operation *op) { if (op->getNumResults() == 0) { return; } + if (!isa(op->getResult(0).getType())) { + return; + } + RankedTensorType tensorType = mlir::cast(op->getResult(0).getType()); LegalGridAnalysis legalGridAnalysis = @@ -91,7 +93,7 @@ class TTIROptimizer : public impl::TTIROptimizerBase { // Move DPS operand with the op. // - if (llvm::isa(nextOp)) { + if (isa(nextOp)) { nextOp->getOperands().back().getDefiningOp()->moveBefore(nextOp); } } @@ -107,6 +109,10 @@ class TTIROptimizer : public impl::TTIROptimizerBase { return; } + if (!isa(op->getResult(0).getType())) { + return; + } + RankedTensorType tensorType = mlir::cast(op->getResult(0).getType()); llvm::ArrayRef tensorShape = tensorType.getShape(); @@ -118,11 +124,48 @@ class TTIROptimizer : public impl::TTIROptimizerBase { RankedTensorType::get(tensorShape, tensorType.getElementType(), opConfigAnalysis.getResult().at(op)); + // Update the memory space and layout of the op. + // + tt::LayoutAttr ttLayoutAttr = + mlir::cast(newTensorType.getEncoding()); + op->getResult(0).setType(newTensorType); - if (llvm::isa(op)) { - // Update dps operand layout as well. + // Update DPS operand layout as well. + // + if (isa(op)) { + BufferType bufferType = + utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace()); + TensorMemoryLayout tensorMemoryLayout = + utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout()); + op->getOperands().back().setType(newTensorType); + EmptyOp emptyOp = + mlir::cast(op->getOperands().back().getDefiningOp()); + + emptyOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get( + op->getContext(), + TensorMemoryLayoutAttr::get(op->getContext(), + tensorMemoryLayout), + BufferTypeAttr::get(op->getContext(), bufferType))); + } + // TODO (nobradovic): Other memory management ops after lowering to + // TTNN will need to be special handled as well. Depends on ttnn + // layout attr refactor and lowering. + // + else if (isa(op)) { + BufferType bufferType = + utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace()); + TensorMemoryLayout tensorMemoryLayout = + utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout()); + // Update the device op with the new tensor type. + // + ttnn::ToDeviceOp toDeviceOp = llvm::cast(op); + toDeviceOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get( + op->getContext(), + ttnn::TensorMemoryLayoutAttr::get(op->getContext(), + tensorMemoryLayout), + ttnn::BufferTypeAttr::get(op->getContext(), bufferType))); } } }); @@ -152,9 +195,9 @@ class TTIROptimizer : public impl::TTIROptimizerBase { // to reflect consumerOp's output layout. If producerOp is not a // ToLayoutOp, insert a ToLayoutOp in between producerOp and consumerOp. // - if (llvm::isa(producerOp)) { - ttir::ToLayoutOp toLayoutOp = llvm::cast(producerOp); - LayoutAttr consumerOpOutputLayout = mlir::cast( + if (isa(producerOp)) { + ttnn::ToLayoutOp toLayoutOp = llvm::cast(producerOp); + tt::LayoutAttr consumerOpOutputLayout = mlir::cast( mlir::cast(consumerOp->getResult(0).getType()) .getEncoding()); @@ -162,8 +205,8 @@ class TTIROptimizer : public impl::TTIROptimizerBase { mlir::cast(toLayoutOp.getResult().getType()); llvm::ArrayRef toLayoutOpTensorShape = toLayoutOpTensorType.getShape(); - LayoutAttr toLayoutOpLayout = - mlir::cast(toLayoutOpTensorType.getEncoding()); + tt::LayoutAttr toLayoutOpLayout = + mlir::cast(toLayoutOpTensorType.getEncoding()); // TODO(nobradovic): Match memory space and layout of consumer op. This // actually needs to be properly resolved based on op type, output @@ -183,48 +226,53 @@ class TTIROptimizer : public impl::TTIROptimizerBase { toLayoutOp.getResult().setType(newTensorType); toLayoutOp.getOperands().back().setType(newTensorType); - } else { - LayoutAttr consumerOpOutputLayout = mlir::cast( - mlir::cast(consumerOp->getResult(0).getType()) - .getEncoding()); - - RankedTensorType producerOpTensorType = - mlir::cast(producerOp->getResult(0).getType()); - llvm::ArrayRef producerOpTensorShape = - producerOpTensorType.getShape(); - LayoutAttr producerOpLayout = - mlir::cast(producerOpTensorType.getEncoding()); - - // TODO(nobradovic): Match memory space and layout of consumer op. This - // actually needs to be properly resolved based on op type, output - // layout and other inputs. - // - RankedTensorType newTensorType = RankedTensorType::get( - producerOpTensorShape, producerOpTensorType.getElementType(), - producerOpLayout - .withElementType(consumerOp->getContext(), - consumerOpOutputLayout.getElementType()) - .withMemorySpace(consumerOp->getContext(), - consumerOpOutputLayout.getMemorySpace()) - .withMemoryLayout(consumerOp->getContext(), - consumerOpOutputLayout.getMemLayout()) - .withGrid(consumerOp->getContext(), producerOpTensorType, - consumerOpOutputLayout.getGrid())); - - OpBuilder builder(consumerOp); - - mlir::tensor::EmptyOp emptyOp = builder.create( - consumerOp->getLoc(), producerOpTensorShape, - producerOpTensorType.getElementType(), - mlir::cast(newTensorType.getEncoding())); - - Operation *toLayoutOp = builder.create( - consumerOp->getLoc(), newTensorType, producerOp->getResult(0), - emptyOp); - - consumerOp->setOperand(edge.operandIndex, toLayoutOp->getResult(0)); } + // TODO (nobradovic): Resharding needs to be reimplemented for TTNN + // dialect. + // else { + // tt::LayoutAttr consumerOpOutputLayout = mlir::cast( + // mlir::cast(consumerOp->getResult(0).getType()) + // .getEncoding()); + + // RankedTensorType producerOpTensorType = + // mlir::cast(producerOp->getResult(0).getType()); + // llvm::ArrayRef producerOpTensorShape = + // producerOpTensorType.getShape(); + // tt::LayoutAttr producerOpLayout = + // mlir::cast(producerOpTensorType.getEncoding()); + + // // TODO(nobradovic): Match memory space and layout of consumer op. + // This + // // actually needs to be properly resolved based on op type, output + // // layout and other inputs. + // // + // RankedTensorType newTensorType = RankedTensorType::get( + // producerOpTensorShape, producerOpTensorType.getElementType(), + // producerOpLayout + // .withElementType(consumerOp->getContext(), + // consumerOpOutputLayout.getElementType()) + // .withMemorySpace(consumerOp->getContext(), + // consumerOpOutputLayout.getMemorySpace()) + // .withMemoryLayout(consumerOp->getContext(), + // consumerOpOutputLayout.getMemLayout()) + // .withGrid(consumerOp->getContext(), producerOpTensorType, + // consumerOpOutputLayout.getGrid())); + + // OpBuilder builder(consumerOp); + + // mlir::tensor::EmptyOp emptyOp = builder.create( + // consumerOp->getLoc(), producerOpTensorShape, + // producerOpTensorType.getElementType(), + // mlir::cast(newTensorType.getEncoding())); + + // Operation *toLayoutOp = builder.create( + // consumerOp->getLoc(), newTensorType, producerOp->getResult(0), + // emptyOp); + + // consumerOp->setOperand(edge.operandIndex, toLayoutOp->getResult(0)); + // } } } }; -} // namespace mlir::tt::ttir + +} // namespace mlir::tt::ttnn diff --git a/lib/Scheduler/Scheduler.cpp b/lib/Scheduler/Scheduler.cpp index 63b3529be..25923fffd 100644 --- a/lib/Scheduler/Scheduler.cpp +++ b/lib/Scheduler/Scheduler.cpp @@ -5,19 +5,30 @@ #include "ttmlir/Scheduler/Scheduler.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "ttmlir/Dialect/TTIR/IR/TTIROpsDialect.h.inc" +#include "ttmlir/Dialect/TTNN/IR/TTNN.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include #include namespace mlir::tt::scheduler { +bool isTTNNOp(mlir::Operation *op) { + return isa(op->getDialect()) && op->getNumResults() > 0 && + !llvm::isa(op); +} + bool isTTIROp(mlir::Operation *op) { return isa(op->getDialect()); } +bool isTTShedulableOp(mlir::Operation *op) { + return isTTNNOp(op) || isTTIROp(op); +} + // Init the dependencies map of all ops which are TTIR ops Scheduler::Scheduler(func::FuncOp *func) { for (auto &op : func->getOps()) { - if (isTTIROp(&op)) { + if (isTTShedulableOp(&op)) { dependencies[&op] = {}; unscheduledOps.insert(&op); } @@ -26,7 +37,7 @@ Scheduler::Scheduler(func::FuncOp *func) { for (auto &op : func->getOps()) { // Skip non TTIR operations // Skip operations which do not implement DestinationStyleOpInterface - if (!isTTIROp(&op)) { + if (!isTTShedulableOp(&op)) { continue; } @@ -35,7 +46,7 @@ Scheduler::Scheduler(func::FuncOp *func) { for (mlir::Operation *use : result.getUsers()) { // Skip non TTIR operations // Skip operations which set the result - if (isTTIROp(use) && use->getResult(0) != result) { + if (isTTShedulableOp(use) && use->getResult(0) != result) { dependencies[use].push_back(&op); } } diff --git a/lib/SharedLib/CMakeLists.txt b/lib/SharedLib/CMakeLists.txt index ab7a4c4c0..d63f30eaa 100644 --- a/lib/SharedLib/CMakeLists.txt +++ b/lib/SharedLib/CMakeLists.txt @@ -20,7 +20,7 @@ set(TTMLIR_LIBS MLIRTTMetalDialect MLIRTTIRTransforms MLIRTTNNTransforms - MLIRTTIRAnalysis + MLIRTTNNAnalysis MLIRTTNNPipelines MLIRTTMetalPipelines TTMLIRTTNNToEmitC diff --git a/test/ttmlir/Dialect/TTIR/test_grid_set.mlir b/test/ttmlir/Dialect/TTIR/test_grid_set.mlir deleted file mode 100644 index 78caf2d55..000000000 --- a/test/ttmlir/Dialect/TTIR/test_grid_set.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout --ttir-optimizer %s | FileCheck %s -#any_device = #tt.operand_constraint -module attributes {} { - func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #dram>, interleaved> - // CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_2]]> - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> - } -} diff --git a/test/ttmlir/Dialect/TTNN/test_grid_set.mlir b/test/ttmlir/Dialect/TTNN/test_grid_set.mlir new file mode 100644 index 000000000..e91fca93e --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/test_grid_set.mlir @@ -0,0 +1,22 @@ +// RUN: ttmlir-opt --ttnn-optimizer %s | FileCheck %s +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #tt.memory_space +#system = #tt.memory_space +#system_desc = #tt.system_desc<[{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32]}], [0], [3 : i32], [ 0x0x0x0]> +#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #system>> +#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #dram>, interleaved> +module attributes {tt.device = #device, tt.system_desc = #system_desc} { + func.func @forward(%arg0: tensor<64x128xf32, #layout>, %arg1: tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0, %0) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #layout>, !tt.device<#device>) -> tensor<64x128xf32, #layout1> + %2 = "ttnn.to_device"(%1, %0) <{memory_config = #ttnn.memory_config<, >}> : (tensor<64x128xf32, #layout1>, !tt.device<#device>) -> tensor<64x128xf32, #layout1> + %3 = "ttnn.to_layout"(%arg1, %0) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #layout>, !tt.device<#device>) -> tensor<64x128xf32, #layout1> + %4 = "ttnn.to_device"(%3, %0) <{memory_config = #ttnn.memory_config<, >}> : (tensor<64x128xf32, #layout1>, !tt.device<#device>) -> tensor<64x128xf32, #layout1> + %5 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xf32, #layout1> + %6 = "ttnn.multiply"(%2, %4, %5) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #layout1>, tensor<64x128xf32, #layout1>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1> + // CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #dram>, interleaved> + // CHECK: %{{.+}} = "ttnn.multiply"{{.+}} -> tensor<64x128xf32, #[[LAYOUT_2]]> + %7 = "ttnn.to_memory_config"(%6, %0) : (tensor<64x128xf32, #layout1>, !tt.device<#device>) -> tensor<64x128xf32, #layout> + return %7 : tensor<64x128xf32, #layout> + } +} diff --git a/test/unittests/TestScheduler/CMakeLists.txt b/test/unittests/TestScheduler/CMakeLists.txt index da9229758..34a53e00e 100644 --- a/test/unittests/TestScheduler/CMakeLists.txt +++ b/test/unittests/TestScheduler/CMakeLists.txt @@ -7,5 +7,6 @@ target_link_libraries(SchedulerTests MLIR MLIRTTDialect MLIRTTIRDialect + MLIRTTNNPipelines MLIRScheduler )