-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor sharding into mem layout analysis. All l1 interleaved policy for mem layout analysis Add mnist test Add option in both optimizer pass and ttnn-ttir-backedn-pipeline to specify memory layout analysis policy type
- Loading branch information
1 parent
33ac41f
commit f723ac7
Showing
16 changed files
with
468 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,6 @@ third_party/tt-metal | |
.cache | ||
*pycache* | ||
*.egg-info | ||
ttrt-artifacts/* | ||
query_results.json | ||
run_results.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
37 changes: 37 additions & 0 deletions
37
include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H | ||
#define TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h" | ||
|
||
namespace mlir::tt::ttnn { | ||
|
||
class L1InterleavedPolicy { | ||
private: | ||
Operation *rootOp; | ||
std::vector<L1ChainConfig> *l1ChainConfigs; | ||
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts; | ||
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule; | ||
unsigned usableL1CacheSize = 0; | ||
|
||
public: | ||
L1InterleavedPolicy( | ||
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs, | ||
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> | ||
&legalLayouts, | ||
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule, | ||
unsigned usableL1CacheSize) | ||
: rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs), | ||
legalLayouts(legalLayouts), schedule(&schedule), | ||
usableL1CacheSize(usableL1CacheSize) {} | ||
|
||
void run(); | ||
}; | ||
|
||
} // namespace mlir::tt::ttnn | ||
|
||
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H | ||
#define TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "ttmlir/Dialect/TTNN/Analysis/Edge.h" | ||
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h" | ||
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h" | ||
|
||
namespace mlir::tt::ttnn { | ||
|
||
enum class PolicyType { DFSharding, L1Interleaved }; | ||
|
||
struct ShardingAnalysisInput { | ||
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts; | ||
unsigned usableL1CacheSize = 0; | ||
|
||
ShardingAnalysisInput() : legalLayouts() {} | ||
|
||
ShardingAnalysisInput( | ||
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> | ||
&legalLayouts, | ||
unsigned usableL1CacheSize) | ||
: legalLayouts(legalLayouts), usableL1CacheSize(usableL1CacheSize) {} | ||
|
||
bool operator==(const ShardingAnalysisInput &rhs) const { | ||
return legalLayouts == rhs.legalLayouts; | ||
} | ||
|
||
bool operator!=(const ShardingAnalysisInput &rhs) const { | ||
return !(*this == rhs); | ||
} | ||
}; | ||
|
||
struct ShardingAnalysisResult { | ||
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts; | ||
std::unordered_set<Edge> reshardedEdges; | ||
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> schedule; | ||
|
||
ShardingAnalysisResult() : legalLayouts(), reshardedEdges(), schedule() {} | ||
|
||
ShardingAnalysisResult( | ||
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> | ||
&legalLayouts, | ||
const std::unordered_set<Edge> &reshardedEdges) | ||
: legalLayouts(legalLayouts), reshardedEdges(reshardedEdges) {} | ||
}; | ||
|
||
// Determine shard chain configs. | ||
// | ||
class ShardingAnalysis | ||
: public TTNNAnalysis<ShardingAnalysisInput, ShardingAnalysisResult> { | ||
|
||
private: | ||
void analysisImplementation() override; | ||
bool applyOverrides() override; | ||
std::vector<L1ChainConfig> l1ChainConfigs; | ||
|
||
public: | ||
ShardingAnalysis(Operation *op) : TTNNAnalysis(op) {} | ||
}; | ||
} // namespace mlir::tt::ttnn | ||
|
||
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h" | ||
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" | ||
#include "ttmlir/Scheduler/Scheduler.h" | ||
#include <llvm/Support/raw_ostream.h> | ||
|
||
namespace mlir::tt::ttnn { | ||
|
||
void L1InterleavedPolicy::run() { | ||
rootOp->walk([&](func::FuncOp func) { | ||
mlir::tt::scheduler::Scheduler scheduler(&func); | ||
llvm::SmallVector<mlir::Operation *> scheduleableOps; | ||
Operation *currentOp = nullptr; | ||
llvm::DenseMap<Operation *, tt::LayoutAttr> selectedOpLayout; | ||
|
||
// TODO(fbajraktari): Algo | ||
// | ||
l1ChainConfigs->push_back(L1ChainConfig()); | ||
while (scheduler.hasUnscheduledOps()) { | ||
scheduleableOps = scheduler.getScheduleableOps(); | ||
currentOp = scheduleableOps[0]; | ||
|
||
// Schedule currentOp. | ||
// | ||
scheduler.scheduleOp(currentOp); | ||
|
||
// Check if currentOp is valid l1 interleaved op. | ||
// | ||
if (legalLayouts.lookup(currentOp).size() > 0) { | ||
selectedOpLayout[currentOp] = legalLayouts.lookup(currentOp).front(); | ||
|
||
// Add currentOp to shard chain config. | ||
// | ||
OpL1MemSpec shardSpec; | ||
shardSpec.op = currentOp; | ||
|
||
// Hardcoded tensor split factor for now, until pipeline OP | ||
// support is added. | ||
// | ||
shardSpec.tensorSplitFactor = 1; | ||
l1ChainConfigs->back().addOpL1MemSpec(std::move(shardSpec)); | ||
} | ||
} | ||
|
||
if (l1ChainConfigs->back().isEmpty()) { | ||
l1ChainConfigs->pop_back(); | ||
} | ||
|
||
// Schedule | ||
// | ||
(*schedule)[func] = scheduler.getSchedule(); | ||
|
||
// Resolve shard chain configs. | ||
// | ||
for (auto &l1ChainConfig : *l1ChainConfigs) { | ||
l1ChainConfig.build(); | ||
l1ChainConfig.resolve(); | ||
|
||
std::unordered_set<Edge> memReconfigEdges; | ||
l1ChainConfig.complete(selectedOpLayout, memReconfigEdges); | ||
} | ||
}); | ||
llvm::errs() << "usableL1CacheSize: " << usableL1CacheSize << "\n"; | ||
} | ||
|
||
} // namespace mlir::tt::ttnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.