From 0194e80b4f982d14169984d3562836825d42d3ec Mon Sep 17 00:00:00 2001 From: Filip Bajraktari Date: Thu, 31 Oct 2024 13:09:56 +0000 Subject: [PATCH] Added abstraction to MemoryLayoutAnalysis policy and removed unnecessary tests. --- .../Dialect/TTNN/Analysis/DFShardingPolicy.h | 17 ++--- .../TTNN/Analysis/L1InterleavedPolicy.h | 17 ++--- .../Analysis/MemoryLayoutAnalysisPolicy.h | 39 +++++++++++ .../Dialect/TTNN/Analysis/ShardingAnalysis.h | 67 ------------------- .../TTNN/Analysis/L1InterleavedPolicy.cpp | 4 +- .../TTNN/Analysis/MemoryLayoutAnalysis.cpp | 2 +- .../TTNN/all_l1_interleaved_policy.mlir | 46 ------------- .../Dialect/TTNN/mnist_l1_interleaved.mlir | 45 ------------- 8 files changed, 52 insertions(+), 185 deletions(-) create mode 100644 include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h delete mode 100644 include/ttmlir/Dialect/TTNN/Analysis/ShardingAnalysis.h delete mode 100644 test/ttmlir/Dialect/TTNN/all_l1_interleaved_policy.mlir delete mode 100644 test/ttmlir/Dialect/TTNN/mnist_l1_interleaved.mlir diff --git a/include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h b/include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h index 6ef8476b00..790773da52 100644 --- a/include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h @@ -7,20 +7,14 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h" +#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h" namespace mlir::tt::ttnn { // Process ops in DFS schedulable order and build shard chain configs. // Schedule is also produced as a side effect of sharding. // -class DFShardingPolicy { -private: - Operation *rootOp; - std::vector *l1ChainConfigs; - llvm::DenseMap> legalLayouts; - llvm::DenseMap> *schedule; - unsigned usableL1CacheSize = 0; - +class DFShardingPolicy : public MemoryLayoutAnalysisPolicy { public: DFShardingPolicy( Operation *rootOp, std::vector &l1ChainConfigs, @@ -28,11 +22,10 @@ class DFShardingPolicy { &legalLayouts, llvm::DenseMap> &schedule, unsigned usableL1CacheSize) - : rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs), - legalLayouts(legalLayouts), schedule(&schedule), - usableL1CacheSize(usableL1CacheSize) {} + : MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts, + schedule, usableL1CacheSize) {} - void run(const std::unordered_set &overrideReshardEdges); + void run(const std::unordered_set &overrideReshardEdges) final; }; } // namespace mlir::tt::ttnn diff --git a/include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h b/include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h index 816afed1fb..ac579c7ddf 100644 --- a/include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h @@ -7,17 +7,11 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h" +#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h" namespace mlir::tt::ttnn { -class L1InterleavedPolicy { -private: - Operation *rootOp; - std::vector *l1ChainConfigs; - llvm::DenseMap> legalLayouts; - llvm::DenseMap> *schedule; - unsigned usableL1CacheSize = 0; - +class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy { public: L1InterleavedPolicy( Operation *rootOp, std::vector &l1ChainConfigs, @@ -25,11 +19,10 @@ class L1InterleavedPolicy { &legalLayouts, llvm::DenseMap> &schedule, unsigned usableL1CacheSize) - : rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs), - legalLayouts(legalLayouts), schedule(&schedule), - usableL1CacheSize(usableL1CacheSize) {} + : MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts, + schedule, usableL1CacheSize) {} - void run(); + void run(const std::unordered_set &overrideReshardEdges) final; }; } // namespace mlir::tt::ttnn diff --git a/include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h b/include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h new file mode 100644 index 0000000000..46b6c48cbe --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H +#define TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h" + +namespace mlir::tt::ttnn { + +class MemoryLayoutAnalysisPolicy { +protected: + Operation *rootOp; + std::vector *l1ChainConfigs; + llvm::DenseMap> legalLayouts; + llvm::DenseMap> *schedule; + unsigned usableL1CacheSize = 0; + +public: + virtual ~MemoryLayoutAnalysisPolicy() {}; + + MemoryLayoutAnalysisPolicy( + Operation *rootOp, std::vector &l1ChainConfigs, + const llvm::DenseMap> + &legalLayouts, + llvm::DenseMap> &schedule, + unsigned usableL1CacheSize) + : rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs), + legalLayouts(legalLayouts), schedule(&schedule), + usableL1CacheSize(usableL1CacheSize) {} + + virtual void run(const std::unordered_set &overrideReshardEdges) = 0; +}; + +} // namespace mlir::tt::ttnn + +#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H diff --git a/include/ttmlir/Dialect/TTNN/Analysis/ShardingAnalysis.h b/include/ttmlir/Dialect/TTNN/Analysis/ShardingAnalysis.h deleted file mode 100644 index b5d6e7c2e8..0000000000 --- a/include/ttmlir/Dialect/TTNN/Analysis/ShardingAnalysis.h +++ /dev/null @@ -1,67 +0,0 @@ -// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -// -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H -#define TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "ttmlir/Dialect/TTNN/Analysis/Edge.h" -#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h" -#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h" - -namespace mlir::tt::ttnn { - -enum class PolicyType { DFSharding, L1Interleaved }; - -struct ShardingAnalysisInput { - llvm::DenseMap> legalLayouts; - unsigned usableL1CacheSize = 0; - - ShardingAnalysisInput() : legalLayouts() {} - - ShardingAnalysisInput( - const llvm::DenseMap> - &legalLayouts, - unsigned usableL1CacheSize) - : legalLayouts(legalLayouts), usableL1CacheSize(usableL1CacheSize) {} - - bool operator==(const ShardingAnalysisInput &rhs) const { - return legalLayouts == rhs.legalLayouts; - } - - bool operator!=(const ShardingAnalysisInput &rhs) const { - return !(*this == rhs); - } -}; - -struct ShardingAnalysisResult { - llvm::DenseMap> legalLayouts; - std::unordered_set reshardedEdges; - llvm::DenseMap> schedule; - - ShardingAnalysisResult() : legalLayouts(), reshardedEdges(), schedule() {} - - ShardingAnalysisResult( - const llvm::DenseMap> - &legalLayouts, - const std::unordered_set &reshardedEdges) - : legalLayouts(legalLayouts), reshardedEdges(reshardedEdges) {} -}; - -// Determine shard chain configs. -// -class ShardingAnalysis - : public TTNNAnalysis { - -private: - void analysisImplementation() override; - bool applyOverrides() override; - std::vector l1ChainConfigs; - -public: - ShardingAnalysis(Operation *op) : TTNNAnalysis(op) {} -}; -} // namespace mlir::tt::ttnn - -#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H diff --git a/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp b/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp index a98c091a24..4944a618a4 100644 --- a/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp @@ -5,11 +5,11 @@ #include "ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Scheduler/Scheduler.h" -#include namespace mlir::tt::ttnn { -void L1InterleavedPolicy::run() { +void L1InterleavedPolicy::run( + const std::unordered_set &overrideReshardEdges) { rootOp->walk([&](func::FuncOp func) { mlir::tt::scheduler::Scheduler scheduler(&func); llvm::SmallVector scheduleableOps; diff --git a/lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp b/lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp index caf0d2b5f8..24fbfb5bfe 100644 --- a/lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp @@ -68,7 +68,7 @@ void MemoryLayoutAnalysis::analysisImplementation() { L1InterleavedPolicy l1InterleavedPolicy( op, l1ChainConfigs, filterL1InterleavedOnly(analysisInput.legalLayouts), analysisResult.schedule, analysisInput.usableL1CacheSize); - l1InterleavedPolicy.run(); + l1InterleavedPolicy.run(analysisInput.overrideReshardEdges); break; } } diff --git a/test/ttmlir/Dialect/TTNN/all_l1_interleaved_policy.mlir b/test/ttmlir/Dialect/TTNN/all_l1_interleaved_policy.mlir deleted file mode 100644 index ce3cd0258e..0000000000 --- a/test/ttmlir/Dialect/TTNN/all_l1_interleaved_policy.mlir +++ /dev/null @@ -1,46 +0,0 @@ -// RUN: ttmlir-opt --ttnn-optimizer="memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s -#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> -#dram = #tt.memory_space -#system = #tt.memory_space -#system_desc = #tt.system_desc<[{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> -#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xbf16, #system>> -#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<128x96xbf16, #system>> -#layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x96xbf16, #system>> -#layout3 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<96x32xbf16, #system>> -#layout4 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x32xbf16, #system>> -#layout5 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x1x!tt.tile<32x32, bf16>, #dram>, interleaved> -#layout6 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x96xbf16, #dram>, interleaved> -#layout7 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x32xbf16, #dram>, interleaved> -module attributes {tt.device = #device, tt.system_desc = #system_desc} { - func.func @forward(%arg0: tensor<64x128xbf16, #layout>, %arg1: tensor<128x96xbf16, #layout1>, %arg2: tensor<64x96xbf16, #layout2>, %arg3: tensor<96x32xbf16, #layout3>, %arg4: tensor<64x32xbf16, #layout4>) -> tensor<64x32xbf16, #layout4> { - // CHECK: #[[L1_:.*]] = #tt.memory_space - // CHECK: #[[LAYOUT_6:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x12xbf16, #l1_>, interleaved> - // CHECK: #[[LAYOUT_7:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x4xbf16, #l1_>, interleaved> - %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x1>>>}> : (tensor<64x128xbf16, #layout>, !tt.device<#device>) -> tensor<64x128xbf16, #layout5> - %2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x1>>>}> : (tensor<128x96xbf16, #layout1>, !tt.device<#device>) -> tensor<128x96xbf16, #layout5> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x96>>>, shape = #ttnn.shape<64x96>}> : (!tt.device<#device>) -> tensor<64x96xbf16, #layout6> - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_6]]> - %4 = "ttnn.matmul"(%1, %2, %3) : (tensor<64x128xbf16, #layout5>, tensor<128x96xbf16, #layout5>, tensor<64x96xbf16, #layout6>) -> tensor<64x96xbf16, #layout6> - %5 = "ttnn.to_layout"(%arg2, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x1>>>}> : (tensor<64x96xbf16, #layout2>, !tt.device<#device>) -> tensor<64x96xbf16, #layout5> - %6 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x96>>>, shape = #ttnn.shape<64x96>}> : (!tt.device<#device>) -> tensor<64x96xbf16, #layout6> - // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_6]]> - %7 = "ttnn.add"(%4, %5, %6) <{operandSegmentSizes = array}> : (tensor<64x96xbf16, #layout6>, tensor<64x96xbf16, #layout5>, tensor<64x96xbf16, #layout6>) -> tensor<64x96xbf16, #layout6> - %8 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x96>>>, shape = #ttnn.shape<64x96>}> : (!tt.device<#device>) -> tensor<64x96xbf16, #layout6> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_6]]> - %9 = "ttnn.relu"(%7, %8) <{operandSegmentSizes = array}> : (tensor<64x96xbf16, #layout6>, tensor<64x96xbf16, #layout6>) -> tensor<64x96xbf16, #layout6> - %10 = "ttnn.to_layout"(%arg3, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x1>>>}> : (tensor<96x32xbf16, #layout3>, !tt.device<#device>) -> tensor<96x32xbf16, #layout5> - %11 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x32>>>, shape = #ttnn.shape<64x32>}> : (!tt.device<#device>) -> tensor<64x32xbf16, #layout7> - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]> - %12 = "ttnn.matmul"(%9, %10, %11) : (tensor<64x96xbf16, #layout6>, tensor<96x32xbf16, #layout5>, tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout7> - %13 = "ttnn.to_layout"(%arg4, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x1>>>}> : (tensor<64x32xbf16, #layout4>, !tt.device<#device>) -> tensor<64x32xbf16, #layout5> - %14 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x32>>>, shape = #ttnn.shape<64x32>}> : (!tt.device<#device>) -> tensor<64x32xbf16, #layout7> - // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]> - %15 = "ttnn.add"(%12, %13, %14) <{operandSegmentSizes = array}> : (tensor<64x32xbf16, #layout7>, tensor<64x32xbf16, #layout5>, tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout7> - %16 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x32>>>, shape = #ttnn.shape<64x32>}> : (!tt.device<#device>) -> tensor<64x32xbf16, #layout7> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]> - %17 = "ttnn.relu"(%15, %16) <{operandSegmentSizes = array}> : (tensor<64x32xbf16, #layout7>, tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout7> - %18 = "ttnn.to_layout"(%17) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x32>>>}> : (tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout4> - return %18 : tensor<64x32xbf16, #layout4> - } -} diff --git a/test/ttmlir/Dialect/TTNN/mnist_l1_interleaved.mlir b/test/ttmlir/Dialect/TTNN/mnist_l1_interleaved.mlir deleted file mode 100644 index 69eab27588..0000000000 --- a/test/ttmlir/Dialect/TTNN/mnist_l1_interleaved.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: ttmlir-opt --ttnn-optimizer="memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s -#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> -#dram = #tt.memory_space -#system = #tt.memory_space -#system_desc = #tt.system_desc<[{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> -#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x784xf32, #system>> -#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #system>> -#layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<256x10xf32, #system>> -#layout3 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x256xf32, #system>> -#layout4 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<784x256xf32, #system>> -#layout5 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x1x!tt.tile<32x32, f32>, #dram>, interleaved> -#layout6 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x256xf32, #dram>, interleaved> -#layout7 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #dram>, interleaved> -module @"tt-forge-graph" attributes {tt.device = #device, tt.system_desc = #system_desc} { - func.func @main(%arg0: tensor<1x784xf32, #layout>, %arg1: tensor<1x10xf32, #layout1>, %arg2: tensor<256x10xf32, #layout2>, %arg3: tensor<1x256xf32, #layout3>, %arg4: tensor<784x256xf32, #layout4>) -> tensor<1x10xf32, #layout1> { - // CHECK: #[[L1_:.*]] = #tt.memory_space - // CHECK: #[[LAYOUT_6:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<1x32xf32, #l1_>, interleaved> - // CHECK: #[[LAYOUT_7:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<1x2xf32, #l1_>, interleaved> - %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x1>>>}> : (tensor<1x784xf32, #layout>, !tt.device<#device>) -> tensor<1x784xf32, #layout5> - %2 = "ttnn.to_layout"(%arg4, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x1>>>}> : (tensor<784x256xf32, #layout4>, !tt.device<#device>) -> tensor<784x256xf32, #layout5> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x256>>>, shape = #ttnn.shape<1x256>}> : (!tt.device<#device>) -> tensor<1x256xf32, #layout6> - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_6]]> - %4 = "ttnn.matmul"(%1, %2, %3) : (tensor<1x784xf32, #layout5>, tensor<784x256xf32, #layout5>, tensor<1x256xf32, #layout6>) -> tensor<1x256xf32, #layout6> - %5 = "ttnn.to_layout"(%arg3, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x1>>>}> : (tensor<1x256xf32, #layout3>, !tt.device<#device>) -> tensor<1x256xf32, #layout5> - %6 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x256>>>, shape = #ttnn.shape<1x256>}> : (!tt.device<#device>) -> tensor<1x256xf32, #layout6> - // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_6]]> - %7 = "ttnn.add"(%4, %5, %6) <{operandSegmentSizes = array}> : (tensor<1x256xf32, #layout6>, tensor<1x256xf32, #layout5>, tensor<1x256xf32, #layout6>) -> tensor<1x256xf32, #layout6> - %8 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x256>>>, shape = #ttnn.shape<1x256>}> : (!tt.device<#device>) -> tensor<1x256xf32, #layout6> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_6]]> - %9 = "ttnn.relu"(%7, %8) <{operandSegmentSizes = array}> : (tensor<1x256xf32, #layout6>, tensor<1x256xf32, #layout6>) -> tensor<1x256xf32, #layout6> - %10 = "ttnn.to_layout"(%arg2, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x1>>>}> : (tensor<256x10xf32, #layout2>, !tt.device<#device>) -> tensor<256x10xf32, #layout5> - %11 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x10>>>, shape = #ttnn.shape<1x10>}> : (!tt.device<#device>) -> tensor<1x10xf32, #layout7> - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_7]]> - %12 = "ttnn.matmul"(%9, %10, %11) : (tensor<1x256xf32, #layout6>, tensor<256x10xf32, #layout5>, tensor<1x10xf32, #layout7>) -> tensor<1x10xf32, #layout7> - %13 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x1>>>}> : (tensor<1x10xf32, #layout1>, !tt.device<#device>) -> tensor<1x10xf32, #layout5> - %14 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x10>>>, shape = #ttnn.shape<1x10>}> : (!tt.device<#device>) -> tensor<1x10xf32, #layout7> - // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_7]]> - %15 = "ttnn.add"(%12, %13, %14) <{operandSegmentSizes = array}> : (tensor<1x10xf32, #layout7>, tensor<1x10xf32, #layout5>, tensor<1x10xf32, #layout7>) -> tensor<1x10xf32, #layout7> - // CHECK: %{{.*}} = "ttnn.softmax"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_7]]> - %16 = "ttnn.softmax"(%15) <{dimension = 1 : si32}> : (tensor<1x10xf32, #layout7>) -> tensor<1x10xf32, #layout7> - %17 = "ttnn.to_layout"(%16) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<1x10>>>}> : (tensor<1x10xf32, #layout7>) -> tensor<1x10xf32, #layout1> - return %17 : tensor<1x10xf32, #layout1> - } -}