Skip to content

Commit

Permalink
[Optimizer/TTNN] Migrating Optimizer to TTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
nobradovictt committed Oct 18, 2024
1 parent 998c81a commit 94f5a2a
Show file tree
Hide file tree
Showing 33 changed files with 371 additions and 293 deletions.
1 change: 0 additions & 1 deletion include/ttmlir/Dialect/TTIR/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#ifndef TTMLIR_DIALECT_TTIR_TRANSFORMS_PASSES_H
#define TTMLIR_DIALECT_TTIR_TRANSFORMS_PASSES_H

#include "ttmlir/Dialect/TT/Utils/OverrideParams.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

Expand Down
26 changes: 0 additions & 26 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -108,32 +108,6 @@ def TTIRAllocate: Pass<"ttir-allocate", "::mlir::ModuleOp"> {
}];
}

def TTIROptimizer: Pass<"ttir-optimizer", "::mlir::ModuleOp"> {
let summary = "Determine op configurations for maximum performance.";
let description = [{
Go through the ops, set sharding specs for each op based on sharding analysis,
by updating layout attribute of each op.
}];
let options = [
Option<"overrideOutputLayout", "override-output-layout",
"llvm::StringMap<LayoutOverrideParams>",
/*default=*/"llvm::StringMap<LayoutOverrideParams>()",
"Override output tensor layout for specific ops.">,
Option<"shardingPassEnabled", "sharding-pass-enabled",
"bool",
/*default=*/"false",
"Enable sharding pass.">,
Option<"reshardingEnabled", "resharding-enabled",
"bool",
/*default=*/"false",
"Resharding pass. Temp disabled till we support all types of shard specs.">,
Option<"maxLegalLayouts", "max-legal-layouts",
"int64_t",
/*default=*/"64",
"Override maximum number of legal layouts for grid analysis.">
];
}

def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
let summary = "Load system desc.";
let description = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_DFSHARDINGPOLICY_H
#define TTMLIR_DIALECT_TTIR_ANALYSIS_DFSHARDINGPOLICY_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_DFSHARDINGPOLICY_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_DFSHARDINGPOLICY_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTIR/Analysis/ShardChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/ShardChainConfig.h"

namespace mlir::tt::ttir {
namespace mlir::tt::ttnn {

// Process ops in DFS schedulable order and build shard chain configs.
// Schedule is also produced as a side effect of sharding.
Expand All @@ -17,14 +17,15 @@ class DFShardingPolicy {
private:
Operation *rootOp;
std::vector<ShardChainConfig> *shardChainConfigs;
llvm::DenseMap<Operation *, std::vector<LayoutAttr>> legalLayouts;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;

public:
DFShardingPolicy(
Operation *rootOp, std::vector<ShardChainConfig> &shardChainConfigs,
const llvm::DenseMap<Operation *, std::vector<LayoutAttr>> &legalLayouts,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: rootOp(rootOp), shardChainConfigs(&shardChainConfigs),
Expand All @@ -34,6 +35,6 @@ class DFShardingPolicy {
void run();
};

} // namespace mlir::tt::ttir
} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_DFSHARDINGPOLICY_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_DFSHARDINGPOLICY_H
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_EDGE_H
#define TTMLIR_DIALECT_TTIR_ANALYSIS_EDGE_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_EDGE_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_EDGE_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"

namespace mlir::tt::ttir {
namespace mlir::tt::ttnn {
struct Edge {
Operation *producerOp = nullptr;
Operation *consumerOp = nullptr;
Expand All @@ -23,11 +23,11 @@ struct Edge {
}
};

} // namespace mlir::tt::ttir
} // namespace mlir::tt::ttnn

namespace std {
template <> struct hash<mlir::tt::ttir::Edge> {
size_t operator()(const mlir::tt::ttir::Edge &edge) const noexcept {
template <> struct hash<mlir::tt::ttnn::Edge> {
size_t operator()(const mlir::tt::ttnn::Edge &edge) const noexcept {
llvm::hash_code code = llvm::hash_value(edge.operandIndex);
code = llvm::hash_combine(code, edge.producerOp);
code = llvm::hash_combine(code, edge.consumerOp);
Expand All @@ -36,4 +36,4 @@ template <> struct hash<mlir::tt::ttir::Edge> {
};
} // namespace std

#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_EDGE_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_EDGE_H
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_LEGALGRIDANALYSIS_H
#define TTMLIR_DIALECT_TTIR_ANALYSIS_LEGALGRIDANALYSIS_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TT/Utils/OverrideParams.h"
#include "ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"
#include "llvm/ADT/StringMap.h"

namespace mlir::tt::ttir {
namespace mlir::tt::ttnn {

struct LegalGridAnalysisInput {
ChipDescAttr chipDesc;
Expand Down Expand Up @@ -43,15 +43,15 @@ struct LegalGridAnalysisInput {
};

class LegalGridAnalysis
: public TTIRAnalysis<LegalGridAnalysisInput, std::vector<LayoutAttr>> {
: public TTNNAnalysis<LegalGridAnalysisInput, std::vector<tt::LayoutAttr>> {
private:
void analysisImplementation() override;
bool applyOverrides() override;

public:
LegalGridAnalysis(Operation *op) : TTIRAnalysis(op) {}
LegalGridAnalysis(Operation *op) : TTNNAnalysis(op) {}
};

} // namespace mlir::tt::ttir
} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_LEGALGRIDANALYSIS_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_OPCONFIGANALYSIS_H
#define TTMLIR_DIALECT_TTIR_ANALYSIS_OPCONFIGANALYSIS_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_OPCONFIGANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_OPCONFIGANALYSIS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"

namespace mlir::tt::ttir {
namespace mlir::tt::ttnn {

struct OpConfigAnalysisInput {
llvm::DenseMap<Operation *, std::vector<LayoutAttr>> legalGrids;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalGrids;

OpConfigAnalysisInput() : legalGrids() {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<LayoutAttr>> &&legalGrids)
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&&legalGrids)
: legalGrids(std::move(legalGrids)) {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<LayoutAttr>> &legalGrids)
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalGrids)
: legalGrids(legalGrids) {}

bool operator==(const OpConfigAnalysisInput &rhs) const {
Expand All @@ -35,16 +37,16 @@ struct OpConfigAnalysisInput {
// Determine optimal configuration for each op.
//
class OpConfigAnalysis
: public TTIRAnalysis<OpConfigAnalysisInput,
llvm::DenseMap<Operation *, LayoutAttr>> {
: public TTNNAnalysis<OpConfigAnalysisInput,
llvm::DenseMap<Operation *, tt::LayoutAttr>> {

private:
void analysisImplementation() override;
bool applyOverrides() override;

public:
OpConfigAnalysis(Operation *op) : TTIRAnalysis(op) {}
OpConfigAnalysis(Operation *op) : TTNNAnalysis(op) {}
};
} // namespace mlir::tt::ttir
} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_OPCONFIGANALYSIS_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_OPCONFIGANALYSIS_H
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDCHAINCONFIG_H
#define TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDCHAINCONFIG_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDCHAINCONFIG_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDCHAINCONFIG_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/Analysis/ShardSolver.h"
#include "ttmlir/Dialect/TTNN/Analysis/ShardSolver.h"
#include <unordered_set>

namespace mlir::tt::ttir {
namespace mlir::tt::ttnn {

struct ShardSpec {
Operation *op;
uint tensorSplitFactor;
LayoutAttr layout;
tt::LayoutAttr layout;
};

// Enum to track the state of the shard chain.
Expand All @@ -36,12 +36,14 @@ class ShardChainConfig {
public:
ShardChainConfig() : shardSpecs(), state() {}

ShardSolver resolve(
const llvm::DenseMap<Operation *, std::vector<LayoutAttr>> &legalLayouts,
unsigned usableL1CacheSize);
ShardSolver
resolve(const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
unsigned usableL1CacheSize);
void build();
void complete(const llvm::DenseMap<Operation *, LayoutAttr> &selectedOpLayout,
std::unordered_set<Edge> &reshardedEdges);
void
complete(const llvm::DenseMap<Operation *, tt::LayoutAttr> &selectedOpLayout,
std::unordered_set<Edge> &reshardedEdges);

bool isEmpty() { return shardSpecs.empty(); }
void addShardSpec(ShardSpec &&spec) {
Expand All @@ -56,6 +58,6 @@ class ShardChainConfig {
}
};

} // namespace mlir::tt::ttir
} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDCHAINCONFIG_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDCHAINCONFIG_H
Loading

0 comments on commit 94f5a2a

Please sign in to comment.