Skip to content

Commit

Permalink
Skeleton for l1 interleaved policy
Browse files Browse the repository at this point in the history
Refactor sharding into mem layout analysis.

All l1 interleaved policy for mem layout analysis

Add mnist test

Add option in both optimizer pass and ttnn-ttir-backedn-pipeline to specify memory layout analysis policy type
  • Loading branch information
fbajraktariTT committed Oct 31, 2024
1 parent 33ac41f commit f723ac7
Show file tree
Hide file tree
Showing 16 changed files with 468 additions and 15 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ third_party/tt-metal
.cache
*pycache*
*.egg-info
ttrt-artifacts/*
query_results.json
run_results.json
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> {
bool isSystemMemorySpace() const { return ::mlir::tt::isSystemMemorySpace(getMemorySpace()); }
bool isDeviceMemorySpace() const { return ::mlir::tt::isDeviceMemorySpace(getMemorySpace()); }
bool hasShardedTensorMemoryLayout() const;
bool hasInterleavedTensorMemoryLayout() const;
bool hasShardedL1TensorMemoryLayout() const;
bool hasInterleavedL1TensorMemoryLayout() const;
bool isTiled() const;
Type getElementType() const;
Type getScalarElementType() const;
Expand Down
37 changes: 37 additions & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"

namespace mlir::tt::ttnn {

class L1InterleavedPolicy {
private:
Operation *rootOp;
std::vector<L1ChainConfig> *l1ChainConfigs;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;

public:
L1InterleavedPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs),
legalLayouts(legalLayouts), schedule(&schedule),
usableL1CacheSize(usableL1CacheSize) {}

void run();
};

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H
15 changes: 10 additions & 5 deletions include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,29 @@

namespace mlir::tt::ttnn {

enum class MemoryLayoutAnalysisPolicyType {
DFSharding,
};
enum class MemoryLayoutAnalysisPolicyType { DFSharding, L1Interleaved };

::llvm::StringRef
stringifyMemoryLayoutAnalysisPolicyType(MemoryLayoutAnalysisPolicyType policy);

MemoryLayoutAnalysisPolicyType
symbolizeMemoryLayoutAnalysisPolicyType(::llvm::StringRef policy);

struct MemoryLayoutAnalysisInput {
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
unsigned usableL1CacheSize = 0;
std::unordered_set<Edge> overrideReshardEdges;
MemoryLayoutAnalysisPolicyType policy;

MemoryLayoutAnalysisInput() : legalLayouts() {}

MemoryLayoutAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges)
const std::unordered_set<Edge> &overrideReshardEdges, MemoryLayoutAnalysisPolicyType policy)
: legalLayouts(legalLayouts), usableL1CacheSize(usableL1CacheSize),
overrideReshardEdges(overrideReshardEdges) {}
overrideReshardEdges(overrideReshardEdges), policy(policy) {}

bool operator==(const MemoryLayoutAnalysisInput &rhs) const {
return legalLayouts == rhs.legalLayouts;
Expand Down
67 changes: 67 additions & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/ShardingAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/Edge.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"

namespace mlir::tt::ttnn {

enum class PolicyType { DFSharding, L1Interleaved };

struct ShardingAnalysisInput {
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
unsigned usableL1CacheSize = 0;

ShardingAnalysisInput() : legalLayouts() {}

ShardingAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
unsigned usableL1CacheSize)
: legalLayouts(legalLayouts), usableL1CacheSize(usableL1CacheSize) {}

bool operator==(const ShardingAnalysisInput &rhs) const {
return legalLayouts == rhs.legalLayouts;
}

bool operator!=(const ShardingAnalysisInput &rhs) const {
return !(*this == rhs);
}
};

struct ShardingAnalysisResult {
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
std::unordered_set<Edge> reshardedEdges;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> schedule;

ShardingAnalysisResult() : legalLayouts(), reshardedEdges(), schedule() {}

ShardingAnalysisResult(
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
const std::unordered_set<Edge> &reshardedEdges)
: legalLayouts(legalLayouts), reshardedEdges(reshardedEdges) {}
};

// Determine shard chain configs.
//
class ShardingAnalysis
: public TTNNAnalysis<ShardingAnalysisInput, ShardingAnalysisResult> {

private:
void analysisImplementation() override;
bool applyOverrides() override;
std::vector<L1ChainConfig> l1ChainConfigs;

public:
ShardingAnalysis(Operation *op) : TTNNAnalysis(op) {}
};
} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDINGANALYSIS_H
36 changes: 32 additions & 4 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,32 @@

#include "mlir/Pass/PassOptions.h"
#include "ttmlir/Dialect/TT/Utils/OverrideParams.h"
#include <cstdint>
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Support/CommandLine.h>
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h"

namespace mlir::tt::ttnn {

struct MemoryLayoutAnalysisPolicyTypeParser
: public llvm::cl::parser<MemoryLayoutAnalysisPolicyType> {
public:
MemoryLayoutAnalysisPolicyTypeParser(llvm::cl::Option &opt)
: llvm::cl::parser<MemoryLayoutAnalysisPolicyType>(opt) {}

bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
MemoryLayoutAnalysisPolicyType &value) {
MemoryLayoutAnalysisPolicyType policy =
symbolizeMemoryLayoutAnalysisPolicyType(arg);
value = policy;
return true;
}

static void print(llvm::raw_ostream &os,
const MemoryLayoutAnalysisPolicyType &value) {
os << "memory-layout-analysis-policy="
<< stringifyMemoryLayoutAnalysisPolicyType(value);
os << "\n";
}
};

// Options for the TTIR to TTNN backend pipeline.
//
struct TTIRToTTNNBackendPipelineOptions
Expand Down Expand Up @@ -85,6 +105,14 @@ struct TTIRToTTNNBackendPipelineOptions
"of shard specs."),
llvm::cl::init(false)};

// Specify policy for memory layout analysis.
//
Option<MemoryLayoutAnalysisPolicyType, MemoryLayoutAnalysisPolicyTypeParser>
memoryLayoutAnalysisPolicy{
*this, "memory-layout-analysis-policy",
llvm::cl::desc("Specify policy for memory layout analysis."),
llvm::cl::init(MemoryLayoutAnalysisPolicyType::DFSharding)};

// Option to provide a system descriptor flatbuffer file to compile
// against.
//
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "ttmlir/Dialect/TT/Utils/OverrideParams.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,22 @@ bool LayoutAttr::hasShardedTensorMemoryLayout() const {
getMemLayout() == TensorMemoryLayout::BlockSharded);
}

bool LayoutAttr::hasInterleavedTensorMemoryLayout() const {
return (getMemLayout() == TensorMemoryLayout::Interleaved);
}

bool LayoutAttr::hasShardedL1TensorMemoryLayout() const {
return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and
(getMemLayout() == TensorMemoryLayout::HeightSharded or
getMemLayout() == TensorMemoryLayout::WidthSharded or
getMemLayout() == TensorMemoryLayout::BlockSharded);
}

bool LayoutAttr::hasInterleavedL1TensorMemoryLayout() const {
return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and
(getMemLayout() == TensorMemoryLayout::Interleaved);
}

bool LayoutAttr::isTiled() const {
return ::mlir::isa<::mlir::tt::TileType>(getElementType());
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTTNNAnalysis
MemoryLayoutAnalysis.cpp
L1ChainConfig.cpp
DFShardingPolicy.cpp
L1InterleavedPolicy.cpp
ShardSolver.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
69 changes: 69 additions & 0 deletions lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Scheduler/Scheduler.h"
#include <llvm/Support/raw_ostream.h>

namespace mlir::tt::ttnn {

void L1InterleavedPolicy::run() {
rootOp->walk([&](func::FuncOp func) {
mlir::tt::scheduler::Scheduler scheduler(&func);
llvm::SmallVector<mlir::Operation *> scheduleableOps;
Operation *currentOp = nullptr;
llvm::DenseMap<Operation *, tt::LayoutAttr> selectedOpLayout;

// TODO(fbajraktari): Algo
//
l1ChainConfigs->push_back(L1ChainConfig());
while (scheduler.hasUnscheduledOps()) {
scheduleableOps = scheduler.getScheduleableOps();
currentOp = scheduleableOps[0];

// Schedule currentOp.
//
scheduler.scheduleOp(currentOp);

// Check if currentOp is valid l1 interleaved op.
//
if (legalLayouts.lookup(currentOp).size() > 0) {
selectedOpLayout[currentOp] = legalLayouts.lookup(currentOp).front();

// Add currentOp to shard chain config.
//
OpL1MemSpec shardSpec;
shardSpec.op = currentOp;

// Hardcoded tensor split factor for now, until pipeline OP
// support is added.
//
shardSpec.tensorSplitFactor = 1;
l1ChainConfigs->back().addOpL1MemSpec(std::move(shardSpec));
}
}

if (l1ChainConfigs->back().isEmpty()) {
l1ChainConfigs->pop_back();
}

// Schedule
//
(*schedule)[func] = scheduler.getSchedule();

// Resolve shard chain configs.
//
for (auto &l1ChainConfig : *l1ChainConfigs) {
l1ChainConfig.build();
l1ChainConfig.resolve();

std::unordered_set<Edge> memReconfigEdges;
l1ChainConfig.complete(selectedOpLayout, memReconfigEdges);
}
});
llvm::errs() << "usableL1CacheSize: " << usableL1CacheSize << "\n";
}

} // namespace mlir::tt::ttnn
55 changes: 50 additions & 5 deletions lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,28 @@

#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h"
#include "ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h"

namespace mlir::tt::ttnn {

::llvm::StringRef
stringifyMemoryLayoutAnalysisPolicyType(MemoryLayoutAnalysisPolicyType policy) {
switch (policy) {
case MemoryLayoutAnalysisPolicyType::DFSharding:
return "DFSharding";
case MemoryLayoutAnalysisPolicyType::L1Interleaved:
return "L1Interleaved";
}
return "";
}

MemoryLayoutAnalysisPolicyType
symbolizeMemoryLayoutAnalysisPolicyType(::llvm::StringRef policy) {
return llvm::StringSwitch<MemoryLayoutAnalysisPolicyType>(policy)
.Case("DFSharding", MemoryLayoutAnalysisPolicyType::DFSharding)
.Case("L1Interleaved", MemoryLayoutAnalysisPolicyType::L1Interleaved);
}

bool MemoryLayoutAnalysis::applyOverrides() {

// TODO(nobradovic):
Expand All @@ -33,18 +52,44 @@ filterShardedOnly(const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
return shardedLayouts;
}

void MemoryLayoutAnalysis::analysisImplementation() {
MemoryLayoutAnalysisPolicyType policy =
MemoryLayoutAnalysisPolicyType::DFSharding;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
filterL1InterleavedOnly(
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts) {
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> l1InterleavedLayouts;
for (const auto &opLayouts : legalLayouts) {
std::vector<tt::LayoutAttr> opL1InterleavedLayouts;
for (const auto &layout : opLayouts.second) {
if (layout.hasInterleavedL1TensorMemoryLayout()) {
opL1InterleavedLayouts.push_back(layout);
}
}

switch (policy) {
case MemoryLayoutAnalysisPolicyType::DFSharding:
l1InterleavedLayouts[opLayouts.first] = opL1InterleavedLayouts;
}

return l1InterleavedLayouts;
}

void MemoryLayoutAnalysis::analysisImplementation() {
// Apply specific memory layout analysis policy.
//
switch (analysisInput.policy) {
case MemoryLayoutAnalysisPolicyType::DFSharding: {
DFShardingPolicy dfShardingPolicy(
op, l1ChainConfigs, filterShardedOnly(analysisInput.legalLayouts),
analysisResult.schedule, analysisInput.usableL1CacheSize);
dfShardingPolicy.run(analysisInput.overrideReshardEdges);
break;
}
case MemoryLayoutAnalysisPolicyType::L1Interleaved: {
L1InterleavedPolicy l1InterleavedPolicy(
op, l1ChainConfigs, filterL1InterleavedOnly(analysisInput.legalLayouts),
analysisResult.schedule, analysisInput.usableL1CacheSize);
l1InterleavedPolicy.run();
break;
}
}

// Copy over default legal layouts.
//
Expand Down
Loading

0 comments on commit f723ac7

Please sign in to comment.