Skip to content

Commit

Permalink
Update layout passes to accept device tilize for f32 (#1647)
Browse files Browse the repository at this point in the history
ttnn now allows tilizing on device for fp32, updated layout passes to
mark this as a valid case
  • Loading branch information
jnie-TT authored Dec 23, 2024
1 parent 9520cbb commit b4405d0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 58 deletions.
78 changes: 45 additions & 33 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ namespace mlir::tt::ttnn {
#define GEN_PASS_DEF_TTNNMODIFYSIGNATURESFORDYLIB
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc"

// TTNN supports device tilize for bf16 and fp32
static bool canTilizeDataTypeOnDevice(DataType dataType) {
return dataType == DataType::BFloat16 or dataType == DataType::Float32;
}

class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {

public:
Expand Down Expand Up @@ -427,9 +432,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize, and the data type is bfloat16, we can tilize on
* device */
if (info.shouldTilize() and output.dataType == DataType::BFloat16) {
/* If we should tilize, and the data type can be tilized on device, tilize
* on device */
if (info.shouldTilize() and canTilizeDataTypeOnDevice(output.dataType)) {
currentInput =
this->createToDeviceOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -440,9 +445,10 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize, and the data type is not bfloat16, we tilize on host
*/
if (info.shouldTilize() and output.dataType != DataType::BFloat16) {
/* If we should tilize, and the data type cannot be tilized on device,
* tilize on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(output.dataType)) {
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand Down Expand Up @@ -513,9 +519,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we need to tilize and the input datatype is bfloat16
/* If we need to tilize and the input datatype is tilizeable on device,
we can tilize on device and then typecast afterwards */
if (info.shouldTilize() and input.dataType == DataType::BFloat16) {
if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) {
currentInput =
this->createToDeviceOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -528,9 +534,9 @@ class TTNNDecomposeLayouts
return;
}

/* if we need to tilize and the output data type is bfloat16
/* if we need to tilize and the output data type can be tilized on device,
we can typecast on host and tilize on device */
if (info.shouldTilize() and output.dataType == DataType::BFloat16) {
if (info.shouldTilize() and canTilizeDataTypeOnDevice(output.dataType)) {
currentInput =
this->createTypecastOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -543,10 +549,11 @@ class TTNNDecomposeLayouts
return;
}

/* if we need to tilize and the input/ output data types are not bfloat16 do
* everything on host */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
output.dataType != DataType::BFloat16) {
/* if we need to tilize and the input/output data types cannot be tilized on
* device, do everything on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
not canTilizeDataTypeOnDevice(output.dataType)) {
currentInput =
this->createTypecastOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand Down Expand Up @@ -639,9 +646,10 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type is bfloat16, tilize on device
/* If we should tilize and the input data type can be tilized on device,
* tilize on device
*/
if (info.shouldTilize() and input.dataType == DataType::BFloat16) {
if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) {
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
currentInput = this->createToMemoryConfigOpIfNeeded(op, rewriter,
Expand All @@ -652,9 +660,10 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type is not bfloat16, tilize on
* host */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
/* If we should tilize and the input data type cannot be tilized on device,
* tilize on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
opsToCreate.createFromDeviceOp) {
currentInput =
this->createFromDeviceOpIfNeeded(op, rewriter, currentInput, info);
Expand All @@ -664,9 +673,10 @@ class TTNNDecomposeLayouts
return;
}

/* If we want to tilize a device tensor that is not bfloat16, we need to
* tilize on host and move it back */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
/* If we want to tilize a device tensor whose data type cannot be tilized on
* device, we need to tilize on host and move it back */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
not opsToCreate.createFromDeviceOp) {
// Force-create a FromDeviceOp
currentInput =
Expand Down Expand Up @@ -781,9 +791,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type is bfloat16, tilize and
* typecast on device */
if (info.shouldTilize() and input.dataType == DataType::BFloat16) {
/* If we should tilize and the input data type can be tilized on device,
* tilize and typecast on device */
if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) {
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -796,9 +806,10 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type is not bfloat16 and we want
to read back from device do everything on host */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
/* If we should tilize and the input data type cannot be tilized on device,
and we want to read back from device, do everything on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
opsToCreate.createFromDeviceOp) {
currentInput =
this->createFromDeviceOpIfNeeded(op, rewriter, currentInput, info);
Expand All @@ -810,10 +821,11 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type is not bfloat 16 and we don't
want to read back from device: tilize on host, move back to device, and
typecast on device */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
/* If we should tilize and the input data type cannot be tilized on device,
and we don't want to read back from device - tilize on host, move back to
device, and typecast on device */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
not opsToCreate.createFromDeviceOp) {
// Force-create a FromDeviceOp
currentInput =
Expand Down Expand Up @@ -863,7 +875,7 @@ class TTNNDecomposeLayouts
/*
* Logic for creating ops. Conditions/constraints include:
* - When possible, we want to execute operations on device.
* - Tilize on device requires dataformat of BFLOAT16.
* - Tilize on device requires dataformat of BFLOAT16 or FLOAT32.
* - Typecast on device requires TILIZED tensor.
* - Untilize on device requires even width, and page size >
* sizeof(uint32_t). For now, we will always untilize on host. We rarely
Expand Down
35 changes: 20 additions & 15 deletions runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

namespace tt::runtime::ttnn {

static bool canTilizeDataTypeOnDevice(::ttnn::DataType dataType) {
return dataType == ::ttnn::DataType::BFLOAT16 or
dataType == ::ttnn::DataType::FLOAT32;
}
//
// LayoutConverter APIs
//
Expand Down Expand Up @@ -103,14 +107,14 @@ ::ttnn::Tensor LayoutConverter::handleHostInputLayoutNoTypecast(
return out;
}

if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) {
::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice);
out = toLayoutIfNeeded(out);
out = toMemoryConfigIfNeeded(out);
return out;
}

if (shouldTilize and outputDesc.dataType != ::ttnn::DataType::BFLOAT16) {
if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) {
::ttnn::Tensor out = toLayoutIfNeeded(input);
out = toDeviceIfNeeded(out, targetDevice);
out = toMemoryConfigIfNeeded(out);
Expand Down Expand Up @@ -147,24 +151,24 @@ ::ttnn::Tensor LayoutConverter::handleHostInputLayoutTypecast(
return out;
}

if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) {
::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice);
out = toLayoutIfNeeded(out);
out = typecastIfNeeded(out);
out = toMemoryConfigIfNeeded(out);
return out;
}

if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) {
::ttnn::Tensor out = typecastIfNeeded(input);
out = toDeviceIfNeeded(out, targetDevice);
out = toLayoutIfNeeded(input);
out = toMemoryConfigIfNeeded(out);
return out;
}

if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
outputDesc.dataType != ::ttnn::DataType::BFLOAT16) {
if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
not canTilizeDataTypeOnDevice(outputDesc.dataType)) {
::ttnn::Tensor out = typecastIfNeeded(input);
out = toLayoutIfNeeded(out);
out = toDeviceIfNeeded(out, targetDevice);
Expand Down Expand Up @@ -217,25 +221,26 @@ ::ttnn::Tensor LayoutConverter::handleDeviceInputLayoutNoTypecast(
return out;
}

/* If we should tilize and the input data type is bfloat16, tilize on device
/* If we should tilize and the input data type can be tilized on device,
* tilize on device
*/
if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) {
::ttnn::Tensor out = toLayoutIfNeeded(input);
out = toMemoryConfigIfNeeded(out);
out = fromDeviceIfNeeded(out);
return out;
}

/* If we should tilize and the input data type is not bfloat16, tilize on
* host */
if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
/* If we should tilize and the input data type cannot be tilized on device,
* tilize on host */
if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
shouldFromDevice) {
::ttnn::Tensor out = fromDeviceIfNeeded(input);
out = toLayoutIfNeeded(out);
return out;
}

if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
not shouldFromDevice) {
LOG_WARNING("Currently no constraint checking for on-device tilize.");
::ttnn::Tensor out = toLayoutIfNeeded(input);
Expand Down Expand Up @@ -287,23 +292,23 @@ LayoutConverter::handleDeviceInputLayoutTypecast(const ::ttnn::Tensor &input) {
return out;
}

if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) {
::ttnn::Tensor out = toLayoutIfNeeded(input);
out = typecastIfNeeded(out);
out = toMemoryConfigIfNeeded(out);
out = fromDeviceIfNeeded(out);
return out;
}

if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
shouldFromDevice) {
::ttnn::Tensor out = fromDeviceIfNeeded(input);
out = toLayoutIfNeeded(out);
out = typecastIfNeeded(out);
return out;
}

if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
not shouldFromDevice) {
LOG_WARNING("Currently no constraint checking for on-device tilize.");
::ttnn::Tensor out = toLayoutIfNeeded(input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
#dram = #ttnn.buffer_type<dram>
#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #system_memory>>
#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #dram>, <interleaved>>

module attributes {tt.device = #device} {
func.func @add(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<64x128>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2>
%4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<2x4>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1>
%4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout>
%6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout<row_major>}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout>
return %6 : tensor<64x128xbf16, #ttnn_layout>
}
Expand All @@ -27,9 +26,9 @@ module attributes {tt.device = #device} {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<64x128>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2>
%4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<2x4>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1>
%4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout>
%6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout<row_major>}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout>
return %6 : tensor<64x128xbf16, #ttnn_layout>
}
Expand All @@ -40,9 +39,9 @@ module attributes {tt.device = #device} {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<64x128>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2>
%4 = "ttnn.subtract"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<2x4>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1>
%4 = "ttnn.subtract"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout>
%6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout<row_major>}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout>
return %6 : tensor<64x128xbf16, #ttnn_layout>
}
Expand Down

0 comments on commit b4405d0

Please sign in to comment.