Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTNN] Adding support for data type workarounds and introducing Embedding workarounds #1583

Merged
merged 1 commit into from
Dec 23, 2024

Conversation

sdjordjevicTT
Copy link
Contributor

@sdjordjevicTT sdjordjevicTT commented Dec 12, 2024

This PR introduces a solution for handling data type workarounds for operation operands and results. To address input operand data type workarounds, we insert a toLayout operation between the input operands and the operation itself. This casts the input to the desired data type. If the data type of the output result changes due to a workaround, we will revert it to the previous data type by inserting a ToLayoutOp after the operation's output.

Additionally, this PR provides necessary workarounds to ensure that the embedding operation functions correctly. Specifically, it changes the input to an RM layout and casts both the input weight and the output to bf16. Other ops will be onboarded to this type of workaround in a separate PR.

Example of IR today:

module attributes {tt.device = #device, tt.system_desc = #system_desc} {
  func.func @forward(%arg0: tensor<32x32xf32, #ttnn_layout>, %arg1: tensor<512x128xf32, #ttnn_layout1>) -> tensor<32x32x128xf32, #ttnn_layout2> {
    %0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
    %1 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>, shape = #ttnn.shape<32x32x128>}> : (!tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout3>
    %2 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<32x32>>, <interleaved>>}> : (tensor<32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout4>
    %3 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<16x4>>, <interleaved>>}> : (tensor<512x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x128xf32, #ttnn_layout5>
    %4 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>}> : (tensor<32x32x128xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout6>
    %5 = "ttnn.embedding"(%2, %3, %4) : (tensor<32x32xf32, #ttnn_layout4>, tensor<512x128xf32, #ttnn_layout5>, tensor<32x32x128xf32, #ttnn_layout6>) -> tensor<32x32x128xf32, #ttnn_layout6>
    %6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<1024x128>>>}> : (tensor<32x32x128xf32, #ttnn_layout6>) -> tensor<32x32x128xf32, #ttnn_layout2>
    return %6 : tensor<32x32x128xf32, #ttnn_layout2>
  }
}

An example of IR with this change where embedding op has bf16 workaround applied for weight operand:

module attributes {tt.device = #device, tt.system_desc = #system_desc} {
  func.func @forward(%arg0: tensor<32x32xf32, #ttnn_layout>, %arg1: tensor<512x128xf32, #ttnn_layout1>) -> tensor<32x32x128xf32, #ttnn_layout2> {
    %0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
    %1 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>, shape = #ttnn.shape<32x32x128>}> : (!tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout3>
    %2 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<32x32>>, <interleaved>>}> : (tensor<32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout4>
    %3 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<16x4>>, <interleaved>>}> : (tensor<512x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x128xbf16, #ttnn_layout5>
    %4 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>}> : (tensor<32x32x128xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<32x32x128xbf16, #ttnn_layout6>
    %5 = "ttnn.embedding"(%2, %3, %4) : (tensor<32x32xf32, #ttnn_layout4>, tensor<512x128xbf16, #ttnn_layout5>, tensor<32x32x128xbf16, #ttnn_layout6>) -> tensor<32x32x128xbf16, #ttnn_layout6>
    %6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<1024x128>>>}> : (tensor<32x32x128xbf16, #ttnn_layout6>) -> tensor<32x32x128xf32, #ttnn_layout2>
    return %6 : tensor<32x32x128xf32, #ttnn_layout2>
  }
}

@mtopalovicTT
Copy link
Contributor

@sdjordjevicTT can you just add in the PR description example of how IR looks today (without this change), and how it will look with the change?

@sdjordjevicTT sdjordjevicTT marked this pull request as draft December 17, 2024 11:14
@sdjordjevicTT sdjordjevicTT force-pushed the sdjordjevic/add_data_format_workaround_infra branch 3 times, most recently from 4df3134 to 6da4ac9 Compare December 20, 2024 14:21
@sdjordjevicTT sdjordjevicTT self-assigned this Dec 20, 2024
@sdjordjevicTT sdjordjevicTT marked this pull request as ready for review December 20, 2024 14:23
@sdjordjevicTT sdjordjevicTT force-pushed the sdjordjevic/add_data_format_workaround_infra branch from 6da4ac9 to 2ab9e8a Compare December 20, 2024 14:32
lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp Outdated Show resolved Hide resolved
@sdjordjevicTT sdjordjevicTT force-pushed the sdjordjevic/add_data_format_workaround_infra branch 3 times, most recently from b595a45 to 0816699 Compare December 23, 2024 09:37
@sdjordjevicTT sdjordjevicTT force-pushed the sdjordjevic/add_data_format_workaround_infra branch from 0816699 to f8a4157 Compare December 23, 2024 09:42
@sdjordjevicTT sdjordjevicTT merged commit 9520cbb into main Dec 23, 2024
21 checks passed
@sdjordjevicTT sdjordjevicTT deleted the sdjordjevic/add_data_format_workaround_infra branch December 23, 2024 10:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants