Skip to content

Commit

Permalink
removed duplicate TensorMemoryLayout enum
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Sep 24, 2024
1 parent 9abcba1 commit abc3b55
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 87 deletions.
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/CommonTypeConstraints.td"
include "ttmlir/Dialect/TTNN/IR/TTNNBase.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td"
include "ttmlir/Dialect/TT/IR/TTOpsEnums.td"

//===----------------------------------------------------------------------===//
// TTNN attr definitions
Expand Down Expand Up @@ -48,7 +49,7 @@ def TTNN_LayoutAttr : EnumAttr<TTNN_Dialect, TTNN_Layout, "layout"> {
let assemblyFormat = "`<` $value `>`";
}

def TTNN_TensorMemoryLayoutAttr : EnumAttr<TTNN_Dialect, TTNN_TensorMemoryLayout, "tensor_memory_layout"> {
def TTNN_TensorMemoryLayoutAttr : EnumAttr<TTNN_Dialect, TT_TensorMemoryLayout, "tensor_memory_layout"> {
let assemblyFormat = "`<` $value `>`";
}

Expand Down
18 changes: 0 additions & 18 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,6 @@ def TTNN_Layout : I32EnumAttr<"Layout", "TTNN Layout",
let cppNamespace = "::mlir::tt::ttnn";
}

def TTNN_TensorMemoryLayout_Interleaved : I32EnumAttrCase<"Interleaved", 0, "interleaved">;
def TTNN_TensorMemoryLayout_SingleBank : I32EnumAttrCase<"SingleBank", 1, "single_bank">;
def TTNN_TensorMemoryLayout_HeightSharded : I32EnumAttrCase<"HeightSharded", 2, "height_sharded">;
def TTNN_TensorMemoryLayout_WidthSharded : I32EnumAttrCase<"WidthSharded", 3, "width_sharded">;
def TTNN_TensorMemoryLayout_BlockSharded : I32EnumAttrCase<"BlockSharded", 4, "block_sharded">;

def TTNN_TensorMemoryLayout : I32EnumAttr<"TensorMemoryLayout", "TTNN Tensor Memory Layout",
[
TTNN_TensorMemoryLayout_Interleaved,
TTNN_TensorMemoryLayout_SingleBank,
TTNN_TensorMemoryLayout_HeightSharded,
TTNN_TensorMemoryLayout_WidthSharded,
TTNN_TensorMemoryLayout_BlockSharded,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt::ttnn";
}

def TTNN_BufferType_DRAM : I32EnumAttrCase<"DRAM", 0, "dram">;
def TTNN_BufferType_L1 : I32EnumAttrCase<"L1", 1, "l1">;
def TTNN_BufferType_SystemMemory : I32EnumAttrCase<"SystemMemory", 2, "system_memory">;
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ toTTNNBufferType(const mlir::tt::MemorySpace memorySpace);

// Map TT::TensorMemoryLayout to TTNN::TensorMemoryLayout
//
ttnn::TensorMemoryLayout
TensorMemoryLayout
toTTNNTensorMemoryLayout(const tt::TensorMemoryLayout ttTensorMemoryLayout);

} // namespace mlir::tt::ttnn::utils
Expand Down
19 changes: 10 additions & 9 deletions include/ttmlir/Target/TTNN/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,31 @@
#ifndef TTMLIR_TARGET_TTNN_UTILS_H
#define TTMLIR_TARGET_TTNN_UTILS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Target/Common/types_generated.h"
#include <llvm/Support/ErrorHandling.h>

