Skip to content

Commit

Permalink
make emitted cpp code compileable (#722)
Browse files Browse the repository at this point in the history
  • Loading branch information
svuckovicTT authored Sep 18, 2024
1 parent 5486220 commit 5d11b85
Show file tree
Hide file tree
Showing 23 changed files with 825 additions and 239 deletions.
2 changes: 1 addition & 1 deletion cmake/modules/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
set(ttmlir_cmake_builddir "${CMAKE_BINARY_DIR}/lib/cmake/ttmlir")

set_property(GLOBAL APPEND PROPERTY TTMLIR_EXPORTS "MLIRTTDialect;MLIRTTNNDialect;MLIRTTKernelDialect;MLIRTTMetalDialect;MLIRTTNNTransforms;")
set_property(GLOBAL APPEND PROPERTY TTMLIR_EXPORTS "MLIRTTDialect;MLIRTTNNDialect;TTMLIRTTNNUtils;MLIRTTKernelDialect;MLIRTTMetalDialect;MLIRTTNNTransforms;")
get_property(TTMLIR_EXPORTS GLOBAL PROPERTY TTMLIR_EXPORTS)
export(TARGETS ${TTMLIR_EXPORTS} FILE ${ttmlir_cmake_builddir}/TTMLIRTargets.cmake)

Expand Down
8 changes: 7 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,14 @@ def TTNN_EmptyOp : TTNN_Op<"empty"> {
Tensor empty operation
}];

