Skip to content

Commit

Permalink
memory layout analysis policy flag incorporated into both optimizer p…
Browse files Browse the repository at this point in the history
…ass and TTIRToTTNNBackendPipeline
  • Loading branch information
fbajraktariTT committed Oct 31, 2024
1 parent f723ac7 commit a93e83a
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 80 deletions.
47 changes: 47 additions & 0 deletions include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSIS_H
#define TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSIS_H

#include <llvm/ADT/StringSwitch.h>
#include <llvm/Support/CommandLine.h>

namespace mlir::tt {

enum class MemoryLayoutAnalysisPolicyType { DFSharding, L1Interleaved };

struct MemoryLayoutAnalysisPolicyTypeParser
: public llvm::cl::parser<MemoryLayoutAnalysisPolicyType> {
public:
MemoryLayoutAnalysisPolicyTypeParser(llvm::cl::Option &opt)
: llvm::cl::parser<MemoryLayoutAnalysisPolicyType>(opt) {}

bool parse(llvm::cl::Option &opt, llvm::StringRef argName,
llvm::StringRef arg, MemoryLayoutAnalysisPolicyType &value) {
value = llvm::StringSwitch<MemoryLayoutAnalysisPolicyType>(arg)
.Case("DFSharding", MemoryLayoutAnalysisPolicyType::DFSharding)
.Case("L1Interleaved",
MemoryLayoutAnalysisPolicyType::L1Interleaved);
return false;
}

static void print(llvm::raw_ostream &os,
const MemoryLayoutAnalysisPolicyType &value) {
llvm::StringRef policy;
switch (value) {
case MemoryLayoutAnalysisPolicyType::DFSharding:
policy = "DFSharding";
break;
case MemoryLayoutAnalysisPolicyType::L1Interleaved:
policy = "L1Interleaved";
break;
}
os << "memory-layout-analysis-policy=" << policy << "\n";
}
};

} // namespace mlir::tt

#endif // TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSIS_H
1 change: 0 additions & 1 deletion include/ttmlir/Dialect/TT/Utils/OverrideParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#define TTMLIR_DIALECT_TT_UTILS_OVERRIDEPARAMS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include <cstdint>
#include <llvm/Support/CommandLine.h>

