Skip to content

Commit

Permalink
Added abstraction to MemoryLayoutAnalysis policy and removed unnecessary
Browse files Browse the repository at this point in the history
tests.
  • Loading branch information
fbajraktariTT committed Oct 31, 2024
1 parent a93e83a commit 0194e80
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 185 deletions.
17 changes: 5 additions & 12 deletions include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,25 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"

namespace mlir::tt::ttnn {

// Process ops in DFS schedulable order and build shard chain configs.
// Schedule is also produced as a side effect of sharding.
//
class DFShardingPolicy {
private:
Operation *rootOp;
std::vector<L1ChainConfig> *l1ChainConfigs;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;

class DFShardingPolicy : public MemoryLayoutAnalysisPolicy {
public:
DFShardingPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs),
legalLayouts(legalLayouts), schedule(&schedule),
usableL1CacheSize(usableL1CacheSize) {}
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize) {}

void run(const std::unordered_set<Edge> &overrideReshardEdges);
void run(const std::unordered_set<Edge> &overrideReshardEdges) final;
};

} // namespace mlir::tt::ttnn
Expand Down
17 changes: 5 additions & 12 deletions include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,22 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"

namespace mlir::tt::ttnn {

class L1InterleavedPolicy {
private:
Operation *rootOp;
std::vector<L1ChainConfig> *l1ChainConfigs;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;

class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
public:
L1InterleavedPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs),
legalLayouts(legalLayouts), schedule(&schedule),
usableL1CacheSize(usableL1CacheSize) {}
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize) {}

void run();
void run(const std::unordered_set<Edge> &overrideReshardEdges) final;
};

} // namespace mlir::tt::ttnn
Expand Down
39 changes: 39 additions & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"

namespace mlir::tt::ttnn {

class MemoryLayoutAnalysisPolicy {
protected:
Operation *rootOp;
std::vector<L1ChainConfig> *l1ChainConfigs;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;

public:
virtual ~MemoryLayoutAnalysisPolicy() {};

MemoryLayoutAnalysisPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs),
legalLayouts(legalLayouts), schedule(&schedule),
usableL1CacheSize(usableL1CacheSize) {}

virtual void run(const std::unordered_set<Edge> &overrideReshardEdges) = 0;
};

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H
67 changes: 0 additions & 67 deletions include/ttmlir/Dialect/TTNN/Analysis/ShardingAnalysis.h

This file was deleted.

4 changes: 2 additions & 2 deletions lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
#include "ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Scheduler/Scheduler.h"
#include <llvm/Support/raw_ostream.h>

namespace mlir::tt::ttnn {

void L1InterleavedPolicy::run() {
void L1InterleavedPolicy::run(
const std::unordered_set<Edge> &overrideReshardEdges) {
rootOp->walk([&](func::FuncOp func) {
mlir::tt::scheduler::Scheduler scheduler(&func);
llvm::SmallVector<mlir::Operation *> scheduleableOps;
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void MemoryLayoutAnalysis::analysisImplementation() {
L1InterleavedPolicy l1InterleavedPolicy(
op, l1ChainConfigs, filterL1InterleavedOnly(analysisInput.legalLayouts),
analysisResult.schedule, analysisInput.usableL1CacheSize);
l1InterleavedPolicy.run();
l1InterleavedPolicy.run(analysisInput.overrideReshardEdges);
break;
}
}
Expand Down
46 changes: 0 additions & 46 deletions test/ttmlir/Dialect/TTNN/all_l1_interleaved_policy.mlir

This file was deleted.

Loading

0 comments on commit 0194e80

Please sign in to comment.