Skip to content

Commit

Permalink
Adding support for data type workarounds
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT committed Dec 20, 2024
1 parent 1ce54d5 commit b595a45
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 29 deletions.
3 changes: 3 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds();
}
}];

let hasVerifier = 1;
Expand Down
40 changes: 32 additions & 8 deletions include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace mlir::tt::ttnn::wa {
using TensorLayoutWorkaround = std::optional<Layout>;
using TensorBufferTypeWorkaround = std::optional<BufferType>;
using TensorMemoryLayoutWorkaround = std::optional<TensorMemoryLayout>;
using TensorDataTypeWorkaround = std::optional<DataType>;

// Struct that encapsulates operand workarounds.
// It contains tensor layout, tensor buffer type and tensor memory layout
Expand All @@ -31,35 +32,47 @@ struct TTNNOperandWorkarounds {
// Tensor memory layout workaround.
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround;

// Tensor data format workaround.
TensorDataTypeWorkaround tensorDataTypeWorkaround;

// Default constructor.
TTNNOperandWorkarounds() = default;

// Constructor that takes tensor layout, tensor buffer type and tensor memory.
TTNNOperandWorkarounds(
TensorLayoutWorkaround tensorLayoutWorkaround,
TensorBufferTypeWorkaround tensorBufferTypeWorkaround,
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround)
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround,
TensorDataTypeWorkaround tensorDataTypeWorkaround)
: tensorLayoutWorkaround(tensorLayoutWorkaround),
tensorBufferTypeWorkaround(tensorBufferTypeWorkaround),
tensorMemoryLayoutWorkaround(tensorMemoryLayoutWorkaround) {}
tensorMemoryLayoutWorkaround(tensorMemoryLayoutWorkaround),
tensorDataTypeWorkaround(tensorDataTypeWorkaround) {}

// Constructor that takes tensor layout workaround and sets the other
// workarounds to nullopt.
TTNNOperandWorkarounds(TensorLayoutWorkaround tensorLayoutWorkaround)
: TTNNOperandWorkarounds(tensorLayoutWorkaround, std::nullopt,
std::nullopt) {}
std::nullopt, std::nullopt) {}

// Constructor that takes tensor buffer type workaround and sets the other
// workarounds to nullopt.
TTNNOperandWorkarounds(TensorBufferTypeWorkaround tensorBufferTypeWorkaround)
: TTNNOperandWorkarounds(std::nullopt, tensorBufferTypeWorkaround,
std::nullopt) {}
std::nullopt, std::nullopt) {}

// Constructor that takes tensor memory layout workaround and sets the other
// workarounds to nullopt.
TTNNOperandWorkarounds(
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround)
: TTNNOperandWorkarounds(std::nullopt, std::nullopt,
tensorMemoryLayoutWorkaround) {}
tensorMemoryLayoutWorkaround, std::nullopt) {}

// Constructor that takes tensor data type workaround and sets the other
// workarounds to nullopt.
TTNNOperandWorkarounds(TensorDataTypeWorkaround tensorDataTypeWorkaround)
: TTNNOperandWorkarounds(std::nullopt, std::nullopt, std::nullopt,
tensorDataTypeWorkaround) {}

