Skip to content

Commit

Permalink
Wrap compound components in struct instead of tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Sep 3, 2024
1 parent e5822b8 commit 3262f11
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 49 deletions.
13 changes: 11 additions & 2 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,17 @@ def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpI
// TODO return below, but we need a way to properly create an ArrayAttr:
// return {OperandConstraint::Any, OperandConstraint::Any};
}
// Returns a tuple of booleans indicating if the op changes layout, grid, format, or memory space.
std::tuple<bool, bool, bool, bool, bool> compoundComponents();

struct CompoundComponents {
bool isLayoutChange;
bool isGridChange;
bool isFormatChange;
bool isMemorySpaceChange;
bool isMemoryLayoutChange;
};

// Returns booleans indicating if the op changes layout, grid, format, memory space or memory layout.
CompoundComponents compoundComponents();
}];

let hasVerifier = 1;
Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ ::mlir::LogicalResult mlir::tt::ttir::ToLayoutOp::verify() {
return success();
}

std::tuple<bool, bool, bool, bool, bool>
mlir::tt::ttir::ToLayoutOp::CompoundComponents
mlir::tt::ttir::ToLayoutOp::compoundComponents() {
auto inputLayout =
mlir::cast<tt::LayoutAttr>(getInput().getType().getEncoding());
Expand All @@ -49,8 +49,8 @@ mlir::tt::ttir::ToLayoutOp::compoundComponents() {
inputLayout.getMemorySpace() != outputLayout.getMemorySpace();
bool isMemoryLayoutChange =
inputLayout.getMemLayout() != outputLayout.getMemLayout();
return std::make_tuple(isLayoutChange, isGridChange, isFormatChange,
isMemorySpaceChange, isMemoryLayoutChange);
return {isLayoutChange, isGridChange, isFormatChange, isMemorySpaceChange,
isMemoryLayoutChange};
}

::mlir::LogicalResult mlir::tt::ttir::GenericOp::verify() {
Expand Down
22 changes: 10 additions & 12 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "ttmlir/Dialect/TTIR/Analysis/OptimalTargetGridAnalysis.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include "ttmlir/Utils.h"
#include <iostream>

namespace mlir::tt::ttir {
#define GEN_PASS_DEF_TTIRGENERICKERNEL
Expand Down Expand Up @@ -792,13 +791,12 @@ class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern<ToLayoutOp> {

LogicalResult matchAndRewrite(ToLayoutOp op,
PatternRewriter &rewriter) const final {
auto [isLayoutChange, isGridChange, isFormatChange, isMemorySpaceChange,
isMemoryLayoutChange] = op.compoundComponents();
bool isCompound =
(static_cast<int>(isLayoutChange) + static_cast<int>(isGridChange) +
static_cast<int>(isFormatChange) +
static_cast<int>(isMemorySpaceChange) +
static_cast<int>(isMemoryLayoutChange)) > 1;
auto components = op.compoundComponents();
bool isCompound = (static_cast<int>(components.isLayoutChange) +
static_cast<int>(components.isGridChange) +
static_cast<int>(components.isFormatChange) +
static_cast<int>(components.isMemorySpaceChange) +
static_cast<int>(components.isMemoryLayoutChange)) > 1;

if (!isCompound) {
return failure();
Expand Down Expand Up @@ -837,20 +835,20 @@ class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern<ToLayoutOp> {
inputLayout.withElementType(rewriter.getContext(),
outputLayout.getElementType()));
}
} else if (isLayoutChange && inputLayout.isTiled()) {
} else if (components.isLayoutChange && inputLayout.isTiled()) {
// For now to flexibly support layout changes, we need to bounce to scalar
// first
bounce(rewriter, op,
inputLayout.withElementType(rewriter.getContext(),
inputLayout.getScalarElementType()));
} else if (isGridChange) {
assert(!isLayoutChange &&
} else if (components.isGridChange) {
assert(!components.isLayoutChange &&
"Changing layout and grid at the same time is currently "
"not supported");
bounce(rewriter, op,
outputLayout.withGrid(rewriter.getContext(), outputType,
inputLayout.getGrid()));
} else if (isMemoryLayoutChange) {
} else if (components.isMemoryLayoutChange) {
bounce(rewriter, op,
inputLayout.withMemoryLayout(rewriter.getContext(),
outputLayout.getMemLayout()));
Expand Down
22 changes: 11 additions & 11 deletions lib/Dialect/TTMetal/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,17 +364,17 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern<ttir::ToLayoutOp> {
auto inputLayout = mlir::cast<tt::LayoutAttr>(inputTy.getEncoding());
auto outputLayout = mlir::cast<tt::LayoutAttr>(outputTy.getEncoding());

auto [isLayoutChange, isGridChange, isFormatChange, isMemorySpaceChange,
isMemoryLayoutChange] = op.compoundComponents();
bool isCompound =
(static_cast<int>(isLayoutChange) + static_cast<int>(isGridChange) +
static_cast<int>(isFormatChange) +
static_cast<int>(isMemorySpaceChange) +
static_cast<int>(isMemoryLayoutChange)) > 1;
auto components = op.compoundComponents();
bool isCompound = (static_cast<int>(components.isLayoutChange) +
static_cast<int>(components.isGridChange) +
static_cast<int>(components.isFormatChange) +
static_cast<int>(components.isMemorySpaceChange) +
static_cast<int>(components.isMemoryLayoutChange)) > 1;

assert(!isCompound && "Only one change is allowed");
assert(!isMemoryLayoutChange &&
assert(!components.isMemoryLayoutChange &&
"Tensor memory layout shouldn't change in metal backend");
if (isMemorySpaceChange) {
if (components.isMemorySpaceChange) {
if (inputLayout.isSystemMemorySpace()) {
assert(outputLayout.isDeviceMemorySpace());
rewriter.replaceOpWithNewOp<ttmetal::HostWriteOp>(
Expand All @@ -386,10 +386,10 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern<ttir::ToLayoutOp> {
} else {
assert(false && "L1 <-> DRAM not supported yet");
}
} else if (isLayoutChange || isGridChange) {
} else if (components.isLayoutChange || components.isGridChange) {
return relayout(op, rewriter);
} else {
assert(isFormatChange);
assert(components.isFormatChange);
return reformat(op, rewriter);
}
return failure();
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() {
"sharded memory layouts");
}

if (outputMemorySpace == MemorySpace::DeviceDRAM &&
if (outputMemorySpace == ::mlir::tt::MemorySpace::DeviceDRAM &&
outputMemoryLayout != ::mlir::tt::TensorMemoryLayout::Interleaved) {
return emitOpError(
"Device DRAM memory space only supports interleaved memory layout");
Expand Down
2 changes: 1 addition & 1 deletion runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ createL1MemoryConfig(const ::tt::target::TensorRef *tensorRef) {
std::copy(memoryDescShape->begin(), memoryDescShape->end(),
shardShape.begin());
TT_FATAL(shardShape[0] % ::tt::constants::TILE_HEIGHT == 0 and
shardShape[1] % ::tt::constants::TILE_WIDTH != 0,
shardShape[1] % ::tt::constants::TILE_WIDTH == 0,
"Shard shape ({}, {}) does not divide tile shape ({}, {}) evenly",
shardShape[0], shardShape[1], ::tt::constants::TILE_HEIGHT,
::tt::constants::TILE_WIDTH);
Expand Down
19 changes: 0 additions & 19 deletions runtime/lib/ttnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,6 @@ inline ::tt::target::DataType fromTTNNDataType(::ttnn::DataType dataType) {
}
}

inline ::tt::tt_metal::TensorMemoryLayout
toTensorMemoryLayout(::tt::target::TensorMemoryLayout memLayout) {
switch (memLayout) {
case ::tt::target::TensorMemoryLayout::Interleaved:
return ::tt::tt_metal::TensorMemoryLayout::INTERLEAVED;
case ::tt::target::TensorMemoryLayout::SingleBank:
return ::tt::tt_metal::TensorMemoryLayout::SINGLE_BANK;
case ::tt::target::TensorMemoryLayout::HeightSharded:
return ::tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED;
case ::tt::target::TensorMemoryLayout::WidthSharded:
return ::tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED;
case ::tt::target::TensorMemoryLayout::BlockSharded:
return ::tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED;

default:
throw std::runtime_error("Unsupported shard strategy");
}
}

inline std::vector<uint32_t>
toShapeFromFBShape(const flatbuffers::Vector<int32_t> &vec) {
return std::vector<uint32_t>(vec.begin(), vec.end());
Expand Down

0 comments on commit 3262f11

Please sign in to comment.