From d22057f9ac87616e590f179bdece3a771e47c607 Mon Sep 17 00:00:00 2001 From: Milan Topalovic <163355844+mtopalovicTT@users.noreply.github.com> Date: Wed, 27 Nov 2024 12:42:16 +0100 Subject: [PATCH] Minor API fixes for TTNN encoding ettribute (#1390) This PR adds couple convenience methods to TTNN tensor encoding attribute and also removes redundant utils functions. Renaming/Adding some new functions... * `getDataType` to get scalar data type: * `memref<2x2x!tt.tile<32x32xf32>>` returns float data type * `memref<128x128xi32>` returns int data type * `getElementType` to get type from memref: * `memref<2x2x!tt.tile<32x32xf32>>` returns TileType * `memref<128x128xi32>` returns IntegerType * `getLayout` - gets layout of encoding i.e Tile/RowMajor * `getShardShape`: * `memref<2x2x!tt.tile<32x32xf32>>` returns `(2, 2)` * `memref<128x128xi32>` returns `(128, 128)` * `getScalarShardShape`: * `memref<2x2x!tt.tile<32x32xf32>>` returns `(64, 64)` * `memref<128x128xi32>` returns `(128, 128)` --- .../ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td | 15 +- include/ttmlir/Dialect/TTNN/Utils/Utils.h | 4 - lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 42 ++--- lib/Dialect/TTNN/IR/TTNNOps.cpp | 17 +- lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp | 145 ++++++++++-------- lib/Dialect/TTNN/Transforms/Optimizer.cpp | 18 +-- lib/Dialect/TTNN/Transforms/Passes.cpp | 18 +-- 7 files changed, 121 insertions(+), 138 deletions(-) diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td index bba7fe6f2..e45fba003 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td @@ -109,6 +109,13 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> { let summary = "Tensor encoding attribute used for types in ttnn"; let description = [{ Layout attribute in ttnn. This attribute is used to encode different information about tensor memory layout. + Here is how tensor will look like after layout tensor<32x32x64xf32, #ttnn.ttnn_layout> + Lets break down what each parameter means: + - linear: An affine map that defines how the logical tensor dimensions map to physical space. + - grid: The grid shape (of tensix cores) where tensor is divided onto. + - memref: A memref is used to describe shard size and memory space. Shard size is calculated by dividing the tensor size by grid size. + - mem_layout: The layout of the tensor in memory. For tensor on host it should be None. For tensor on device + it can be interleaved or sharded. }]; let parameters = (ins AttrParameter<"AffineMap", "An affine map that defines how the logical tensor dimensions map to a grid shape.">:$linear, @@ -142,15 +149,15 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> { bool hasShardedL1TensorMemoryLayout() const; bool hasInterleavedL1TensorMemoryLayout() const; bool isTiled() const; + Layout getLayout() const; Type getElementType() const; - DataType getDataTypeFromMemRef() const; + DataType getDataType() const; uint64_t getElementSizeBytes() const; int64_t getTensorSizeInBytes(ArrayRef tensorShape, ::mlir::tt::DeviceAttr device) const; llvm::SmallVector getStride(ArrayRef logicalShape) const; - llvm::SmallVector getPhysicalShape(ArrayRef logicalShape) const; - llvm::SmallVector getShardShape(bool convertTileToScalar = true) const; + llvm::SmallVector getShardShape() const; + llvm::SmallVector getScalarShardShape() const; AffineMap replaceMemoryMapSymbolsWithShardShape(AffineMap physicalMemoryMap) const; - AffineMap projectOnto(AffineMap linearMap, AffineMap physicalMemoryMap) const; AffineMap getIdentityTileLinearMap() const; llvm::SmallVector getTiledShape(ArrayRef logicalTensorShape) const; }]; diff --git a/include/ttmlir/Dialect/TTNN/Utils/Utils.h b/include/ttmlir/Dialect/TTNN/Utils/Utils.h index a6e10c099..533235a61 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/Utils.h +++ b/include/ttmlir/Dialect/TTNN/Utils/Utils.h @@ -31,10 +31,6 @@ mlir::tt::TensorMemoryLayout toTTTensorMemoryLayout( mlir::tt::MemorySpace toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType); -DataType getDataTypeFromMemRef(mlir::MemRefType memref); - -Layout getLayoutFromMemRef(mlir::MemRefType memref); - mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context, DataType dtype); diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 9dbc9cf97..3241928f4 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -65,20 +65,15 @@ class TensorEmptyConversionPattern // Get the shape of the tensor, tensor layout, and data type // - mlir::MemRefType memref = layoutAttr.getMemref(); ttnn::ShapeAttr shapeAttr = ttnn::ShapeAttr::get( rewriter.getContext(), mlir::cast(op->getResult(0).getType()).getShape()); - Type elementType = memref.getElementType(); - DataType dtype = DataType::Float32; + DataType dtype = layoutAttr.getDataType(); ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor; - if (llvm::isa(elementType)) { + if (layoutAttr.isTiled()) { ttnnLayoutEnum = ttnn::Layout::Tile; - auto tileType = mlir::cast(elementType); - dtype = tileType.getDataType(); } else { ttnnLayoutEnum = ttnn::Layout::RowMajor; - dtype = elementTypeToDataType(elementType); } DataTypeAttr dTypeAttr = DataTypeAttr::get(rewriter.getContext(), dtype); ttnn::LayoutAttr tensorLayoutAttr = @@ -101,13 +96,14 @@ class TensorEmptyConversionPattern // Create MemoryConfigAttr // auto device = getOrInsertDevice(rewriter, op); + llvm::SmallVector shardShape = layoutAttr.getShardShape(); ttnn::MemoryConfigAttr memoryConfigAttr = ttnn::MemoryConfigAttr::get( op.getContext(), ttnn::TensorMemoryLayoutAttr::get(op.getContext(), memLayout), ttnn::BufferTypeAttr::get(op.getContext(), bufferType), ttnn::ShardSpecAttr::get( op.getContext(), - ttnn::ShapeAttr::get(op.getContext(), memref.getShape()))); + ttnn::ShapeAttr::get(op.getContext(), shardShape))); rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), device, @@ -137,18 +133,15 @@ class ToLayoutOpConversionPattern auto outputLayoutAttr = mlir::cast( op.getResult().getType().getEncoding()); - auto outputMemref = outputLayoutAttr.getMemref(); - // Determine the output data type - DataType dtype = ttnn::utils::getDataTypeFromMemRef(outputMemref); + DataType dtype = outputLayoutAttr.getDataType(); DataTypeAttr outputDataType = DataTypeAttr::get(rewriter.getContext(), dtype); // Determine the output layout (tile or row major) ttnn::BufferType outputBufferType = outputLayoutAttr.getBufferType(); - ttnn::Layout outputLayoutEnum = - ttnn::utils::getLayoutFromMemRef(outputMemref); + ttnn::Layout outputLayoutEnum = outputLayoutAttr.getLayout(); bool isOutputOnHost = (outputBufferType == ttnn::BufferType::SystemMemory); @@ -176,13 +169,14 @@ class ToLayoutOpConversionPattern op.getResult().setType(result); outputLayoutAttr = mlir::cast(result.getEncoding()); - outputMemref = outputLayoutAttr.getMemref(); outputLayoutEnum = newOutputLayoutEnum; } } ttnn::LayoutAttr outputLayout = ttnn::LayoutAttr::get(rewriter.getContext(), outputLayoutEnum); + llvm::SmallVector outputShardShape = + outputLayoutAttr.getShardShape(); // Determine output memory config attr ttnn::TensorMemoryLayout outputTensorMemoryLayout = @@ -193,8 +187,8 @@ class ToLayoutOpConversionPattern outputTensorMemoryLayout), ttnn::BufferTypeAttr::get(rewriter.getContext(), outputBufferType), ttnn::ShardSpecAttr::get( - op.getContext(), ttnn::ShapeAttr::get(rewriter.getContext(), - outputMemref.getShape()))); + op.getContext(), + ttnn::ShapeAttr::get(rewriter.getContext(), outputShardShape))); rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(result), adaptor.getInput(), @@ -222,15 +216,16 @@ class ToLayoutOpConversionPattern ttnn::Layout newOutputLayoutEnum) const { auto oldOutputLayoutAttr = mlir::cast(oldOutput.getEncoding()); - auto oldOutputMemref = oldOutputLayoutAttr.getMemref(); - DataType outputDtype = ttnn::utils::getDataTypeFromMemRef(oldOutputMemref); - llvm::ArrayRef oldShardShape = oldOutputMemref.getShape(); + DataType outputDtype = oldOutputLayoutAttr.getDataType(); + SmallVector oldShardShape = + oldOutputLayoutAttr.getShardShape(); size_t shardShapeSize = oldShardShape.size(); assert(shardShapeSize >= 2 && "expected at least 2D shape"); if (newOutputLayoutEnum == ttnn::Layout::RowMajor) { // Set shard shape to match convention of row major layout - auto tileType = mlir::cast(oldOutputMemref.getElementType()); + auto tileType = + mlir::cast(oldOutputLayoutAttr.getElementType()); llvm::SmallVector newShardShape(oldShardShape.begin(), oldShardShape.end()); newShardShape[shardShapeSize - 2] = @@ -804,9 +799,7 @@ class TypecastOpConversionPattern ttnn::TTNNLayoutAttr outputLayoutAttr = mlir::cast(result.getType().getEncoding()); - mlir::MemRefType outputMemref = outputLayoutAttr.getMemref(); - - DataType outputDataType = ttnn::utils::getDataTypeFromMemRef(outputMemref); + DataType outputDataType = outputLayoutAttr.getDataType(); if (op->getUsers().empty()) { return rewriter.notifyMatchFailure( @@ -950,8 +943,7 @@ class ArangeOpConversionPattern : public OpConversionPattern { layoutAttr.getMemLayout()), rewriter.getAttr(layoutAttr.getBufferType()), rewriter.getAttr( - rewriter.getAttr( - layoutAttr.getMemref().getShape()))); + rewriter.getAttr(layoutAttr.getShardShape()))); rewriter.replaceOpWithNewOp( op, outputType, adaptor.getStart(), adaptor.getEnd(), adaptor.getStep(), diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index b3201cf67..cd2746aad 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -190,25 +190,12 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() { // DataType and Layout // - mlir::MemRefType memref = layoutAttr.getMemref(); - Type elementType = memref.getElementType(); if (getLayout().has_value()) { - ttnn::Layout ttnnLayoutEnum; - if (llvm::isa(elementType)) { - ttnnLayoutEnum = ttnn::Layout::Tile; - } else { - ttnnLayoutEnum = ttnn::Layout::RowMajor; - } + ttnn::Layout ttnnLayoutEnum = layoutAttr.getLayout(); assert(ttnnLayoutEnum == getLayoutAttr().getValue()); } if (getDtype().has_value()) { - tt::DataType dtype; - if (llvm::isa(elementType)) { - auto tileType = mlir::cast(elementType); - dtype = tileType.getDataType(); - } else { - dtype = elementTypeToDataType(elementType); - } + tt::DataType dtype = layoutAttr.getDataType(); assert(dtype == getDtype()); } diff --git a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp index d80815f91..8aaae1261 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp @@ -34,6 +34,11 @@ bool TTNNLayoutAttr::isTiled() const { return ::mlir::isa<::mlir::tt::TileType>(getElementType()); } +// Get layout of the tensor (RowMajor/Tile) +Layout TTNNLayoutAttr::getLayout() const { + return isTiled() ? Layout::Tile : Layout::RowMajor; +} + // Check if the tensor memory layout is sharded bool TTNNLayoutAttr::hasShardedTensorMemoryLayout() const { return (getMemLayout() == TensorMemoryLayout::HeightSharded || @@ -119,19 +124,19 @@ mlir::Type TTNNLayoutAttr::getElementType() const { return getMemref().getElementType(); } -// Extract data type from the memref. Example: -// memref<2x2xf32> -> f32 -// memref<2x2x!tt.tile<32x32xf32>> -> f32 -mlir::tt::DataType TTNNLayoutAttr::getDataTypeFromMemRef() const { +// Get scalar element type. +// Example: memref<2x2xf32> -> f32 +// Example: memref<2x2x!tt.tile<32x32xf32>> -> f32 +// +// return The scalar element type. +mlir::tt::DataType TTNNLayoutAttr::getDataType() const { Type elementType = getElementType(); - DataType dtype = DataType::Float32; - if (llvm::isa(elementType)) { + if (isTiled()) { TileType tileType = mlir::cast(elementType); - dtype = tileType.getDataType(); - } else { - dtype = elementTypeToDataType(elementType); + return tileType.getDataType(); } - return dtype; + + return elementTypeToDataType(elementType); } // Gets the size of shard in bytes @@ -139,10 +144,10 @@ mlir::tt::DataType TTNNLayoutAttr::getDataTypeFromMemRef() const { // This function returns the size of the shard in bytes. // Size is calculated by multiplying shard shape with element size. // -// /return The size of the shard in bytes. +// return The size of the shard in bytes. uint64_t TTNNLayoutAttr::getElementSizeBytes() const { mlir::Type elementType = getElementType(); - if (mlir::isa(elementType)) { + if (isTiled()) { TileType tileType = mlir::cast(elementType); return tileType.getSizeBytes(); } @@ -151,21 +156,31 @@ uint64_t TTNNLayoutAttr::getElementSizeBytes() const { // Get shard shape // -// This function returns the shape of the shard. If element type is TileType -// and convertTileToScalar is true, then the shape is converted to scalar shape. -// Example: (convertToScalar = true) memref<2x2x!tt.tile<32x32xf32>> -> {64, 64} -// Example: (convertToScalar = false) memref<2x2x!tt.tile<32x32xf32>> -> {2, 2} -// Example: memref<128x128xf32> -> {128, 128} +// Return the shape of the shard. +// Example: memref<2x2x!tt.tile<32x32xf32>> -> { 2, 2 } +// Example: memref<128x128xf32> -> { 128, 128 } +// Example: memref<2x3!tt.tile<32x32xf32>> -> { 2, 3 } // -// /param convertTileToScalar If true, convert tile shape to scalar shape. -// /return The shape of the shard. -llvm::SmallVector -TTNNLayoutAttr::getShardShape(bool convertTileToScalar) const { +// return The shape of the shard. +llvm::SmallVector TTNNLayoutAttr::getShardShape() const { + return SmallVector(getMemref().getShape()); +} + +// Get scalar shard shape +// +// If the element type is TileType, this function returns the scalar shape of +// the shard. +// Example: memref<2x2x!tt.tile<32x32xf32>> -> { 64, 64 } +// Example: memref<128x128xf32> -> { 128, 128 } +// Example: memref<2x3!tt.tile<32x32xf32>> -> { 64, 96 } +// +// return The scalar shape of the shard. +llvm::SmallVector TTNNLayoutAttr::getScalarShardShape() const { SmallVector shardShape(getMemref().getShape()); - Type elementType = getElementType(); - if (mlir::isa(elementType) && convertTileToScalar) { - return mlir::cast(elementType).getScalarShape(shardShape); + if (isTiled()) { + return mlir::cast(getElementType()).getScalarShape(shardShape); } + return shardShape; } @@ -178,8 +193,8 @@ TTNNLayoutAttr::getShardShape(bool convertTileToScalar) const { // d2) and tile shape (32, 32) The result is (90, 10) which is then divided by // tile shape (32, 32) -> (3, 1) // -// /param tensorShape The shape of the tensor -// /return The size of the tensor in tiles. +// param tensorShape The shape of the tensor +// return The size of the tensor in tiles. llvm::SmallVector TTNNLayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { assert(isTiled() && "Expected a tiled layout"); @@ -214,10 +229,9 @@ TTNNLayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { // Element size for TileType is tile width * tile height * sizeof(element). // For scalar types, element size is sizeof(element). // -// /return The size of the shard in bytes. +// return The size of the shard in bytes. uint64_t TTNNLayoutAttr::getShardSizeInBytes() const { - MemRefType ty = getMemref(); - ArrayRef shape = ty.getShape(); + SmallVector shape = getShardShape(); uint64_t size = getElementSizeBytes(); return std::accumulate(shape.begin(), shape.end(), size, std::multiplies()); @@ -228,7 +242,7 @@ uint64_t TTNNLayoutAttr::getShardSizeInBytes() const { // This function returns a new identity affine map // with the same number of dimensions as the linear map. // -// /return The new identity affine map. +// return The new identity affine map. mlir::AffineMap TTNNLayoutAttr::getIdentityTileLinearMap() const { assert(isTiled() && "Expected a tiled layout"); @@ -241,12 +255,11 @@ mlir::AffineMap TTNNLayoutAttr::getIdentityTileLinearMap() const { // This function takes a physical memory map and replaces the symbols with the // shard shape // -// /param physicalMemoryMap The physical memory map (d0, d1)[s0, s1] -// /return New memory map with symbols replaced with shard shape. +// param physicalMemoryMap The physical memory map (d0, d1)[s0, s1] +// return New memory map with symbols replaced with shard shape. mlir::AffineMap TTNNLayoutAttr::replaceMemoryMapSymbolsWithShardShape( AffineMap physicalMemoryMap) const { - mlir::SmallVector shardShape = - getShardShape(false /*convertTileToScalar*/); + mlir::SmallVector shardShape = getShardShape(); assert(physicalMemoryMap.getNumSymbols() == shardShape.size() && "Physical memory map must have same number of symbols as logical " "shard rank"); @@ -289,11 +302,11 @@ int64_t TTNNLayoutAttr::getTensorSizeInBytes(ArrayRef tensorShape, // This function creates a new TTNNLayoutAttr with the given parameters. // The element type, buffer type and memory layout are preserved. // -// /param context The MLIR context. -// /param tensorShape The shape of the tensor (i.e 6x10x10) -// /param grid The grid where the tensor will be placed (i.e 2x3) -// /param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) -// /return The constructed TTNNLayoutAttr +// param context The MLIR context. +// param tensorShape The shape of the tensor (i.e 6x10x10) +// param grid The grid where the tensor will be placed (i.e 2x3) +// param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) +// return The constructed TTNNLayoutAttr TTNNLayoutAttr TTNNLayoutAttr::withGrid( ::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals) { @@ -307,10 +320,10 @@ TTNNLayoutAttr TTNNLayoutAttr::withGrid( // The shape of the tensor, buffer type, element type and memory layout are // preserved. // -// /param context The MLIR context. -// /param grid The grid where the tensor will be placed. -// /param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) -// /return The constructed TTNNLayoutAttr +// param context The MLIR context. +// param grid The grid where the tensor will be placed. +// param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) +// return The constructed TTNNLayoutAttr TTNNLayoutAttr TTNNLayoutAttr::withGrid( ::mlir::MLIRContext *context, RankedTensorType ty, GridAttr grid, ArrayRef> collapseIntervals) { @@ -324,14 +337,14 @@ TTNNLayoutAttr TTNNLayoutAttr::withGrid( // This function creates a deep copy of the current TTNNLayoutAttr and // replaces the element type with the given one. // -// /param context The MLIR context. -// /param elementType The new element type. -// /return The new TTNNLayoutAttr with the given element type. +// param context The MLIR context. +// param elementType The new element type. +// return The new TTNNLayoutAttr with the given element type. TTNNLayoutAttr TTNNLayoutAttr::withElementType(::mlir::MLIRContext *context, Type elementType) { return TTNNLayoutAttr::get( context, getLinear(), getGrid(), - buildMemRef(context, getShardShape(), + buildMemRef(context, getScalarShardShape(), elementType, getBufferType()), getMemLayout()); } @@ -341,14 +354,14 @@ TTNNLayoutAttr TTNNLayoutAttr::withElementType(::mlir::MLIRContext *context, // This function creates a deep copy of the current TTNNLayoutAttr and // replaces the memory space with the given one. // -// /param context The MLIR context. -// /param memorySpace The new memory space. -// /return The new TTNNLayoutAttr with the given memory space. +// param context The MLIR context. +// param memorySpace The new memory space. +// return The new TTNNLayoutAttr with the given memory space. TTNNLayoutAttr TTNNLayoutAttr::withBufferType(::mlir::MLIRContext *context, BufferType memorySpace) { return TTNNLayoutAttr::get( context, getLinear(), getGrid(), - buildMemRef(context, getShardShape(), + buildMemRef(context, getScalarShardShape(), getElementType(), memorySpace), getMemLayout()); } @@ -358,15 +371,15 @@ TTNNLayoutAttr TTNNLayoutAttr::withBufferType(::mlir::MLIRContext *context, // This function creates a deep copy of the current TTNNLayoutAttr and // replaces the memory layout with the given one. // -// /param context The MLIR context. -// /param memLayout The new memory layout. -// /return The new TTNNLayoutAttr with the given memory layout. +// param context The MLIR context. +// param memLayout The new memory layout. +// return The new TTNNLayoutAttr with the given memory layout. TTNNLayoutAttr TTNNLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout) { return TTNNLayoutAttr::get( context, getLinear(), getGrid(), buildMemRef( - context, getShardShape(), getElementType(), getBufferType()), + context, getScalarShardShape(), getElementType(), getBufferType()), memLayout); } @@ -375,9 +388,9 @@ TTNNLayoutAttr TTNNLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, // This function creates a deep copy of the current TTNNLayoutAttr and // replaces shard shape with the given one. // -// /param context The MLIR context. -// /param shardShape The new shard shape. -// /return The new TTNNLayoutAttr with the given shard shape. +// param context The MLIR context. +// param shardShape The new shard shape. +// return The new TTNNLayoutAttr with the given shard shape. TTNNLayoutAttr TTNNLayoutAttr::withShardShape(::mlir::MLIRContext *context, llvm::SmallVector shardShape) { @@ -392,14 +405,14 @@ TTNNLayoutAttr::withShardShape(::mlir::MLIRContext *context, // // This function constructs a new TTNNLayoutAttr with the given parameters. // -// /param context The MLIR context. -// /param tensorShape The shape of the tensor (i.e 6x10x10) -// /param elementType The type of the element i.e TileType/FloatType/IntegerType -// /param bufferType The type of the buffer -// /param grid The grid where the tensor will be placed (i.e 2x3) -// /param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) -// /param memLayout The memory layout of the tensor -// /return The constructed TTNNLayoutAttr +// param context The MLIR context. +// param tensorShape The shape of the tensor (i.e 6x10x10) +// param elementType The type of the element i.e TileType/FloatType/IntegerType +// param bufferType The type of the buffer +// param grid The grid where the tensor will be placed (i.e 2x3) +// param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) +// param memLayout The memory layout of the tensor +// return The constructed TTNNLayoutAttr TTNNLayoutAttr TTNNLayoutAttr::get( ::mlir::MLIRContext *context, ArrayRef tensorShape, Type elementType, BufferType bufferType, GridAttr grid, diff --git a/lib/Dialect/TTNN/Transforms/Optimizer.cpp b/lib/Dialect/TTNN/Transforms/Optimizer.cpp index 05ff417a6..e5d2f86d8 100644 --- a/lib/Dialect/TTNN/Transforms/Optimizer.cpp +++ b/lib/Dialect/TTNN/Transforms/Optimizer.cpp @@ -276,7 +276,7 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { EmptyOp emptyOp = mlir::cast(op->getOperands().back().getDefiningOp()); - emptyOp.setDtype(layoutAttr.getDataTypeFromMemRef()); + emptyOp.setDtype(layoutAttr.getDataType()); if (layoutAttr.isTiled()) { emptyOp.setLayout(ttnn::Layout::Tile); } else { @@ -449,16 +449,17 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { BufferType outputBufferType = consumerOpOutputLayout.getBufferType(); TensorMemoryLayout outputTensorMemoryLayout = consumerOpOutputLayout.getMemLayout(); - MemRefType outputMemref = consumerOpOutputLayout.getMemref(); + llvm::SmallVector shardShape = + consumerOpOutputLayout.getShardShape(); MemoryConfigAttr outputMemConfigAttr = MemoryConfigAttr::get( consumerOp->getContext(), TensorMemoryLayoutAttr::get(consumerOp->getContext(), outputTensorMemoryLayout), BufferTypeAttr::get(consumerOp->getContext(), outputBufferType), - ShardSpecAttr::get(consumerOp->getContext(), - ShapeAttr::get(consumerOp->getContext(), - outputMemref.getShape()))); + ShardSpecAttr::get( + consumerOp->getContext(), + ShapeAttr::get(consumerOp->getContext(), shardShape))); // If producerOp is a toLayoutOp, adjust its output layout(update // inplace) to reflect consumerOp's output layout. If producerOp is not a @@ -472,10 +473,9 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { } else { OpBuilder builder(consumerOp); - DataTypeAttr outputDataType = - DataTypeAttr::get(consumerOp->getContext(), - utils::getDataTypeFromMemRef(outputMemref)); - Layout outputLayoutEnum = utils::getLayoutFromMemRef(outputMemref); + DataTypeAttr outputDataType = DataTypeAttr::get( + consumerOp->getContext(), consumerOpOutputLayout.getDataType()); + Layout outputLayoutEnum = consumerOpOutputLayout.getLayout(); LayoutAttr outputLayout = LayoutAttr::get(consumerOp->getContext(), outputLayoutEnum); Operation *memoryReconfigOp = builder.create( diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index 79bfeb404..e22540a7d 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -198,24 +198,12 @@ class TTNNDecomposeLayouts } }; - ttnn::Layout getLayoutFromMemRef(mlir::MemRefType memref) const { - ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor; - Type elementType = memref.getElementType(); - if (llvm::isa(elementType)) { - ttnnLayoutEnum = ttnn::Layout::Tile; - } else { - ttnnLayoutEnum = ttnn::Layout::RowMajor; - } - return ttnnLayoutEnum; - } - std::pair getInputOutputLayouts(ttnn::ToLayoutOp op) const { LayoutInfo input, output; auto inputLayoutAttr = mlir::cast(op.getInput().getType().getEncoding()); - auto inputMemref = inputLayoutAttr.getMemref(); assert(op.getMemoryConfig().has_value()); MemoryConfigAttr outputMemoryConfig = op.getMemoryConfig().value(); @@ -223,10 +211,10 @@ class TTNNDecomposeLayouts input.bufferType = inputLayoutAttr.getBufferType(); output.bufferType = outputMemoryConfig.getBufferType().getValue(); - input.layoutEnum = getLayoutFromMemRef(inputMemref); + input.layoutEnum = inputLayoutAttr.getLayout(); output.layoutEnum = op.getLayout(); - input.dataType = ttnn::utils::getDataTypeFromMemRef(inputMemref); + input.dataType = inputLayoutAttr.getDataType(); assert(op.getDtype().has_value()); output.dataType = op.getDtype().value(); @@ -234,7 +222,7 @@ class TTNNDecomposeLayouts output.tensorMemoryLayout = outputMemoryConfig.getTensorMemoryLayout().getValue(); - input.shardShape = inputMemref.getShape(); + input.shardShape = inputLayoutAttr.getShardShape(); output.shardShape = outputMemoryConfig.getShardShapeArray(); return {input, output}; }