// Operand workarounds factory methods.
static TTNNOperandWorkarounds createEmptyTTNNOperandWorkarounds();
Expand All @@ -68,7 +81,8 @@ struct TTNNOperandWorkarounds {
bool operator==(const TTNNOperandWorkarounds &rhs) const {
return tensorLayoutWorkaround == rhs.tensorLayoutWorkaround &&
tensorBufferTypeWorkaround == rhs.tensorBufferTypeWorkaround &&
tensorMemoryLayoutWorkaround == rhs.tensorMemoryLayoutWorkaround;
tensorMemoryLayoutWorkaround == rhs.tensorMemoryLayoutWorkaround &&
tensorDataTypeWorkaround == rhs.tensorDataTypeWorkaround;
}

// Inequality operator.
Expand All @@ -79,7 +93,7 @@ struct TTNNOperandWorkarounds {
// Returns true if any of the workarounds is set.
bool hasAnyWorkaround() const {
return tensorLayoutWorkaround || tensorBufferTypeWorkaround ||
tensorMemoryLayoutWorkaround;
tensorMemoryLayoutWorkaround || tensorDataTypeWorkaround;
}
};

Expand All @@ -103,6 +117,9 @@ struct BufferTypeWorkaroundResult : public WorkaroundResult<BufferType> {};
struct MemoryLayoutWorkaroundResult
: public WorkaroundResult<std::optional<TensorMemoryLayout>> {};

// Data type workaround result struct.
struct DataTypeWorkaroundResult : public WorkaroundResult<DataType> {};

// Struct that encapsulates the result of applying the workarounds.
// It contains the target tensor layout, buffer type and tensor memory layout
// results and a flag indicating whether the workarounds were applied.
Expand All @@ -116,11 +133,15 @@ struct WorkaroundResults {
// Tensor memory layout workaround result.
MemoryLayoutWorkaroundResult tensorMemoryLayoutResult;

// Tensor data type workaround result.
DataTypeWorkaroundResult tensorDataTypeResult;

// Returns true if any of the workarounds were applied.
bool isModified() const {
return tensorLayoutResult.isModified() ||
tensorBufferTypeResult.isModified() ||
tensorMemoryLayoutResult.isModified();
tensorMemoryLayoutResult.isModified() ||
tensorDataTypeResult.isModified();
}
};

Expand Down Expand Up @@ -194,6 +215,9 @@ class TTNNOperandsWorkaroundsFactory {
public:
// Create workarounds for max_pool2d op operands.
static TTNNOperandsWorkarounds createMaxPool2DOpOperandsWorkarounds();

// Create workarounds for embedding op operands.
static TTNNOperandsWorkarounds createEmbeddingOpOperandsWorkarounds();
};

} // namespace mlir::tt::ttnn::wa
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ createToLayoutOp(mlir::Operation *op,
mlir::TypedValue<RankedTensorType> inputValue,
PatternRewriter &rewriter, Layout targetTensorLayout,
BufferType targetTensorBufferType,
std::optional<TensorMemoryLayout> targetTensorMemoryLayout);
std::optional<TensorMemoryLayout> targetTensorMemoryLayout,
DataType targetTensorDataType);
} // namespace mlir::tt::ttnn::utils

#endif
7 changes: 6 additions & 1 deletion include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType);
mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context,
DataType dtype);

// Helper method to create a RankedTensorType with the given encoding
// Helper method to create a RankedTensorType with the given encoding.
RankedTensorType
createRankedTensorTypeWithEncoding(RankedTensorType tensorType,
ttnn::TTNNLayoutAttr encoding);

// Helper method to create a RankedTensorType with the given element type.
RankedTensorType
createRankedTensorTypeWithElementType(RankedTensorType tensorType,
Type elementType);

// Return the L1 memory usage of the output tensor of the given op.
// Used within L1 interleaved policies.
//
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ class ToLayoutOpConversionPattern
// operands.
for (mlir::Operation *user : op.getResult().getUsers()) {
if (isa<ttir::Conv2dOp>(user) || isa<ttir::SliceOp>(user) ||
isa<ttir::EmbeddingOp>(user) ||
(isa<ttir::EmbeddingBackwardOp>(user) &&
(user->getOperand(0) == op || user->getOperand(1) == op))) {
return true;
Expand Down
26 changes: 26 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ WorkaroundResults applyWorkarounds(const TTNNOperandWorkarounds &workaround,
results.tensorMemoryLayoutResult.previousValue =
inputLayoutAttr.getMemLayoutOpt();

results.tensorDataTypeResult.targetValue =
workaround.tensorDataTypeWorkaround.value_or(
inputLayoutAttr.getDataType());
results.tensorDataTypeResult.previousValue = inputLayoutAttr.getDataType();

return results;
}

Expand Down Expand Up @@ -87,4 +92,25 @@ TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds() {
.addInputOperandWorkaround(rowMajorLayoutWorkaround)
.addOutputOperandWorkaround(rowMajorLayoutWorkaround);
}

// Factory method to create a set of workarounds for embedding operation
// operands. The embedding operation expects the input to be in row-major layout
// and the weight operand to use the bf16 data type. Since the output of the
// embedding operation follows the same format as the weight operand, the same
// workaround is applied to the output operand.Metal issue for input operand
// workaround: https://github.com/tenstorrent/tt-metal/issues/14915 Metal issue
// for weight operand workaround: to be added
TTNNOperandsWorkarounds
TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds() {
// Create input and weight workarounds.
TTNNOperandWorkarounds inputWorkaround =
TTNNOperandWorkarounds(Layout::RowMajor);
TTNNOperandWorkarounds weightWorkaround =
TTNNOperandWorkarounds(DataType::BFloat16);
return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(0, 0)
.addInputOperandWorkaround(inputWorkaround)
.addInputOperandWorkaround(weightWorkaround)
.addInputOperandWorkaround(weightWorkaround)
.addOutputOperandWorkaround(weightWorkaround);
}
} // namespace mlir::tt::ttnn::wa
39 changes: 28 additions & 11 deletions lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"

#include <optional>
#include <tuple>
Expand Down Expand Up @@ -57,7 +57,8 @@ static void revertOutputLayout(wa::TTNNWorkaroundInterface &op,
op.getOperation(), newOpResult, rewriter,
workaroundResults.tensorLayoutResult.previousValue,
workaroundResults.tensorBufferTypeResult.previousValue,
workaroundResults.tensorMemoryLayoutResult.previousValue);
workaroundResults.tensorMemoryLayoutResult.previousValue,
workaroundResults.tensorDataTypeResult.previousValue);

// Replace the new output result with the casted output result.
rewriter.replaceUsesWithIf(
Expand Down Expand Up @@ -93,7 +94,8 @@ static bool workaroundInputOperand(
op.getOperation(), inputValue, rewriter,
inputWorkaroundResults.tensorLayoutResult.targetValue,
inputWorkaroundResults.tensorBufferTypeResult.targetValue,
inputWorkaroundResults.tensorMemoryLayoutResult.targetValue);
inputWorkaroundResults.tensorMemoryLayoutResult.targetValue,
inputWorkaroundResults.tensorDataTypeResult.targetValue);

// Insert to layout op between the current op and the input operand
// to convert the input operand to the desired tensor layout, buffer type.
Expand Down Expand Up @@ -136,7 +138,7 @@ workaroundOutputOperand(mlir::TypedValue<RankedTensorType> opResult,
Type elementType = utils::getElementType(
rewriter.getContext(),
outputWorkaroundResults.tensorLayoutResult.targetValue,
opResultLayoutAttr.getDataType());
outputWorkaroundResults.tensorDataTypeResult.targetValue);

// Get the input operand type.
RankedTensorType opResultType =
Expand All @@ -150,16 +152,24 @@ workaroundOutputOperand(mlir::TypedValue<RankedTensorType> opResult,
*outputWorkaroundResults.tensorMemoryLayoutResult.targetValue)
: nullptr;

// Create the new output result type with the updated tensor layout, buffer
// type and memory layout.
// Create the new output layout attribute with the updated tensor layout,
// buffer type, memory layout and data type.
TTNNLayoutAttr newOutputLayoutAttr =
opResultLayoutAttr.withElementType(rewriter.getContext(), elementType)
.withBufferType(
rewriter.getContext(),
outputWorkaroundResults.tensorBufferTypeResult.targetValue)
.withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr);

// Create the new output result type with the updated data type and layout.
RankedTensorType newOutputResultType =
ttnn::utils::createRankedTensorTypeWithEncoding(
opResultType,
opResultLayoutAttr.withElementType(rewriter.getContext(), elementType)
.withBufferType(
ttnn::utils::createRankedTensorTypeWithElementType(
opResultType,
ttnn::utils::createRowMajorTypeFromDtype(
rewriter.getContext(),
outputWorkaroundResults.tensorBufferTypeResult.targetValue)
.withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr));
outputWorkaroundResults.tensorDataTypeResult.targetValue)),
newOutputLayoutAttr);

