Skip to content

Commit

Permalink
[Optimizer] Add ability to pass in custom constraint check function t…
Browse files Browse the repository at this point in the history
…o ShardSolver. (#1573)
  • Loading branch information
nobradovictt authored Dec 13, 2024
1 parent ed84a56 commit b0c0c2b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
8 changes: 7 additions & 1 deletion include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,10 @@ class ShardSolver {
const std::vector<OpL1MemSpec> &shardSpecs,
const llvm::DenseSet<Operation *> &shardedOps,
const unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges);
const std::unordered_set<Edge> &overrideReshardEdges,
std::function<bool(mlir::Operation *, TTNNLayoutAttr const &,
mlir::Operation *, TTNNLayoutAttr const &)>
customCheckShardCompatible = nullptr);
RemainingLayoutAttrs at(Operation *operation) const;
void set(Operation *operation, TTNNLayoutAttr const &layout);
static bool supportsInterleavedInputShardedOutput(Operation *op);
Expand All @@ -310,6 +313,9 @@ class ShardSolver {

llvm::DenseMap<Operation *, TTNNLayoutAttr> selectedOpLayout;
std::unordered_set<Edge> memReconfigEdges;
std::function<bool(mlir::Operation *, TTNNLayoutAttr const &,
mlir::Operation *, TTNNLayoutAttr const &)>
customCheckShardCompatible;
};

} // namespace mlir::tt::ttnn
Expand Down
15 changes: 13 additions & 2 deletions lib/Dialect/TTNN/Analysis/ShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ ShardSolver::ShardSolver(
const std::vector<OpL1MemSpec> &shardSpecs,
const llvm::DenseSet<Operation *> &shardedOps,
const unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges)
const std::unordered_set<Edge> &overrideReshardEdges,
std::function<bool(Operation *, TTNNLayoutAttr const &, Operation *,
TTNNLayoutAttr const &)>
customCheckShardCompatible)
: legalLayouts(&legalLayouts), shardSpecs(&shardSpecs),
shardedOps(&shardedOps), usableL1CacheSize(usableL1CacheSize),
memReconfigEdges(overrideReshardEdges) {
memReconfigEdges(overrideReshardEdges),
customCheckShardCompatible(customCheckShardCompatible) {
pathSets.reserve(shardSpecs.size());
pathSetIds.reserve(shardSpecs.size());
bitsets.reserve(shardedOps.size());
Expand Down Expand Up @@ -505,6 +509,13 @@ bool ShardSolver::checkShardCompatible(
Operation *producerOp, TTNNLayoutAttr const &producerLayout,
Operation *consumerOp, TTNNLayoutAttr const &consumerLayout) const {

// Custom(test) hook for shard compatibility check.
//
if (customCheckShardCompatible) {
return customCheckShardCompatible(producerOp, producerLayout, consumerOp,
consumerLayout);
}

// TEMP : Dummy mock implementation, will be replaced.
//

Expand Down
20 changes: 19 additions & 1 deletion test/unittests/Optimizer/TestShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,26 @@ TEST_F(ShardSolverBase, VerifyProduceMaxCoreUsage) {
addLayoutForOp(op, legalLayouts, BufferType::L1,
TensorMemoryLayout::BlockSharded, 1, 1);

// Create custom checkShardCompatible function.
//
std::function<bool(mlir::Operation *, TTNNLayoutAttr const &,
mlir::Operation *, TTNNLayoutAttr const &)>
checkShardCompatible = [](mlir::Operation *producerOp,
TTNNLayoutAttr const &producerLayout,
mlir::Operation *consumerOp,
TTNNLayoutAttr const &consumerLayout) {
// Simple shard compat assumption. Try to keep same shard layout.
//
if (producerLayout.getMemLayout() != consumerLayout.getMemLayout()) {
return false;
}

return true;
};

ShardSolver shardSolver(legalLayouts, opL1MemSpecs, l1ChainedOps,
usableL1CacheSize, overrideReshardEdges);
usableL1CacheSize, overrideReshardEdges,
checkShardCompatible);

llvm::DenseMap<mlir::Operation *, llvm::SmallVector<float, 64>>
accMaxCoreUsage = shardSolver.produceMaxCoreUsage();
Expand Down

0 comments on commit b0c0c2b

Please sign in to comment.