Skip to content

Commit

Permalink
Minor API fixes for TTNN encoding ettribute (#1390)
Browse files Browse the repository at this point in the history
This PR adds couple convenience methods to TTNN tensor encoding attribute and also removes redundant utils functions.

Renaming/Adding some new functions...

* `getDataType` to get scalar data type:
    * `memref<2x2x!tt.tile<32x32xf32>>` returns float data type
    * `memref<128x128xi32>` returns int data type
* `getElementType` to get type from memref:
    * `memref<2x2x!tt.tile<32x32xf32>>` returns TileType
    * `memref<128x128xi32>` returns IntegerType
* `getLayout` - gets layout of encoding i.e Tile/RowMajor
* `getShardShape`:
    * `memref<2x2x!tt.tile<32x32xf32>>` returns `(2, 2)`
    * `memref<128x128xi32>` returns `(128, 128)`
* `getScalarShardShape`:
    * `memref<2x2x!tt.tile<32x32xf32>>` returns `(64, 64)`
    * `memref<128x128xi32>` returns `(128, 128)`
  • Loading branch information
mtopalovicTT authored Nov 27, 2024
1 parent b198668 commit d22057f
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 138 deletions.
15 changes: 11 additions & 4 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
let summary = "Tensor encoding attribute used for types in ttnn";
let description = [{
Layout attribute in ttnn. This attribute is used to encode different information about tensor memory layout.
Here is how tensor will look like after layout tensor<32x32x64xf32, #ttnn.ttnn_layout<linear, grid, memref, mem_layout>>
Lets break down what each parameter means:
- linear: An affine map that defines how the logical tensor dimensions map to physical space.
- grid: The grid shape (of tensix cores) where tensor is divided onto.
- memref: A memref is used to describe shard size and memory space. Shard size is calculated by dividing the tensor size by grid size.
- mem_layout: The layout of the tensor in memory. For tensor on host it should be None. For tensor on device
it can be interleaved or sharded.
}];

let parameters = (ins AttrParameter<"AffineMap", "An affine map that defines how the logical tensor dimensions map to a grid shape.">:$linear,
Expand Down Expand Up @@ -142,15 +149,15 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
bool hasShardedL1TensorMemoryLayout() const;
bool hasInterleavedL1TensorMemoryLayout() const;
bool isTiled() const;
Layout getLayout() const;
Type getElementType() const;
DataType getDataTypeFromMemRef() const;
DataType getDataType() const;
uint64_t getElementSizeBytes() const;
int64_t getTensorSizeInBytes(ArrayRef<int64_t> tensorShape, ::mlir::tt::DeviceAttr device) const;
llvm::SmallVector<int64_t> getStride(ArrayRef<int64_t> logicalShape) const;
llvm::SmallVector<int64_t> getPhysicalShape(ArrayRef<int64_t> logicalShape) const;
llvm::SmallVector<int64_t> getShardShape(bool convertTileToScalar = true) const;
llvm::SmallVector<int64_t> getShardShape() const;
llvm::SmallVector<int64_t> getScalarShardShape() const;
AffineMap replaceMemoryMapSymbolsWithShardShape(AffineMap physicalMemoryMap) const;
AffineMap projectOnto(AffineMap linearMap, AffineMap physicalMemoryMap) const;
AffineMap getIdentityTileLinearMap() const;
llvm::SmallVector<int64_t> getTiledShape(ArrayRef<int64_t> logicalTensorShape) const;
}];
Expand Down
4 changes: 0 additions & 4 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ mlir::tt::TensorMemoryLayout toTTTensorMemoryLayout(
mlir::tt::MemorySpace
toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType);

DataType getDataTypeFromMemRef(mlir::MemRefType memref);

Layout getLayoutFromMemRef(mlir::MemRefType memref);

mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context,
DataType dtype);

Expand Down
42 changes: 17 additions & 25 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,15 @@ class TensorEmptyConversionPattern

// Get the shape of the tensor, tensor layout, and data type
//
mlir::MemRefType memref = layoutAttr.getMemref();
ttnn::ShapeAttr shapeAttr = ttnn::ShapeAttr::get(
rewriter.getContext(),
mlir::cast<RankedTensorType>(op->getResult(0).getType()).getShape());
Type elementType = memref.getElementType();
DataType dtype = DataType::Float32;
DataType dtype = layoutAttr.getDataType();
ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor;
if (llvm::isa<TileType>(elementType)) {
if (layoutAttr.isTiled()) {
ttnnLayoutEnum = ttnn::Layout::Tile;
auto tileType = mlir::cast<TileType>(elementType);
dtype = tileType.getDataType();
} else {
ttnnLayoutEnum = ttnn::Layout::RowMajor;
dtype = elementTypeToDataType(elementType);
}
DataTypeAttr dTypeAttr = DataTypeAttr::get(rewriter.getContext(), dtype);
ttnn::LayoutAttr tensorLayoutAttr =
Expand All @@ -101,13 +96,14 @@ class TensorEmptyConversionPattern
// Create MemoryConfigAttr
//
auto device = getOrInsertDevice(rewriter, op);
llvm::SmallVector<int64_t> shardShape = layoutAttr.getShardShape();
ttnn::MemoryConfigAttr memoryConfigAttr = ttnn::MemoryConfigAttr::get(
op.getContext(),
ttnn::TensorMemoryLayoutAttr::get(op.getContext(), memLayout),
ttnn::BufferTypeAttr::get(op.getContext(), bufferType),
ttnn::ShardSpecAttr::get(
op.getContext(),
ttnn::ShapeAttr::get(op.getContext(), memref.getShape())));
ttnn::ShapeAttr::get(op.getContext(), shardShape)));

