Skip to content

Commit

Permalink
Changed the signature of run() method of memory layout analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
fbajraktariTT committed Nov 4, 2024
1 parent 3ac9157 commit e303453
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 11 deletions.
12 changes: 10 additions & 2 deletions include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ namespace mlir::tt::ttnn {
// Schedule is also produced as a side effect of sharding.
//
class DFShardingPolicy : public MemoryLayoutAnalysisPolicy {
private:
std::unordered_set<Edge> overrideReshardEdges;

public:
DFShardingPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
Expand All @@ -23,9 +26,14 @@ class DFShardingPolicy : public MemoryLayoutAnalysisPolicy {
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize) {}
schedule, usableL1CacheSize),
overrideReshardEdges() {}

void run() final;

void run(const std::unordered_set<Edge> &overrideReshardEdges) final;
void setOverrideReshardEdges(const std::unordered_set<Edge> &reshardEdges) {
overrideReshardEdges = reshardEdges;
}
};

} // namespace mlir::tt::ttnn
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize) {}

void run(const std::unordered_set<Edge> &overrideReshardEdges) final;
void run() final;
};

} // namespace mlir::tt::ttnn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class MemoryLayoutAnalysisPolicy {
legalLayouts(legalLayouts), schedule(&schedule),
usableL1CacheSize(usableL1CacheSize) {}

virtual void run(const std::unordered_set<Edge> &overrideReshardEdges) = 0;
virtual void run() = 0;
};

} // namespace mlir::tt::ttnn
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

namespace mlir::tt::ttnn {

void DFShardingPolicy::run(
const std::unordered_set<Edge> &overrideReshardEdges) {
void DFShardingPolicy::run() {
rootOp->walk([&](func::FuncOp func) {
DeviceAttr deviceAttr = getCurrentScopeDevice(func);
mlir::tt::scheduler::Scheduler scheduler(&func);
Expand Down
5 changes: 2 additions & 3 deletions lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

namespace mlir::tt::ttnn {

void L1InterleavedPolicy::run(
const std::unordered_set<Edge> &overrideReshardEdges) {
void L1InterleavedPolicy::run() {
rootOp->walk([&](func::FuncOp func) {
mlir::tt::scheduler::Scheduler scheduler(&func);
llvm::SmallVector<mlir::Operation *> scheduleableOps;
Expand All @@ -19,7 +18,7 @@ void L1InterleavedPolicy::run(
// TODO(fbajraktari):
// This is V0 implementation of L1 interleaved policy. In the current
// implementation we have a single L1ChainCofig per FuncOp. This implies
// that in case of DRAM spil we will have a disconnected chain of L1 ops.
// that in case of DRAM spill we will have a disconnected chain of L1 ops.
// This will be fixed in V1.
//
l1ChainConfigs->push_back(L1ChainConfig());
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,16 @@ void MemoryLayoutAnalysis::analysisImplementation() {
DFShardingPolicy dfShardingPolicy(
op, l1ChainConfigs, filterShardedOnly(analysisInput.legalLayouts),
analysisResult.schedule, analysisInput.usableL1CacheSize);
dfShardingPolicy.run(analysisInput.overrideReshardEdges);
dfShardingPolicy.setOverrideReshardEdges(
analysisInput.overrideReshardEdges);
dfShardingPolicy.run();
break;
}
case MemoryLayoutAnalysisPolicyType::L1Interleaved: {
L1InterleavedPolicy l1InterleavedPolicy(
op, l1ChainConfigs, filterL1InterleavedOnly(analysisInput.legalLayouts),
analysisResult.schedule, analysisInput.usableL1CacheSize);
l1InterleavedPolicy.run(analysisInput.overrideReshardEdges);
l1InterleavedPolicy.run();
break;
}
}
Expand Down

0 comments on commit e303453

Please sign in to comment.