// Update the type of result with applied workarounds.
rewriter.modifyOpInPlace(op, [&]() {
Expand All @@ -175,6 +185,13 @@ workaroundOutputOperand(mlir::TypedValue<RankedTensorType> opResult,
op->setAttr("layout", updatedLayoutAttr);
}

if (outputWorkaroundResults.tensorDataTypeResult.isModified() &&
op->getAttrDictionary().get("dtype")) {
DataTypeAttr updatedDataTypeAttr = rewriter.getAttr<DataTypeAttr>(
outputWorkaroundResults.tensorDataTypeResult.targetValue);
op->setAttr("dtype", updatedDataTypeAttr);
}

if ((outputWorkaroundResults.tensorBufferTypeResult.isModified() ||
outputWorkaroundResults.tensorMemoryLayoutResult.isModified()) &&
op->getAttrDictionary().get("memory_config")) {
Expand Down
17 changes: 11 additions & 6 deletions lib/Dialect/TTNN/Utils/TransformUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ ToLayoutOp
createToLayoutOp(Operation *op, mlir::TypedValue<RankedTensorType> inputValue,
PatternRewriter &rewriter, Layout targetTensorLayout,
BufferType targetTensorBufferType,
std::optional<TensorMemoryLayout> targetTensorMemoryLayout) {
std::optional<TensorMemoryLayout> targetTensorMemoryLayout,
DataType targetTensorDataType) {
TTNNLayoutAttr inputLayoutAttr = getLayoutAttrFromTensor(inputValue);

// Create element type based on tensor layout.
Type elementType = getElementType(rewriter.getContext(), targetTensorLayout,
inputLayoutAttr.getDataType());
targetTensorDataType);

// Create tensor memory layout attribute.
ttnn::TensorMemoryLayoutAttr outputMemLayoutAttr =
Expand All @@ -63,10 +64,14 @@ createToLayoutOp(Operation *op, mlir::TypedValue<RankedTensorType> inputValue,
.withBufferType(rewriter.getContext(), targetTensorBufferType)
.withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr);

// Create the output result type with the new encoding.
// Create the output result type with the new data type and encoding.
RankedTensorType toLayoutOpResultType =
ttnn::utils::createRankedTensorTypeWithEncoding(inputToLayoutOpType,
toLayoutOpResultEncoding);
ttnn::utils::createRankedTensorTypeWithEncoding(
ttnn::utils::createRankedTensorTypeWithElementType(
inputToLayoutOpType,
utils::createRowMajorTypeFromDtype(rewriter.getContext(),
targetTensorDataType)),
toLayoutOpResultEncoding);

// Create the output memory config attribute.
ttnn::MemoryConfigAttr outputMemConfigAttr = ttnn::MemoryConfigAttr::get(
Expand All @@ -88,7 +93,7 @@ createToLayoutOp(Operation *op, mlir::TypedValue<RankedTensorType> inputValue,
return rewriter.create<ttnn::ToLayoutOp>(
op->getLoc(), toLayoutOpResultType, inputValue,
LayoutAttr::get(rewriter.getContext(), targetTensorLayout),
DataTypeAttr::get(rewriter.getContext(), inputLayoutAttr.getDataType()),
DataTypeAttr::get(rewriter.getContext(), targetTensorDataType),
outputMemConfigAttr, deviceValue);
}
} // namespace mlir::tt::ttnn::utils
10 changes: 9 additions & 1 deletion lib/Dialect/TTNN/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,22 @@ Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context, DataType dtype) {
}
}

// Helper method to create a RankedTensorType with the given encoding
// Helper method to create a RankedTensorType with the given encoding.
RankedTensorType
createRankedTensorTypeWithEncoding(RankedTensorType tensorType,
ttnn::TTNNLayoutAttr encoding) {
return RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
}

// Helper method to create a RankedTensorType with the given element type.
RankedTensorType
createRankedTensorTypeWithElementType(RankedTensorType tensorType,
Type elementType) {
return RankedTensorType::get(tensorType.getShape(), elementType,
tensorType.getEncoding());
}

uint64_t getOpOutputL1Usage(TTNNLayoutAttr opLayout) {
// In case the opLayout is not in L1 memory space, L1 memory usage is 0.
//
Expand Down
Loading

0 comments on commit b595a45

Please sign in to comment.