rewriter.replaceOpWithNewOp<ttnn::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), device,
Expand Down Expand Up @@ -137,18 +133,15 @@ class ToLayoutOpConversionPattern
auto outputLayoutAttr = mlir::cast<ttnn::TTNNLayoutAttr>(
op.getResult().getType().getEncoding());

auto outputMemref = outputLayoutAttr.getMemref();

// Determine the output data type
DataType dtype = ttnn::utils::getDataTypeFromMemRef(outputMemref);
DataType dtype = outputLayoutAttr.getDataType();
DataTypeAttr outputDataType =
DataTypeAttr::get(rewriter.getContext(), dtype);

// Determine the output layout (tile or row major)
ttnn::BufferType outputBufferType = outputLayoutAttr.getBufferType();

ttnn::Layout outputLayoutEnum =
ttnn::utils::getLayoutFromMemRef(outputMemref);
ttnn::Layout outputLayoutEnum = outputLayoutAttr.getLayout();

bool isOutputOnHost = (outputBufferType == ttnn::BufferType::SystemMemory);

Expand Down Expand Up @@ -176,13 +169,14 @@ class ToLayoutOpConversionPattern
op.getResult().setType(result);
outputLayoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(result.getEncoding());
outputMemref = outputLayoutAttr.getMemref();
outputLayoutEnum = newOutputLayoutEnum;
}
}

ttnn::LayoutAttr outputLayout =
ttnn::LayoutAttr::get(rewriter.getContext(), outputLayoutEnum);
llvm::SmallVector<int64_t> outputShardShape =
outputLayoutAttr.getShardShape();

// Determine output memory config attr
ttnn::TensorMemoryLayout outputTensorMemoryLayout =
Expand All @@ -193,8 +187,8 @@ class ToLayoutOpConversionPattern
outputTensorMemoryLayout),
ttnn::BufferTypeAttr::get(rewriter.getContext(), outputBufferType),
ttnn::ShardSpecAttr::get(
op.getContext(), ttnn::ShapeAttr::get(rewriter.getContext(),
outputMemref.getShape())));
op.getContext(),
ttnn::ShapeAttr::get(rewriter.getContext(), outputShardShape)));

rewriter.replaceOpWithNewOp<ttnn::ToLayoutOp>(
op, this->getTypeConverter()->convertType(result), adaptor.getInput(),
Expand Down Expand Up @@ -222,15 +216,16 @@ class ToLayoutOpConversionPattern
ttnn::Layout newOutputLayoutEnum) const {
auto oldOutputLayoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(oldOutput.getEncoding());
auto oldOutputMemref = oldOutputLayoutAttr.getMemref();
DataType outputDtype = ttnn::utils::getDataTypeFromMemRef(oldOutputMemref);
llvm::ArrayRef<std::int64_t> oldShardShape = oldOutputMemref.getShape();
DataType outputDtype = oldOutputLayoutAttr.getDataType();
SmallVector<std::int64_t> oldShardShape =
oldOutputLayoutAttr.getShardShape();
size_t shardShapeSize = oldShardShape.size();
assert(shardShapeSize >= 2 && "expected at least 2D shape");

if (newOutputLayoutEnum == ttnn::Layout::RowMajor) {
// Set shard shape to match convention of row major layout
auto tileType = mlir::cast<TileType>(oldOutputMemref.getElementType());
auto tileType =
mlir::cast<TileType>(oldOutputLayoutAttr.getElementType());
llvm::SmallVector<int64_t> newShardShape(oldShardShape.begin(),
oldShardShape.end());
newShardShape[shardShapeSize - 2] =
Expand Down Expand Up @@ -804,9 +799,7 @@ class TypecastOpConversionPattern
ttnn::TTNNLayoutAttr outputLayoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(result.getType().getEncoding());

mlir::MemRefType outputMemref = outputLayoutAttr.getMemref();

DataType outputDataType = ttnn::utils::getDataTypeFromMemRef(outputMemref);
DataType outputDataType = outputLayoutAttr.getDataType();

if (op->getUsers().empty()) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -950,8 +943,7 @@ class ArangeOpConversionPattern : public OpConversionPattern<ttir::ArangeOp> {
layoutAttr.getMemLayout()),
rewriter.getAttr<ttnn::BufferTypeAttr>(layoutAttr.getBufferType()),
rewriter.getAttr<ttnn::ShardSpecAttr>(
rewriter.getAttr<ttnn::ShapeAttr>(
layoutAttr.getMemref().getShape())));
rewriter.getAttr<ttnn::ShapeAttr>(layoutAttr.getShardShape())));

rewriter.replaceOpWithNewOp<ttnn::ArangeOp>(
op, outputType, adaptor.getStart(), adaptor.getEnd(), adaptor.getStep(),
Expand Down
17 changes: 2 additions & 15 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,25 +190,12 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() {

// DataType and Layout
//
mlir::MemRefType memref = layoutAttr.getMemref();
Type elementType = memref.getElementType();
if (getLayout().has_value()) {
ttnn::Layout ttnnLayoutEnum;
if (llvm::isa<TileType>(elementType)) {
ttnnLayoutEnum = ttnn::Layout::Tile;
} else {
ttnnLayoutEnum = ttnn::Layout::RowMajor;
}
ttnn::Layout ttnnLayoutEnum = layoutAttr.getLayout();
assert(ttnnLayoutEnum == getLayoutAttr().getValue());
}
if (getDtype().has_value()) {
tt::DataType dtype;
if (llvm::isa<TileType>(elementType)) {
auto tileType = mlir::cast<TileType>(elementType);
dtype = tileType.getDataType();
} else {
dtype = elementTypeToDataType(elementType);
}
tt::DataType dtype = layoutAttr.getDataType();
assert(dtype == getDtype());
}

Expand Down
Loading

0 comments on commit d22057f

Please sign in to comment.