let arguments = (ins TT_Device:$device);
let arguments = (ins Optional<TT_Device>:$device,
TTNN_ShapeAttr:$shape,
OptionalAttr<TT_DataTypeAttr>:$dtype,
OptionalAttr<TTNN_LayoutAttr>:$layout,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);
let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_FullOp : TTNN_Op<"full"> {
Expand Down
10 changes: 10 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def TTNN_MemoryConfigAttr : TTNN_Attr<"MemoryConfig", "memory_config"> {
let assemblyFormat = "`<` params `>`";
}

def TTNN_ShapeAttr : TTNN_Attr<"Shape", "shape"> {
let summary = "TTNN Shape attribute";
let description = [{
TTNN shape attribute
}];

let parameters = (ins ArrayRefParameter<"int64_t">:$shape);
let assemblyFormat = "`<` custom<DimensionList>($shape) `>`";
}

def TTNN_MeshShapeAttr : TTNN_Attr<"MeshShape", "mesh_shape"> {
let summary = "TTNN Mesh Shape";
let description = [{
Expand Down
26 changes: 26 additions & 0 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_UTILS_UTILS_H
#define TTMLIR_DIALECT_TTNN_UTILS_UTILS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

namespace mlir::tt::ttnn::utils {

// Map TT::MemorySpace to TTNN::BufferType
//
mlir::tt::ttnn::BufferType
toTTNNBufferType(const mlir::tt::MemorySpace memorySpace);

// Map TT::TensorMemoryLayout to TTNN::TensorMemoryLayout
//
ttnn::TensorMemoryLayout
toTTNNTensorMemoryLayout(const tt::TensorMemoryLayout ttTensorMemoryLayout);

} // namespace mlir::tt::ttnn::utils

#endif // TTMLIR_DIALECT_TTNN_UTILS_UTILS_H
6 changes: 5 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ table ToDeviceOp {
}

table EmptyOp {
device: tt.target.DeviceRef;
shape: [int64];
dtype: DataType;
layout: TensorLayout;
device: tt.target.DeviceRef; // optional
memcfg: tt.target.MemoryConfigDesc; // optional
out: tt.target.TensorRef;
}

Expand Down
96 changes: 96 additions & 0 deletions include/ttmlir/Target/TTNN/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TARGET_TTNN_UTILS_H
#define TTMLIR_TARGET_TTNN_UTILS_H

#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Target/Common/types_generated.h"
#include <llvm/Support/ErrorHandling.h>

namespace tt::mlir::ttnn::utils {

::tt::target::TensorMemoryLayout toTargetTensorMemoryLayout(
::mlir::tt::ttnn::TensorMemoryLayout tensorMemoryLayout) {

switch (tensorMemoryLayout) {
case ::mlir::tt::ttnn::TensorMemoryLayout::Interleaved:
return ::tt::target::TensorMemoryLayout::Interleaved;
case ::mlir::tt::ttnn::TensorMemoryLayout::SingleBank:
return ::tt::target::TensorMemoryLayout::SingleBank;
case ::mlir::tt::ttnn::TensorMemoryLayout::HeightSharded:
return ::tt::target::TensorMemoryLayout::HeightSharded;
case ::mlir::tt::ttnn::TensorMemoryLayout::WidthSharded:
return ::tt::target::TensorMemoryLayout::WidthSharded;
case ::mlir::tt::ttnn::TensorMemoryLayout::BlockSharded:
return ::tt::target::TensorMemoryLayout::BlockSharded;
}

llvm_unreachable("Unsupported TensorMemoryLayout");
}

::tt::target::BufferType
toTargetBufferType(::mlir::tt::ttnn::BufferType bufferType) {

switch (bufferType) {
case ::mlir::tt::ttnn::BufferType::DRAM:
return ::tt::target::BufferType::DRAM;
case ::mlir::tt::ttnn::BufferType::L1:
return ::tt::target::BufferType::L1;
case ::mlir::tt::ttnn::BufferType::SystemMemory:
return ::tt::target::BufferType::SystemMemory;
case ::mlir::tt::ttnn::BufferType::L1Small:
return ::tt::target::BufferType::L1Small;
case ::mlir::tt::ttnn::BufferType::Trace:
return ::tt::target::BufferType::Trace;
}

llvm_unreachable("Unsupported BufferType");
}

::tt::target::TensorLayout
toTargetTensorLayout(::mlir::tt::ttnn::Layout layout) {
switch (layout) {
case ::mlir::tt::ttnn::Layout::RowMajor:
return ::tt::target::TensorLayout::RowMajor;
case ::mlir::tt::ttnn::Layout::Tile:
return ::tt::target::TensorLayout::Tile;
case ::mlir::tt::ttnn::Layout::Invalid:
llvm_unreachable("Unsupported Layout");
}

llvm_unreachable("Unsupported Layout");
}

::tt::target::DataType toTargetDataType(::mlir::tt::DataType dataType) {
switch (dataType) {
case ::mlir::tt::DataType::BFloat16:
return ::tt::target::DataType::BFloat16;
case ::mlir::tt::DataType::Float32:
return ::tt::target::DataType::Float32;
case ::mlir::tt::DataType::UInt32:
return ::tt::target::DataType::UInt32;
case ::mlir::tt::DataType::BFP_BFloat8:
return ::tt::target::DataType::BFP_BFloat8;
case ::mlir::tt::DataType::BFP_BFloat4:
return ::tt::target::DataType::BFP_BFloat4;
case ::mlir::tt::DataType::UInt8:
return ::tt::target::DataType::UInt8;
case ::mlir::tt::DataType::UInt16:
return ::tt::target::DataType::UInt16;
case ::mlir::tt::DataType::Float16:
case ::mlir::tt::DataType::BFP_Float2:
case ::mlir::tt::DataType::BFP_Float4:
case ::mlir::tt::DataType::BFP_Float8:
case ::mlir::tt::DataType::BFP_BFloat2:
llvm_unreachable("Unsupported DataType");
}

llvm_unreachable("Unsupported DataType");
}

} // namespace tt::mlir::ttnn::utils

#endif // TTMLIR_TARGET_TTNN_UTILS_H
101 changes: 62 additions & 39 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand All @@ -19,6 +20,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"

using namespace mlir;
Expand Down Expand Up @@ -54,9 +56,62 @@ class TensorEmptyConversionPattern
LogicalResult
matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// Get tt::LayoutAttr of the result type
//
tt::LayoutAttr ttLayoutAttr =
mlir::cast<tt::LayoutAttr>(op.getResult().getType().getEncoding());

// Get the shape of the tensor, tensor layout, and data type
//
mlir::MemRefType memref = ttLayoutAttr.getMemref();
ttnn::ShapeAttr shapeAttr = ttnn::ShapeAttr::get(
rewriter.getContext(),
mlir::cast<RankedTensorType>(op->getResult(0).getType()).getShape());
Type elementType = memref.getElementType();
DataType dtype = DataType::Float32;
ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor;
if (llvm::isa<TileType>(elementType)) {
ttnnLayoutEnum = ttnn::Layout::Tile;
auto tileType = mlir::cast<TileType>(elementType);
dtype = tileType.getDataType();
} else {
ttnnLayoutEnum = ttnn::Layout::RowMajor;
dtype = elementTypeToDataType(elementType);
}
DataTypeAttr dTypeAttr = DataTypeAttr::get(rewriter.getContext(), dtype);
ttnn::LayoutAttr tensorLayoutAttr =
ttnn::LayoutAttr::get(op.getContext(), ttnnLayoutEnum);

// If the tensor is not going to device, we can create the op without
// device-specific attributes
//
tt::TensorMemoryLayout ttTensorMemoryLayout = ttLayoutAttr.getMemLayout();
if (ttTensorMemoryLayout == TensorMemoryLayout::None) {
rewriter.replaceOpWithNewOp<ttnn::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), nullptr,
shapeAttr, dTypeAttr, tensorLayoutAttr, nullptr);

return success();
}

ttnn::BufferType bufferType =
ttnn::utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace());
ttnn::TensorMemoryLayout tensorMemoryLayout =
ttnn::utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout());