namespace mlir::tt {
Expand Down
12 changes: 3 additions & 9 deletions include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,13 @@
#define TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSIS_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Analysis/Edge.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"

namespace mlir::tt::ttnn {

enum class MemoryLayoutAnalysisPolicyType { DFSharding, L1Interleaved };

::llvm::StringRef
stringifyMemoryLayoutAnalysisPolicyType(MemoryLayoutAnalysisPolicyType policy);

MemoryLayoutAnalysisPolicyType
symbolizeMemoryLayoutAnalysisPolicyType(::llvm::StringRef policy);

struct MemoryLayoutAnalysisInput {
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
unsigned usableL1CacheSize = 0;
Expand All @@ -32,7 +25,8 @@ struct MemoryLayoutAnalysisInput {
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges, MemoryLayoutAnalysisPolicyType policy)
const std::unordered_set<Edge> &overrideReshardEdges,
MemoryLayoutAnalysisPolicyType policy)
: legalLayouts(legalLayouts), usableL1CacheSize(usableL1CacheSize),
overrideReshardEdges(overrideReshardEdges), policy(policy) {}

Expand Down
24 changes: 1 addition & 23 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,11 @@
#define TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H

#include "mlir/Pass/PassOptions.h"
#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TT/Utils/OverrideParams.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h"

namespace mlir::tt::ttnn {

struct MemoryLayoutAnalysisPolicyTypeParser
: public llvm::cl::parser<MemoryLayoutAnalysisPolicyType> {
public:
MemoryLayoutAnalysisPolicyTypeParser(llvm::cl::Option &opt)
: llvm::cl::parser<MemoryLayoutAnalysisPolicyType>(opt) {}

bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
MemoryLayoutAnalysisPolicyType &value) {
MemoryLayoutAnalysisPolicyType policy =
symbolizeMemoryLayoutAnalysisPolicyType(arg);
value = policy;
return true;
}

static void print(llvm::raw_ostream &os,
const MemoryLayoutAnalysisPolicyType &value) {
os << "memory-layout-analysis-policy="
<< stringifyMemoryLayoutAnalysisPolicyType(value);
os << "\n";
}
};

// Options for the TTIR to TTNN backend pipeline.
//
struct TTIRToTTNNBackendPipelineOptions
Expand Down
9 changes: 9 additions & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ struct TTNNOptimizerOptions {
llvm::StringMap<OutputLayoutOverrideParams> overrideOutputLayout =
llvm::StringMap<OutputLayoutOverrideParams>();
bool memoryLayoutAnalysisEnabled = false;
MemoryLayoutAnalysisPolicyType memoryLayoutAnalysisPolicy =
MemoryLayoutAnalysisPolicyType::DFSharding;
bool memReconfigEnabled = false;
int64_t maxLegalLayouts = 64;
};
Expand Down Expand Up @@ -95,6 +97,7 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> {
memoryLayoutAnalysisEnabled =
std::move(options.memoryLayoutAnalysisEnabled);
memReconfigEnabled = std::move(options.memReconfigEnabled);
memoryLayoutAnalysisPolicy = std::move(options.memoryLayoutAnalysisPolicy);
maxLegalLayouts = std::move(options.maxLegalLayouts);
}

Expand Down Expand Up @@ -122,6 +125,12 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> {
"we support all "
"types of shard specs."),
::llvm::cl::init(false)};
::mlir::Pass::Option<mlir::tt::MemoryLayoutAnalysisPolicyType,
mlir::tt::MemoryLayoutAnalysisPolicyTypeParser>
memoryLayoutAnalysisPolicy{
*this, "memory-layout-analysis-policy",
llvm::cl::desc("Specify policy for memory layout analysis."),
llvm::cl::init(MemoryLayoutAnalysisPolicyType::DFSharding)};
::mlir::Pass::Option<int64_t> maxLegalLayouts{
*this, "max-legal-layouts",
::llvm::cl::desc(
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void L1InterleavedPolicy::run() {
if (legalLayouts.lookup(currentOp).size() > 0) {
selectedOpLayout[currentOp] = legalLayouts.lookup(currentOp).front();

// Add currentOp to shard chain config.
// Add currentOp to l1 chain config.
//
OpL1MemSpec shardSpec;
shardSpec.op = currentOp;
Expand All @@ -53,7 +53,7 @@ void L1InterleavedPolicy::run() {
//
(*schedule)[func] = scheduler.getSchedule();

// Resolve shard chain configs.
// Resolve l1 chain configs.
//
for (auto &l1ChainConfig : *l1ChainConfigs) {
l1ChainConfig.build();
Expand Down
18 changes: 0 additions & 18 deletions lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,6 @@

namespace mlir::tt::ttnn {

::llvm::StringRef
stringifyMemoryLayoutAnalysisPolicyType(MemoryLayoutAnalysisPolicyType policy) {
switch (policy) {
case MemoryLayoutAnalysisPolicyType::DFSharding:
return "DFSharding";
case MemoryLayoutAnalysisPolicyType::L1Interleaved:
return "L1Interleaved";
}
return "";
}

MemoryLayoutAnalysisPolicyType
symbolizeMemoryLayoutAnalysisPolicyType(::llvm::StringRef policy) {
return llvm::StringSwitch<MemoryLayoutAnalysisPolicyType>(policy)
.Case("DFSharding", MemoryLayoutAnalysisPolicyType::DFSharding)
.Case("L1Interleaved", MemoryLayoutAnalysisPolicyType::L1Interleaved);
}

bool MemoryLayoutAnalysis::applyOverrides() {

// TODO(nobradovic):
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ void createTTNNPipelineAnalysisPasses(
optimizerOptions.memoryLayoutAnalysisEnabled =
options.memoryLayoutAnalysisEnabled;
optimizerOptions.memReconfigEnabled = options.memReconfigEnabled;
optimizerOptions.memoryLayoutAnalysisPolicy =
options.memoryLayoutAnalysisPolicy;
optimizerOptions.maxLegalLayouts = options.maxLegalLayouts;
pm.addPass(mlir::tt::ttnn::createTTNNOptimizer(optimizerOptions));
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TTNN/Transforms/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
MemoryLayoutAnalysis memoryLayoutAnalysis =
getAnalysis<MemoryLayoutAnalysis>();
memoryLayoutAnalysis.init(MemoryLayoutAnalysisInput(
legalLayouts, chipDesc.getUsableL1Size(), overrideReshardEdges, MemoryLayoutAnalysisPolicyType::DFSharding));
legalLayouts, chipDesc.getUsableL1Size(), overrideReshardEdges,
memoryLayoutAnalysisPolicy));
legalLayouts = memoryLayoutAnalysis.getResult().legalLayouts;
opSchedule = memoryLayoutAnalysis.getResult().schedule;
memReconfigEdges = memoryLayoutAnalysis.getResult().memReconfigEdges;
Expand Down
12 changes: 6 additions & 6 deletions test/ttmlir/Dialect/TTNN/all_l1_interleaved_policy.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,30 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} {
// CHECK: #[[LAYOUT_6:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x12xbf16, #l1_>, interleaved>
// CHECK: #[[LAYOUT_7:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x4xbf16, #l1_>, interleaved>
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.composite_to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<64x128xbf16, #layout>, !tt.device<#device>) -> tensor<64x128xbf16, #layout5>
%2 = "ttnn.composite_to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<128x96xbf16, #layout1>, !tt.device<#device>) -> tensor<128x96xbf16, #layout5>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<64x128xbf16, #layout>, !tt.device<#device>) -> tensor<64x128xbf16, #layout5>
%2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<128x96xbf16, #layout1>, !tt.device<#device>) -> tensor<128x96xbf16, #layout5>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x96>>>, shape = #ttnn.shape<64x96>}> : (!tt.device<#device>) -> tensor<64x96xbf16, #layout6>
// CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_6]]>
%4 = "ttnn.matmul"(%1, %2, %3) : (tensor<64x128xbf16, #layout5>, tensor<128x96xbf16, #layout5>, tensor<64x96xbf16, #layout6>) -> tensor<64x96xbf16, #layout6>
%5 = "ttnn.composite_to_layout"(%arg2, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<64x96xbf16, #layout2>, !tt.device<#device>) -> tensor<64x96xbf16, #layout5>
%5 = "ttnn.to_layout"(%arg2, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<64x96xbf16, #layout2>, !tt.device<#device>) -> tensor<64x96xbf16, #layout5>
%6 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x96>>>, shape = #ttnn.shape<64x96>}> : (!tt.device<#device>) -> tensor<64x96xbf16, #layout6>
// CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_6]]>
%7 = "ttnn.add"(%4, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x96xbf16, #layout6>, tensor<64x96xbf16, #layout5>, tensor<64x96xbf16, #layout6>) -> tensor<64x96xbf16, #layout6>
%8 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x96>>>, shape = #ttnn.shape<64x96>}> : (!tt.device<#device>) -> tensor<64x96xbf16, #layout6>
// CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_6]]>
%9 = "ttnn.relu"(%7, %8) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<64x96xbf16, #layout6>, tensor<64x96xbf16, #layout6>) -> tensor<64x96xbf16, #layout6>
%10 = "ttnn.composite_to_layout"(%arg3, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<96x32xbf16, #layout3>, !tt.device<#device>) -> tensor<96x32xbf16, #layout5>
%10 = "ttnn.to_layout"(%arg3, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<96x32xbf16, #layout3>, !tt.device<#device>) -> tensor<96x32xbf16, #layout5>
%11 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x32>>>, shape = #ttnn.shape<64x32>}> : (!tt.device<#device>) -> tensor<64x32xbf16, #layout7>
// CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]>
%12 = "ttnn.matmul"(%9, %10, %11) : (tensor<64x96xbf16, #layout6>, tensor<96x32xbf16, #layout5>, tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout7>
%13 = "ttnn.composite_to_layout"(%arg4, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<64x32xbf16, #layout4>, !tt.device<#device>) -> tensor<64x32xbf16, #layout5>
%13 = "ttnn.to_layout"(%arg4, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<64x32xbf16, #layout4>, !tt.device<#device>) -> tensor<64x32xbf16, #layout5>
%14 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x32>>>, shape = #ttnn.shape<64x32>}> : (!tt.device<#device>) -> tensor<64x32xbf16, #layout7>
// CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]>
%15 = "ttnn.add"(%12, %13, %14) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x32xbf16, #layout7>, tensor<64x32xbf16, #layout5>, tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout7>
%16 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x32>>>, shape = #ttnn.shape<64x32>}> : (!tt.device<#device>) -> tensor<64x32xbf16, #layout7>
// CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]>
%17 = "ttnn.relu"(%15, %16) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<64x32xbf16, #layout7>, tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout7>
%18 = "ttnn.composite_to_layout"(%17) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<none>, <system_memory>, <<64x32>>>}> : (tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout4>
%18 = "ttnn.to_layout"(%17) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<none>, <system_memory>, <<64x32>>>}> : (tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout4>
return %18 : tensor<64x32xbf16, #layout4>
}
}
Loading

0 comments on commit a93e83a

Please sign in to comment.