From b4405d0fdb860345e7c21699043bafa768ae085b Mon Sep 17 00:00:00 2001 From: Jackson Nie Date: Mon, 23 Dec 2024 09:47:53 -0500 Subject: [PATCH] Update layout passes to accept device tilize for f32 (#1647) ttnn now allows tilizing on device for fp32, updated layout passes to mark this as a valid case --- lib/Dialect/TTNN/Transforms/Passes.cpp | 78 +++++++++++-------- .../ttnn/include/tt/runtime/ttnn/types.cpp | 35 +++++---- .../eltwise_binary_op_chain.mlir | 19 +++-- 3 files changed, 74 insertions(+), 58 deletions(-) diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index af28f6535..2f06efb82 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -30,6 +30,11 @@ namespace mlir::tt::ttnn { #define GEN_PASS_DEF_TTNNMODIFYSIGNATURESFORDYLIB #include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" +// TTNN supports device tilize for bf16 and fp32 +static bool canTilizeDataTypeOnDevice(DataType dataType) { + return dataType == DataType::BFloat16 or dataType == DataType::Float32; +} + class TTNNDeallocate : public impl::TTNNDeallocateBase { public: @@ -427,9 +432,9 @@ class TTNNDecomposeLayouts return; } - /* If we should tilize, and the data type is bfloat16, we can tilize on - * device */ - if (info.shouldTilize() and output.dataType == DataType::BFloat16) { + /* If we should tilize, and the data type can be tilized on device, tilize + * on device */ + if (info.shouldTilize() and canTilizeDataTypeOnDevice(output.dataType)) { currentInput = this->createToDeviceOpIfNeeded(op, rewriter, currentInput, info); currentInput = @@ -440,9 +445,10 @@ class TTNNDecomposeLayouts return; } - /* If we should tilize, and the data type is not bfloat16, we tilize on host - */ - if (info.shouldTilize() and output.dataType != DataType::BFloat16) { + /* If we should tilize, and the data type cannot be tilized on device, + * tilize on host */ + if (info.shouldTilize() and + not canTilizeDataTypeOnDevice(output.dataType)) { currentInput = this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info); currentInput = @@ -513,9 +519,9 @@ class TTNNDecomposeLayouts return; } - /* If we need to tilize and the input datatype is bfloat16 + /* If we need to tilize and the input datatype is tilizeable on device, we can tilize on device and then typecast afterwards */ - if (info.shouldTilize() and input.dataType == DataType::BFloat16) { + if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) { currentInput = this->createToDeviceOpIfNeeded(op, rewriter, currentInput, info); currentInput = @@ -528,9 +534,9 @@ class TTNNDecomposeLayouts return; } - /* if we need to tilize and the output data type is bfloat16 + /* if we need to tilize and the output data type can be tilized on device, we can typecast on host and tilize on device */ - if (info.shouldTilize() and output.dataType == DataType::BFloat16) { + if (info.shouldTilize() and canTilizeDataTypeOnDevice(output.dataType)) { currentInput = this->createTypecastOpIfNeeded(op, rewriter, currentInput, info); currentInput = @@ -543,10 +549,11 @@ class TTNNDecomposeLayouts return; } - /* if we need to tilize and the input/ output data types are not bfloat16 do - * everything on host */ - if (info.shouldTilize() and input.dataType != DataType::BFloat16 and - output.dataType != DataType::BFloat16) { + /* if we need to tilize and the input/output data types cannot be tilized on + * device, do everything on host */ + if (info.shouldTilize() and + not canTilizeDataTypeOnDevice(input.dataType) and + not canTilizeDataTypeOnDevice(output.dataType)) { currentInput = this->createTypecastOpIfNeeded(op, rewriter, currentInput, info); currentInput = @@ -639,9 +646,10 @@ class TTNNDecomposeLayouts return; } - /* If we should tilize and the input data type is bfloat16, tilize on device + /* If we should tilize and the input data type can be tilized on device, + * tilize on device */ - if (info.shouldTilize() and input.dataType == DataType::BFloat16) { + if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) { currentInput = this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info); currentInput = this->createToMemoryConfigOpIfNeeded(op, rewriter, @@ -652,9 +660,10 @@ class TTNNDecomposeLayouts return; } - /* If we should tilize and the input data type is not bfloat16, tilize on - * host */ - if (info.shouldTilize() and input.dataType != DataType::BFloat16 and + /* If we should tilize and the input data type cannot be tilized on device, + * tilize on host */ + if (info.shouldTilize() and + not canTilizeDataTypeOnDevice(input.dataType) and opsToCreate.createFromDeviceOp) { currentInput = this->createFromDeviceOpIfNeeded(op, rewriter, currentInput, info); @@ -664,9 +673,10 @@ class TTNNDecomposeLayouts return; } - /* If we want to tilize a device tensor that is not bfloat16, we need to - * tilize on host and move it back */ - if (info.shouldTilize() and input.dataType != DataType::BFloat16 and + /* If we want to tilize a device tensor whose data type cannot be tilized on + * device, we need to tilize on host and move it back */ + if (info.shouldTilize() and + not canTilizeDataTypeOnDevice(input.dataType) and not opsToCreate.createFromDeviceOp) { // Force-create a FromDeviceOp currentInput = @@ -781,9 +791,9 @@ class TTNNDecomposeLayouts return; } - /* If we should tilize and the input data type is bfloat16, tilize and - * typecast on device */ - if (info.shouldTilize() and input.dataType == DataType::BFloat16) { + /* If we should tilize and the input data type can be tilized on device, + * tilize and typecast on device */ + if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) { currentInput = this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info); currentInput = @@ -796,9 +806,10 @@ class TTNNDecomposeLayouts return; } - /* If we should tilize and the input data type is not bfloat16 and we want - to read back from device do everything on host */ - if (info.shouldTilize() and input.dataType != DataType::BFloat16 and + /* If we should tilize and the input data type cannot be tilized on device, + and we want to read back from device, do everything on host */ + if (info.shouldTilize() and + not canTilizeDataTypeOnDevice(input.dataType) and opsToCreate.createFromDeviceOp) { currentInput = this->createFromDeviceOpIfNeeded(op, rewriter, currentInput, info); @@ -810,10 +821,11 @@ class TTNNDecomposeLayouts return; } - /* If we should tilize and the input data type is not bfloat 16 and we don't - want to read back from device: tilize on host, move back to device, and - typecast on device */ - if (info.shouldTilize() and input.dataType != DataType::BFloat16 and + /* If we should tilize and the input data type cannot be tilized on device, + and we don't want to read back from device - tilize on host, move back to + device, and typecast on device */ + if (info.shouldTilize() and + not canTilizeDataTypeOnDevice(input.dataType) and not opsToCreate.createFromDeviceOp) { // Force-create a FromDeviceOp currentInput = @@ -863,7 +875,7 @@ class TTNNDecomposeLayouts /* * Logic for creating ops. Conditions/constraints include: * - When possible, we want to execute operations on device. - * - Tilize on device requires dataformat of BFLOAT16. + * - Tilize on device requires dataformat of BFLOAT16 or FLOAT32. * - Typecast on device requires TILIZED tensor. * - Untilize on device requires even width, and page size > * sizeof(uint32_t). For now, we will always untilize on host. We rarely diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp index 87d081599..2f7159a88 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp @@ -8,6 +8,10 @@ namespace tt::runtime::ttnn { +static bool canTilizeDataTypeOnDevice(::ttnn::DataType dataType) { + return dataType == ::ttnn::DataType::BFLOAT16 or + dataType == ::ttnn::DataType::FLOAT32; +} // // LayoutConverter APIs // @@ -103,14 +107,14 @@ ::ttnn::Tensor LayoutConverter::handleHostInputLayoutNoTypecast( return out; } - if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) { ::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice); out = toLayoutIfNeeded(out); out = toMemoryConfigIfNeeded(out); return out; } - if (shouldTilize and outputDesc.dataType != ::ttnn::DataType::BFLOAT16) { + if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) { ::ttnn::Tensor out = toLayoutIfNeeded(input); out = toDeviceIfNeeded(out, targetDevice); out = toMemoryConfigIfNeeded(out); @@ -147,7 +151,7 @@ ::ttnn::Tensor LayoutConverter::handleHostInputLayoutTypecast( return out; } - if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) { ::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice); out = toLayoutIfNeeded(out); out = typecastIfNeeded(out); @@ -155,7 +159,7 @@ ::ttnn::Tensor LayoutConverter::handleHostInputLayoutTypecast( return out; } - if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) { ::ttnn::Tensor out = typecastIfNeeded(input); out = toDeviceIfNeeded(out, targetDevice); out = toLayoutIfNeeded(input); @@ -163,8 +167,8 @@ ::ttnn::Tensor LayoutConverter::handleHostInputLayoutTypecast( return out; } - if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and - outputDesc.dataType != ::ttnn::DataType::BFLOAT16) { + if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and + not canTilizeDataTypeOnDevice(outputDesc.dataType)) { ::ttnn::Tensor out = typecastIfNeeded(input); out = toLayoutIfNeeded(out); out = toDeviceIfNeeded(out, targetDevice); @@ -217,25 +221,26 @@ ::ttnn::Tensor LayoutConverter::handleDeviceInputLayoutNoTypecast( return out; } - /* If we should tilize and the input data type is bfloat16, tilize on device + /* If we should tilize and the input data type can be tilized on device, + * tilize on device */ - if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) { ::ttnn::Tensor out = toLayoutIfNeeded(input); out = toMemoryConfigIfNeeded(out); out = fromDeviceIfNeeded(out); return out; } - /* If we should tilize and the input data type is not bfloat16, tilize on - * host */ - if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + /* If we should tilize and the input data type cannot be tilized on device, + * tilize on host */ + if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and shouldFromDevice) { ::ttnn::Tensor out = fromDeviceIfNeeded(input); out = toLayoutIfNeeded(out); return out; } - if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and not shouldFromDevice) { LOG_WARNING("Currently no constraint checking for on-device tilize."); ::ttnn::Tensor out = toLayoutIfNeeded(input); @@ -287,7 +292,7 @@ LayoutConverter::handleDeviceInputLayoutTypecast(const ::ttnn::Tensor &input) { return out; } - if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) { ::ttnn::Tensor out = toLayoutIfNeeded(input); out = typecastIfNeeded(out); out = toMemoryConfigIfNeeded(out); @@ -295,7 +300,7 @@ LayoutConverter::handleDeviceInputLayoutTypecast(const ::ttnn::Tensor &input) { return out; } - if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and shouldFromDevice) { ::ttnn::Tensor out = fromDeviceIfNeeded(input); out = toLayoutIfNeeded(out); @@ -303,7 +308,7 @@ LayoutConverter::handleDeviceInputLayoutTypecast(const ::ttnn::Tensor &input) { return out; } - if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and not shouldFromDevice) { LOG_WARNING("Currently no constraint checking for on-device tilize."); ::ttnn::Tensor out = toLayoutIfNeeded(input); diff --git a/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir b/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir index 35b4d9063..a5a100cfb 100644 --- a/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir +++ b/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir @@ -7,16 +7,15 @@ #dram = #ttnn.buffer_type #ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #system_memory>> #ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > -#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #dram>, > module attributes {tt.device = #device} { func.func @add(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> { %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> - %4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> - %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<2x4>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1> + %4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> return %6 : tensor<64x128xbf16, #ttnn_layout> } @@ -27,9 +26,9 @@ module attributes {tt.device = #device} { %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> - %4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> - %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<2x4>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1> + %4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> return %6 : tensor<64x128xbf16, #ttnn_layout> } @@ -40,9 +39,9 @@ module attributes {tt.device = #device} { %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> - %4 = "ttnn.subtract"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> - %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<2x4>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1> + %4 = "ttnn.subtract"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> return %6 : tensor<64x128xbf16, #ttnn_layout> }