Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

L1 interleaved policy #1117

Merged
merged 10 commits into from
Nov 5, 2024
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@ third_party/tt-metal
.cache
*pycache*
*.egg-info
ttrt-artifacts/*
query_results.json
run_results.json
ttrt_report.xml
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> {
bool isSystemMemorySpace() const { return ::mlir::tt::isSystemMemorySpace(getMemorySpace()); }
bool isDeviceMemorySpace() const { return ::mlir::tt::isDeviceMemorySpace(getMemorySpace()); }
bool hasShardedTensorMemoryLayout() const;
bool hasInterleavedTensorMemoryLayout() const;
bool hasShardedL1TensorMemoryLayout() const;
bool hasInterleavedL1TensorMemoryLayout() const;
bool isTiled() const;
Type getElementType() const;
Type getScalarElementType() const;
Expand Down
47 changes: 47 additions & 0 deletions include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSIS_H
#define TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSIS_H

#include <llvm/ADT/StringSwitch.h>
#include <llvm/Support/CommandLine.h>

namespace mlir::tt {

enum class MemoryLayoutAnalysisPolicyType { DFSharding, L1Interleaved };

struct MemoryLayoutAnalysisPolicyTypeParser
: public llvm::cl::parser<MemoryLayoutAnalysisPolicyType> {
public:
MemoryLayoutAnalysisPolicyTypeParser(llvm::cl::Option &opt)
: llvm::cl::parser<MemoryLayoutAnalysisPolicyType>(opt) {}

bool parse(llvm::cl::Option &opt, llvm::StringRef argName,
llvm::StringRef arg, MemoryLayoutAnalysisPolicyType &value) {
value = llvm::StringSwitch<MemoryLayoutAnalysisPolicyType>(arg)
.Case("DFSharding", MemoryLayoutAnalysisPolicyType::DFSharding)
fbajraktariTT marked this conversation as resolved.
Show resolved Hide resolved
.Case("L1Interleaved",
MemoryLayoutAnalysisPolicyType::L1Interleaved);
return false;
}

static void print(llvm::raw_ostream &os,
const MemoryLayoutAnalysisPolicyType &value) {
llvm::StringRef policy;
switch (value) {
case MemoryLayoutAnalysisPolicyType::DFSharding:
policy = "DFSharding";
break;
case MemoryLayoutAnalysisPolicyType::L1Interleaved:
policy = "L1Interleaved";
break;
}
os << "memory-layout-analysis-policy=" << policy << "\n";
}
};

} // namespace mlir::tt

#endif // TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSIS_H
1 change: 0 additions & 1 deletion include/ttmlir/Dialect/TT/Utils/OverrideParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#define TTMLIR_DIALECT_TT_UTILS_OVERRIDEPARAMS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include <cstdint>
#include <llvm/Support/CommandLine.h>

namespace mlir::tt {
Expand Down
21 changes: 11 additions & 10 deletions include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"

namespace mlir::tt::ttnn {

// Process ops in DFS schedulable order and build shard chain configs.
// Schedule is also produced as a side effect of sharding.
//
class DFShardingPolicy {
class DFShardingPolicy : public MemoryLayoutAnalysisPolicy {
private:
Operation *rootOp;
std::vector<L1ChainConfig> *l1ChainConfigs;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;
std::unordered_set<Edge> overrideReshardEdges;

public:
DFShardingPolicy(
Expand All @@ -28,11 +25,15 @@ class DFShardingPolicy {
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs),
legalLayouts(legalLayouts), schedule(&schedule),
usableL1CacheSize(usableL1CacheSize) {}
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize),
overrideReshardEdges() {}

void run(const std::unordered_set<Edge> &overrideReshardEdges);
void run() final;

void setOverrideReshardEdges(const std::unordered_set<Edge> &reshardEdges) {
overrideReshardEdges = reshardEdges;
}
};

} // namespace mlir::tt::ttnn
Expand Down
30 changes: 30 additions & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"

namespace mlir::tt::ttnn {

class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
public:
L1InterleavedPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize) {}

void run() final;
};

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H
11 changes: 5 additions & 6 deletions include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,29 @@
#define TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSIS_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Analysis/Edge.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"

namespace mlir::tt::ttnn {

enum class MemoryLayoutAnalysisPolicyType {
DFSharding,
};

struct MemoryLayoutAnalysisInput {
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
unsigned usableL1CacheSize = 0;
std::unordered_set<Edge> overrideReshardEdges;
MemoryLayoutAnalysisPolicyType policy;

MemoryLayoutAnalysisInput() : legalLayouts() {}

MemoryLayoutAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges)
const std::unordered_set<Edge> &overrideReshardEdges,
MemoryLayoutAnalysisPolicyType policy)
: legalLayouts(legalLayouts), usableL1CacheSize(usableL1CacheSize),
overrideReshardEdges(overrideReshardEdges) {}
overrideReshardEdges(overrideReshardEdges), policy(policy) {}

bool operator==(const MemoryLayoutAnalysisInput &rhs) const {
return legalLayouts == rhs.legalLayouts;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"

namespace mlir::tt::ttnn {

class MemoryLayoutAnalysisPolicy {
protected:
Operation *rootOp;
std::vector<L1ChainConfig> *l1ChainConfigs;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;

public:
virtual ~MemoryLayoutAnalysisPolicy() {};

MemoryLayoutAnalysisPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs),
legalLayouts(legalLayouts), schedule(&schedule),
usableL1CacheSize(usableL1CacheSize) {}

virtual void run() = 0;
};

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H
14 changes: 10 additions & 4 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
#define TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H

#include "mlir/Pass/PassOptions.h"
#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TT/Utils/OverrideParams.h"
#include <cstdint>
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Support/CommandLine.h>

namespace mlir::tt::ttnn {

// Options for the TTIR to TTNN backend pipeline.
//
struct TTIRToTTNNBackendPipelineOptions
Expand Down Expand Up @@ -85,6 +83,14 @@ struct TTIRToTTNNBackendPipelineOptions
"of shard specs."),
llvm::cl::init(false)};

// Specify policy for memory layout analysis.
//
Option<MemoryLayoutAnalysisPolicyType, MemoryLayoutAnalysisPolicyTypeParser>
memoryLayoutAnalysisPolicy{
*this, "memory-layout-analysis-policy",
llvm::cl::desc("Specify policy for memory layout analysis."),
llvm::cl::init(MemoryLayoutAnalysisPolicyType::DFSharding)};

// Option to provide a system descriptor flatbuffer file to compile
// against.
//
Expand Down
9 changes: 9 additions & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ struct TTNNOptimizerOptions {
llvm::StringMap<OutputLayoutOverrideParams> overrideOutputLayout =
llvm::StringMap<OutputLayoutOverrideParams>();
bool memoryLayoutAnalysisEnabled = false;
MemoryLayoutAnalysisPolicyType memoryLayoutAnalysisPolicy =
MemoryLayoutAnalysisPolicyType::DFSharding;
bool memReconfigEnabled = false;
int64_t maxLegalLayouts = 64;
};
Expand Down Expand Up @@ -95,6 +97,7 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> {
memoryLayoutAnalysisEnabled =
std::move(options.memoryLayoutAnalysisEnabled);
memReconfigEnabled = std::move(options.memReconfigEnabled);
memoryLayoutAnalysisPolicy = std::move(options.memoryLayoutAnalysisPolicy);
maxLegalLayouts = std::move(options.maxLegalLayouts);
}

Expand Down Expand Up @@ -122,6 +125,12 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> {
"we support all "
"types of shard specs."),
::llvm::cl::init(false)};
::mlir::Pass::Option<mlir::tt::MemoryLayoutAnalysisPolicyType,
mlir::tt::MemoryLayoutAnalysisPolicyTypeParser>
memoryLayoutAnalysisPolicy{
*this, "memory-layout-analysis-policy",
llvm::cl::desc("Specify policy for memory layout analysis."),
llvm::cl::init(MemoryLayoutAnalysisPolicyType::DFSharding)};
::mlir::Pass::Option<int64_t> maxLegalLayouts{
*this, "max-legal-layouts",
::llvm::cl::desc(
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "ttmlir/Dialect/TT/Utils/OverrideParams.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,22 @@ bool LayoutAttr::hasShardedTensorMemoryLayout() const {
getMemLayout() == TensorMemoryLayout::BlockSharded);
}

bool LayoutAttr::hasInterleavedTensorMemoryLayout() const {
return (getMemLayout() == TensorMemoryLayout::Interleaved);
}

bool LayoutAttr::hasShardedL1TensorMemoryLayout() const {
return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and
(getMemLayout() == TensorMemoryLayout::HeightSharded or
getMemLayout() == TensorMemoryLayout::WidthSharded or
getMemLayout() == TensorMemoryLayout::BlockSharded);
}

bool LayoutAttr::hasInterleavedL1TensorMemoryLayout() const {
return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and
(getMemLayout() == TensorMemoryLayout::Interleaved);
}

bool LayoutAttr::isTiled() const {
return ::mlir::isa<::mlir::tt::TileType>(getElementType());
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTTNNAnalysis
MemoryLayoutAnalysis.cpp
L1ChainConfig.cpp
DFShardingPolicy.cpp
L1InterleavedPolicy.cpp
ShardSolver.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

namespace mlir::tt::ttnn {

void DFShardingPolicy::run(
const std::unordered_set<Edge> &overrideReshardEdges) {
void DFShardingPolicy::run() {
rootOp->walk([&](func::FuncOp func) {
DeviceAttr deviceAttr = getCurrentScopeDevice(func);
mlir::tt::scheduler::Scheduler scheduler(&func);
Expand Down
Loading
Loading