namespace tt::mlir::ttnn::utils {

::tt::target::TensorMemoryLayout toTargetTensorMemoryLayout(
::mlir::tt::ttnn::TensorMemoryLayout tensorMemoryLayout) {
::tt::target::TensorMemoryLayout
toTargetTensorMemoryLayout(::mlir::tt::TensorMemoryLayout tensorMemoryLayout) {

switch (tensorMemoryLayout) {
case ::mlir::tt::ttnn::TensorMemoryLayout::Interleaved:
case ::mlir::tt::TensorMemoryLayout::Interleaved:
return ::tt::target::TensorMemoryLayout::Interleaved;
case ::mlir::tt::ttnn::TensorMemoryLayout::SingleBank:
case ::mlir::tt::TensorMemoryLayout::SingleBank:
return ::tt::target::TensorMemoryLayout::SingleBank;
case ::mlir::tt::ttnn::TensorMemoryLayout::HeightSharded:
case ::mlir::tt::TensorMemoryLayout::HeightSharded:
return ::tt::target::TensorMemoryLayout::HeightSharded;
case ::mlir::tt::ttnn::TensorMemoryLayout::WidthSharded:
case ::mlir::tt::TensorMemoryLayout::WidthSharded:
return ::tt::target::TensorMemoryLayout::WidthSharded;
case ::mlir::tt::ttnn::TensorMemoryLayout::BlockSharded:
case ::mlir::tt::TensorMemoryLayout::BlockSharded:
return ::tt::target::TensorMemoryLayout::BlockSharded;
case ::mlir::tt::TensorMemoryLayout::None:
llvm_unreachable("Unsupported TensorMemoryLayout");
}

llvm_unreachable("Unsupported TensorMemoryLayout");
}

::tt::target::BufferType
Expand Down
8 changes: 3 additions & 5 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,14 @@ class TensorEmptyConversionPattern

ttnn::BufferType bufferType =
ttnn::utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace());
ttnn::TensorMemoryLayout tensorMemoryLayout =
ttnn::utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout());

// Create MemoryConfigAttr
//
auto device = getOrInsertDevice(rewriter, op);
ttnn::MemoryConfigAttr memoryConfigAttr = ttnn::MemoryConfigAttr::get(
op.getContext(),
ttnn::TensorMemoryLayoutAttr::get(op.getContext(), tensorMemoryLayout),
ttnn::TensorMemoryLayoutAttr::get(op.getContext(),
ttTensorMemoryLayout),
ttnn::BufferTypeAttr::get(op.getContext(), bufferType));

rewriter.replaceOpWithNewOp<ttnn::EmptyOp>(
Expand Down Expand Up @@ -199,8 +198,7 @@ class ToLayoutOpConversionPattern

// Set the tensor memory layout
//
ttnn::TensorMemoryLayout tensorMemoryLayout =
ttnn::utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout());
TensorMemoryLayout tensorMemoryLayout = ttLayoutAttr.getMemLayout();

// TODO(bug #621):
// Add ttnn::Tensor(tensor, dtype) op call once tt-metal is updated
Expand Down
30 changes: 4 additions & 26 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h"

#include "ttmlir/Dialect/TT/IR/TTOpsDialect.h.inc"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
Expand All @@ -30,6 +31,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include <cassert>