// Create MemoryConfigAttr
//
auto device = getOrInsertDevice(rewriter, op);
ttnn::MemoryConfigAttr memoryConfigAttr = ttnn::MemoryConfigAttr::get(
op.getContext(),
ttnn::TensorMemoryLayoutAttr::get(op.getContext(), tensorMemoryLayout),
ttnn::BufferTypeAttr::get(op.getContext(), bufferType));

rewriter.replaceOpWithNewOp<ttnn::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), device);
op, this->getTypeConverter()->convertType(op.getType()), device,
shapeAttr, dTypeAttr, tensorLayoutAttr, memoryConfigAttr);

return success();
}
};
Expand Down Expand Up @@ -118,27 +173,15 @@ class ToLayoutOpConversionPattern
}

// TODO(bug #665):
// Binary ops fail with row major layout in ttnn, defaulting to tile layout
// for all ops...
// Binary ops fail with row major layout in ttnn, defaulting to tile
// layout for all ops...
//
ttnnLayoutEnum = ttnn::Layout::Tile;

// Map TT::MemorySpace to TTNN::BufferType
//
tt::MemorySpace memorySpace = ttLayoutAttr.getMemorySpace();
ttnn::BufferType bufferType = ttnn::BufferType::DRAM; // default to DRAM
switch (memorySpace) {
case tt::MemorySpace::System:
case tt::MemorySpace::SystemMMIO:
bufferType = ttnn::BufferType::SystemMemory;
break;
case tt::MemorySpace::DeviceDRAM:
bufferType = ttnn::BufferType::DRAM;
break;
case tt::MemorySpace::DeviceL1:
bufferType = ttnn::BufferType::L1;
break;
}
ttnn::BufferType bufferType =
ttnn::utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace());

// If the ToLayoutOp is applied to empty tensor, we need to check whether
// the empty tensor is going back to system memory; if so, we should not
Expand All @@ -154,27 +197,7 @@ class ToLayoutOpConversionPattern
// Set the tensor memory layout
//
ttnn::TensorMemoryLayout tensorMemoryLayout =
ttnn::TensorMemoryLayout::HeightSharded;
switch (ttLayoutAttr.getMemLayout()) {
case tt::TensorMemoryLayout::None:
assert(false && "TensorMemoryLayout::None not supported");
break;
case tt::TensorMemoryLayout::HeightSharded:
tensorMemoryLayout = ttnn::TensorMemoryLayout::HeightSharded;
break;
case tt::TensorMemoryLayout::Interleaved:
tensorMemoryLayout = ttnn::TensorMemoryLayout::Interleaved;
break;
case tt::TensorMemoryLayout::WidthSharded:
tensorMemoryLayout = ttnn::TensorMemoryLayout::WidthSharded;
break;
case tt::TensorMemoryLayout::BlockSharded:
tensorMemoryLayout = ttnn::TensorMemoryLayout::BlockSharded;
break;
case tt::TensorMemoryLayout::SingleBank:
tensorMemoryLayout = ttnn::TensorMemoryLayout::SingleBank;
break;
}
ttnn::utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout());

// TODO(bug #621):
// Add ttnn::Tensor(tensor, dtype) op call once tt-metal is updated
Expand Down Expand Up @@ -433,7 +456,7 @@ class ConstantOpConversionPattern
? static_cast<float>(valueAttr.getSplatValue<int>())
: valueAttr.getSplatValue<float>();
if (fillValue == 0) {
rewriter.replaceOpWithNewOp<ttnn::EmptyOp>(
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), device);
} else {
::mlir::FloatAttr fillValueAttr = rewriter.getF32FloatAttr(fillValue);
Expand Down
Loading

0 comments on commit 5d11b85

Please sign in to comment.