Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverting tilizing f32 on device #1669

Merged
merged 2 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 33 additions & 45 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ namespace mlir::tt::ttnn {
#define GEN_PASS_DEF_TTNNMODIFYSIGNATURESFORDYLIB
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc"

// TTNN supports device tilize for bf16 and fp32
static bool canTilizeDataTypeOnDevice(DataType dataType) {
return dataType == DataType::BFloat16 or dataType == DataType::Float32;
}

class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {

public:
Expand Down Expand Up @@ -432,9 +427,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize, and the data type can be tilized on device, tilize
* on device */
if (info.shouldTilize() and canTilizeDataTypeOnDevice(output.dataType)) {
/* If we should tilize, and the data type is bfloat16, we can tilize on
* device */
if (info.shouldTilize() and output.dataType == DataType::BFloat16) {
currentInput =
this->createToDeviceOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -445,10 +440,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize, and the data type cannot be tilized on device,
* tilize on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(output.dataType)) {
/* If we should tilize, and the data type is not bfloat16, we tilize on host
*/
if (info.shouldTilize() and output.dataType != DataType::BFloat16) {
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand Down Expand Up @@ -519,9 +513,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we need to tilize and the input datatype is tilizeable on device,
/* If we need to tilize and the input datatype is bfloat16
we can tilize on device and then typecast afterwards */
if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) {
if (info.shouldTilize() and input.dataType == DataType::BFloat16) {
currentInput =
this->createToDeviceOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -534,9 +528,9 @@ class TTNNDecomposeLayouts
return;
}

/* if we need to tilize and the output data type can be tilized on device,
/* if we need to tilize and the output data type is bfloat16
we can typecast on host and tilize on device */
if (info.shouldTilize() and canTilizeDataTypeOnDevice(output.dataType)) {
if (info.shouldTilize() and output.dataType == DataType::BFloat16) {
currentInput =
this->createTypecastOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -549,11 +543,10 @@ class TTNNDecomposeLayouts
return;
}

/* if we need to tilize and the input/output data types cannot be tilized on
* device, do everything on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
not canTilizeDataTypeOnDevice(output.dataType)) {
/* if we need to tilize and the input/ output data types are not bfloat16 do
* everything on host */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
output.dataType != DataType::BFloat16) {
currentInput =
this->createTypecastOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand Down Expand Up @@ -646,10 +639,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type can be tilized on device,
* tilize on device
/* If we should tilize and the input data type is bfloat16, tilize on device
*/
if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) {
if (info.shouldTilize() and input.dataType == DataType::BFloat16) {
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
currentInput = this->createToMemoryConfigOpIfNeeded(op, rewriter,
Expand All @@ -660,10 +652,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type cannot be tilized on device,
* tilize on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
/* If we should tilize and the input data type is not bfloat16, tilize on
* host */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
opsToCreate.createFromDeviceOp) {
currentInput =
this->createFromDeviceOpIfNeeded(op, rewriter, currentInput, info);
Expand All @@ -673,10 +664,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we want to tilize a device tensor whose data type cannot be tilized on
* device, we need to tilize on host and move it back */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
/* If we want to tilize a device tensor that is not bfloat16, we need to
* tilize on host and move it back */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
not opsToCreate.createFromDeviceOp) {
// Force-create a FromDeviceOp
currentInput =
Expand Down Expand Up @@ -791,9 +781,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type can be tilized on device,
* tilize and typecast on device */
if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) {
/* If we should tilize and the input data type is bfloat16, tilize and
* typecast on device */
if (info.shouldTilize() and input.dataType == DataType::BFloat16) {
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -806,10 +796,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type cannot be tilized on device,
and we want to read back from device, do everything on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
/* If we should tilize and the input data type is not bfloat16 and we want
to read back from device do everything on host */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
opsToCreate.createFromDeviceOp) {
currentInput =
this->createFromDeviceOpIfNeeded(op, rewriter, currentInput, info);
Expand All @@ -821,11 +810,10 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type cannot be tilized on device,
and we don't want to read back from device - tilize on host, move back to
device, and typecast on device */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
/* If we should tilize and the input data type is not bfloat 16 and we don't
want to read back from device: tilize on host, move back to device, and
typecast on device */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
not opsToCreate.createFromDeviceOp) {
// Force-create a FromDeviceOp
currentInput =
Expand Down Expand Up @@ -875,7 +863,7 @@ class TTNNDecomposeLayouts
/*
* Logic for creating ops. Conditions/constraints include:
* - When possible, we want to execute operations on device.
* - Tilize on device requires dataformat of BFLOAT16 or FLOAT32.
* - Tilize on device requires dataformat of BFLOAT16.
* - Typecast on device requires TILIZED tensor.
* - Untilize on device requires even width, and page size >
* sizeof(uint32_t). For now, we will always untilize on host. We rarely
Expand Down
35 changes: 15 additions & 20 deletions runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

namespace tt::runtime::ttnn {

static bool canTilizeDataTypeOnDevice(::ttnn::DataType dataType) {
return dataType == ::ttnn::DataType::BFLOAT16 or
dataType == ::ttnn::DataType::FLOAT32;
}
//
// LayoutConverter APIs
//
Expand Down Expand Up @@ -107,14 +103,14 @@ ::ttnn::Tensor LayoutConverter::handleHostInputLayoutNoTypecast(
return out;
}

if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) {
if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice);
out = toLayoutIfNeeded(out);
out = toMemoryConfigIfNeeded(out);
return out;
}

if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) {
if (shouldTilize and outputDesc.dataType != ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = toLayoutIfNeeded(input);
out = toDeviceIfNeeded(out, targetDevice);
out = toMemoryConfigIfNeeded(out);
Expand Down Expand Up @@ -151,24 +147,24 @@ ::ttnn::Tensor LayoutConverter::handleHostInputLayoutTypecast(
return out;
}

if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) {
if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice);
out = toLayoutIfNeeded(out);
out = typecastIfNeeded(out);
out = toMemoryConfigIfNeeded(out);
return out;
}

if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) {
if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = typecastIfNeeded(input);
out = toDeviceIfNeeded(out, targetDevice);
out = toLayoutIfNeeded(input);
out = toMemoryConfigIfNeeded(out);
return out;
}

if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
not canTilizeDataTypeOnDevice(outputDesc.dataType)) {
if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
outputDesc.dataType != ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = typecastIfNeeded(input);
out = toLayoutIfNeeded(out);
out = toDeviceIfNeeded(out, targetDevice);
Expand Down Expand Up @@ -221,26 +217,25 @@ ::ttnn::Tensor LayoutConverter::handleDeviceInputLayoutNoTypecast(
return out;
}

/* If we should tilize and the input data type can be tilized on device,
* tilize on device
/* If we should tilize and the input data type is bfloat16, tilize on device
*/
if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) {
if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = toLayoutIfNeeded(input);
out = toMemoryConfigIfNeeded(out);
out = fromDeviceIfNeeded(out);
return out;
}

/* If we should tilize and the input data type cannot be tilized on device,
* tilize on host */
if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
/* If we should tilize and the input data type is not bfloat16, tilize on
* host */
if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
shouldFromDevice) {
::ttnn::Tensor out = fromDeviceIfNeeded(input);
out = toLayoutIfNeeded(out);
return out;
}

if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
not shouldFromDevice) {
LOG_WARNING("Currently no constraint checking for on-device tilize.");
::ttnn::Tensor out = toLayoutIfNeeded(input);
Expand Down Expand Up @@ -292,23 +287,23 @@ LayoutConverter::handleDeviceInputLayoutTypecast(const ::ttnn::Tensor &input) {
return out;
}

if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) {
if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = toLayoutIfNeeded(input);
out = typecastIfNeeded(out);
out = toMemoryConfigIfNeeded(out);
out = fromDeviceIfNeeded(out);
return out;
}

if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
shouldFromDevice) {
::ttnn::Tensor out = fromDeviceIfNeeded(input);
out = toLayoutIfNeeded(out);
out = typecastIfNeeded(out);
return out;
}

if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
not shouldFromDevice) {
LOG_WARNING("Currently no constraint checking for on-device tilize.");
::ttnn::Tensor out = toLayoutIfNeeded(input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
#dram = #ttnn.buffer_type<dram>
#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #system_memory>>
#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #dram>, <interleaved>>

module attributes {tt.device = #device} {
func.func @add(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<2x4>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1>
%4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<64x128>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2>
%4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout>
%6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout<row_major>}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout>
return %6 : tensor<64x128xbf16, #ttnn_layout>
}
Expand All @@ -26,9 +27,9 @@ module attributes {tt.device = #device} {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<2x4>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1>
%4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<64x128>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2>
%4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout>
%6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout<row_major>}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout>
return %6 : tensor<64x128xbf16, #ttnn_layout>
}
Expand All @@ -39,9 +40,9 @@ module attributes {tt.device = #device} {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<2x4>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1>
%4 = "ttnn.subtract"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<64x128>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2>
%4 = "ttnn.subtract"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout>
%6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout<row_major>}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout>
return %6 : tensor<64x128xbf16, #ttnn_layout>
}
Expand Down
11 changes: 11 additions & 0 deletions test/ttmlir/Silicon/TTNN/matmul/llama_matmul.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module attributes {} {
func.func @forward(%arg0: tensor<1x11x2048xf32>, %arg1: tensor<2048x128256xf32>) -> tensor<1x11x128256xf32> {
%0 = tensor.empty() : tensor<1x11x128256xf32>
// CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]]
%1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<1x11x2048xf32>, tensor<2048x128256xf32>, tensor<1x11x128256xf32>) -> tensor<1x11x128256xf32>
return %1 : tensor<1x11x128256xf32>
}
}
Loading