using namespace mlir;
using namespace mlir::tt;
Expand Down Expand Up @@ -59,28 +61,6 @@ emitc::OpaqueAttr convertLayoutAttr(Builder &builder, ttnn::LayoutAttr attr) {

// Create emitc::OpaqueAttr for ttnn::TensorMemoryLayout
//
emitc::OpaqueAttr convertTensorMemoryLayout(Builder &builder,
ttnn::TensorMemoryLayoutAttr attr) {
switch (attr.getValue()) {
case ttnn::TensorMemoryLayout::BlockSharded:
return builder.getType<emitc::OpaqueAttr>(
"ttnn::TensorMemoryLayout::BLOCK_SHARDED");
case ttnn::TensorMemoryLayout::HeightSharded:
return builder.getType<emitc::OpaqueAttr>(
"ttnn::TensorMemoryLayout::HEIGHT_SHARDED");
case ttnn::TensorMemoryLayout::Interleaved:
return builder.getType<emitc::OpaqueAttr>(
"ttnn::TensorMemoryLayout::INTERLEAVED");
case ttnn::TensorMemoryLayout::SingleBank:
return builder.getType<emitc::OpaqueAttr>(
"ttnn::TensorMemoryLayout::SINGLE_BANK");
case ttnn::TensorMemoryLayout::WidthSharded:
return builder.getType<emitc::OpaqueAttr>(
"ttnn::TensorMemoryLayout::WIDTH_SHARDED");
}

llvm_unreachable("Unknown ttnn::TensorMemoryLayout");
}

// Create emitc::OpaqueAttr for ttnn::BufferType
//
Expand Down Expand Up @@ -305,8 +285,7 @@ class ToDeviceOpConversionPattern
// Create ArrayAttr object holding MemoryConfig attributes
//
ArrayAttr arrayAttrs = rewriter.getArrayAttr(
{convertTensorMemoryLayout(
rewriter, srcOp.getMemoryConfig().getTensorMemoryLayout()),
{srcOp.getMemoryConfig().getTensorMemoryLayout(),
convertBufferType(rewriter, srcOp.getMemoryConfig().getBufferType())});

// Create MemoryConfig object first, then pass it to the op
Expand Down Expand Up @@ -456,8 +435,7 @@ class EmptyOpConversionPattern
// Create ArrayAttr object holding MemoryConfig attributes
//
ArrayAttr memCfgArrayAttrs = rewriter.getArrayAttr(
{convertTensorMemoryLayout(
rewriter, srcOp.getMemoryConfig()->getTensorMemoryLayout()),
{srcOp.getMemoryConfig()->getTensorMemoryLayout(),
convertBufferType(rewriter,
srcOp.getMemoryConfig()->getBufferType())});

Expand Down
4 changes: 1 addition & 3 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() {
if (getMemoryConfig().has_value()) {
ttnn::BufferType bufferType =
mlir::tt::ttnn::utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace());
ttnn::TensorMemoryLayout tensorMemoryLayout =
mlir::tt::ttnn::utils::toTTNNTensorMemoryLayout(
ttLayoutAttr.getMemLayout());
TensorMemoryLayout tensorMemoryLayout = ttLayoutAttr.getMemLayout();
assert(bufferType == getMemoryConfig()->getBufferType().getValue());
assert(tensorMemoryLayout ==
getMemoryConfig()->getTensorMemoryLayout().getValue());
Expand Down
24 changes: 0 additions & 24 deletions lib/Dialect/TTNN/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,3 @@ mlir::tt::ttnn::BufferType mlir::tt::ttnn::utils::toTTNNBufferType(

llvm_unreachable("Unknown MemorySpace");
}

// Map TT::TensorMemoryLayout to TTNN::TensorMemoryLayout
//
mlir::tt::ttnn::TensorMemoryLayout
mlir::tt::ttnn::utils::toTTNNTensorMemoryLayout(
const ::mlir::tt::TensorMemoryLayout ttTensorMemoryLayout) {

switch (ttTensorMemoryLayout) {
case ::mlir::tt::TensorMemoryLayout::HeightSharded:
return ttnn::TensorMemoryLayout::HeightSharded;
case ::mlir::tt::TensorMemoryLayout::Interleaved:
return ttnn::TensorMemoryLayout::Interleaved;
case ::mlir::tt::TensorMemoryLayout::WidthSharded:
return ttnn::TensorMemoryLayout::WidthSharded;
case ::mlir::tt::TensorMemoryLayout::BlockSharded:
return ttnn::TensorMemoryLayout::BlockSharded;
case ::mlir::tt::TensorMemoryLayout::SingleBank:
return ttnn::TensorMemoryLayout::SingleBank;
case ::mlir::tt::TensorMemoryLayout::None:
assert(false && "TensorMemoryLayout::None not supported");
}

llvm_unreachable("Unknown TensorMemoryLayout");
}

0 comments on commit abc3b55

Please sign in to comment.