Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insert To/FromDevice ops around hoisted ops to ensure they're moved to proper device #1648

Closed
wants to merge 9 commits into from

Conversation

vwellsTT
Copy link
Contributor

...

jnie-TT and others added 9 commits December 18, 2024 16:49
[tt-xla](tenstorrent/tt-xla#107),
[tt-torch](tenstorrent/tt-torch#108) have
updated to the new submit API.

[tt-forge-fe](tenstorrent/tt-forge-fe#925) will
be updated soon.

Therefore removing the legacy submit API from tt-mlir. Also did some
minor cleanup in runtime. Will merge this in after tt-forge is updated.
This PR fixes the issue of incorrect `memRef` for L1 Interleaved
layouts.

Closes issue: #1292
When serializing to FlatBuffer, we did not treat the embedding op as a
DPS op. Instead, its output was treated as a separate tensor and
assigned a different global ID than the DPS init, which later caused a
data mismatch at runtime.

So in this example:
```
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x100>>, <interleaved>>, shape = #ttnn.shape<1x12x3200>}> : (!tt.device<#device>) -> tensor<1x12x3200xbf16, #ttnn_layout5>
%4 = "ttnn.embedding"(%1, %2, %3) : (tensor<1x12xi32, #ttnn_layout3>, tensor<1x12x3200xbf16, #ttnn_layout5>, tensor<32000x3200xbf16, #ttnn_layout4>) -> tensor<1x12x3200xbf16, #ttnn_layout5>
%5 = "ttnn.from_device"(%4) : (tensor<1x12x3200xbf16, #ttnn_layout5>) -> tensor<1x12x3200xbf16, #ttnn_layout2>
```
Here’s what happens:
	•	The "ttnn.empty" operation produces a tensor with a global ID, say 5.
• The "ttnn.embedding" operation, instead of reusing the global ID 5 for
its output (as would be expected for a DPS operation), is assigned a new
global ID, say 6. As a result, its output is not written to the memory
chunk associated with global ID 5.
• When the runtime tries to execute the "ttnn.from_device" operation, it
expects its input to have global ID 5 (since it follows the DPS
convention). However, because nothing was written to global ID 5 due to
the mismatch in how "ttnn.embedding" was handled, the runtime will
instead read a random or uninitialized tensor from that location. This
leads to a data mismatch.

The root cause of this issue is the incorrect FlatBuffer serialization
logic for the embedding operation:
```
::flatbuffers::Offset<::tt::target::ttnn::EmbeddingOp>
createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) {
  auto in0 =
      cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
  auto in1 = cache.at<::tt::target::TensorRef>(
      getOperandThroughDPSOps(op.getWeight()));
  auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
                                  kHostAllocatedAddress, kHostAllocatedSize);
  return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output);
}
```

To fix this, we should replace the line:
`auto output = cache.getOrCreate(op.getResult(), kHostAllocatedAddress,
kHostAllocatedSize);`

with:
`auto out =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getResult()));`

This change ensures that "ttnn.embedding" writes to the output of
"ttnn.empty" with global ID 5 (as in the example) rather than allocating
a new buffer with a different ID.

Note: This bug could potentially be present in other operations as well.
Will check and address them accordingly.

closes #1404
Fixes #1330. 
A part of the solution for for
#1142 (need to add tt-torch
and tt-xla tests in separate PRs) .

Not implemented end to end since this OP should not exists e2e but
rather fitted inside transposed conv op.
Add concurrency check, new workflows triggered on the same branch should
cancel old ones in progress.
#1619
Use pull request sha when triggering downstream check
(github.event.pull_request.head.sha) instead of github.sha. It seems
that github.sha is pointing to temp merge that is created to run the PR
workflow, and not the original commit.
Currently, output workarounds can change the op output layout, which can
cause problems if the operand of the changed layout is the input to the
op that doesn't define workarounds. For example, imagine this scenario:
```mlir
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
  func.func @forward(%arg0: tensor<1x128x128x32xbf16, #ttnn_layout>) -> tensor<1x64x64x32xbf16, #ttnn_layout1> {
    %0 = "ttnn.get_device..."
    %1 = "ttnn.to_layout"(%arg0, %0) -> tile output
    %2 = "ttnn.reshape"(%1) -> tile input, tile output
    %3 = "ttnn.empty"(%0) -> tile output
    %4 = "ttnn.to_layout"(%2, %0) -> row major output (transforms %2)
    %5 = "ttnn.to_layout"(%3, %0) -> row major output (transforms %3)
    %6 = "ttnn.max_pool2d"(%4, %5, %0) -> accepts row major input, produces row major output
    %7 = "ttnn.reshape"(%6) -> !!! PROBLEM HERE !!! expects input in tile, but we workaround output to row_major
    %8 = "ttnn.to_layout"(%7) -> moves output to host
    return %8 : tensor<1x64x64x32xbf16, #ttnn_layout1>
  }
}
```
With this change, we introduce the insertion of the toLayout op after
the op if the output workaround changes the output result layout. This
contains the workaround effect locally only on the op on which we apply
the workaround.

```mlir
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
  func.func @forward(%arg0: tensor<1x128x128x32xbf16, #ttnn_layout>) -> tensor<1x64x64x32xbf16, #ttnn_layout1> {
    %0 = "ttnn.get_device..."
    %1 = "ttnn.to_layout"(%arg0, %0) -> tile output
    %2 = "ttnn.reshape"(%1) -> tile input, tile output
    %3 = "ttnn.empty"(%0) -> tile output
    %4 = "ttnn.to_layout"(%2, %0) -> row major output (transforms %2)
    %5 = "ttnn.to_layout"(%3, %0) -> row major output (transforms %3)
    %6 = "ttnn.max_pool2d"(%4, %5, %0) -> accepts row major input, produces row major output
    %7 = "ttnn.to_layout"(%6, %0) -> !!! PROBLEM SOLVED HERE !!! transform the layout back to tile
    %8 = "ttnn.reshape"(%7) -> tile input, tile output
    %9 = "ttnn.to_layout"(%8) -> moves output to host
    return %9 : tensor<1x64x64x32xbf16, #ttnn_layout1>
  }
}
```

I will follow up with proper tests. At the moment I want to get review
inputs on the change! :)

Closes #1614
### Goal: The end-to-end goal is to integrate a path to compile and
execute specific ops or sets of ops on the CPU.

### Context:
The entire task will be split into (tentatively) 7 PRs, as follows:
1. Hoist specific ops into isolated funcs in a separate module
2. Convert TTIR ops to linalg ops within the module of hoisted funcs
3. **Build a pipeline to lower linalg to llvm from existing conversion
passes**
4. Translate LLVM Dialect into a dynamic library for packing into
flatbuffer
5. Generate helper functions so that we can call all of our hoisted
funcs with a common signature
6. Insert TTNN instructions to move operands to host before executing
hoisted func, then back to device afterwards
7. Update ttir-to-ttnn and ttnn-to-flatbuffer pipelines to use new
passes, generate dylibs, and embed them into output flatbuffers, and
update update runtime to consume dylibs from flatbuffers

This PR represents the 3rd PR above. Here, we build a pipeline from
existing passes to lower linalg Dialect into LLVM Dialect so that we can
proceed to compile into an executable .so in later stages

### Example
Input:
```
module {
  func.func @add(
    %arg0: tensor<32x32xf32>,  // First input tensor
    %arg1: tensor<32x32xf32>,  // Second input tensor
    %arg2: tensor<32x32xf32>   // Output tensor (result stored here)
  ) -> tensor<32x32xf32> {
    // Perform linalg.add and store the result in %arg2
    %1 = linalg.add ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>
    return %1 : tensor<32x32xf32>
  }
}
```

Output:
```
module {
  llvm.func @memrefCopy(i64, !llvm.ptr, !llvm.ptr)
  llvm.func @malloc(i64) -> !llvm.ptr
  llvm.func @add(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: !llvm.ptr, %arg15: !llvm.ptr, %arg16: i64, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> {
    %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %1 = llvm.mlir.zero : !llvm.ptr
    %2 = llvm.mlir.constant(2 : i64) : i64
    %3 = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
    %4 = llvm.mlir.constant(1 : index) : i64
    %5 = llvm.mlir.constant(32 : index) : i64
    %6 = llvm.mlir.constant(0 : index) : i64
    %7 = llvm.insertvalue %arg14, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %8 = llvm.insertvalue %arg15, %7[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %9 = llvm.insertvalue %arg16, %8[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %10 = llvm.insertvalue %arg17, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %11 = llvm.insertvalue %arg19, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %12 = llvm.insertvalue %arg18, %11[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %13 = llvm.insertvalue %arg20, %12[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    llvm.br ^bb1(%6 : i64)
  ^bb1(%14: i64):  // 2 preds: ^bb0, ^bb5
    %15 = llvm.icmp "slt" %14, %5 : i64
    llvm.cond_br %15, ^bb2, ^bb6
  ^bb2:  // pred: ^bb1
    llvm.br ^bb3(%6 : i64)
  ^bb3(%16: i64):  // 2 preds: ^bb2, ^bb4
    %17 = llvm.icmp "slt" %16, %5 : i64
    llvm.cond_br %17, ^bb4, ^bb5
  ^bb4:  // pred: ^bb3
    %18 = llvm.getelementptr %arg1[%arg2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %19 = llvm.mul %14, %arg5 : i64
    %20 = llvm.mul %16, %arg6 : i64
    %21 = llvm.add %19, %20 : i64
    %22 = llvm.getelementptr %18[%21] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %23 = llvm.load %22 : !llvm.ptr -> f32
    %24 = llvm.getelementptr %arg8[%arg9] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %25 = llvm.mul %14, %arg12 : i64
    %26 = llvm.mul %16, %arg13 : i64
    %27 = llvm.add %25, %26 : i64
    %28 = llvm.getelementptr %24[%27] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %29 = llvm.load %28 : !llvm.ptr -> f32
    %30 = llvm.fadd %23, %29  : f32
    %31 = llvm.getelementptr %arg15[%arg16] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %32 = llvm.mul %14, %arg19 : i64
    %33 = llvm.mul %16, %arg20 : i64
    %34 = llvm.add %32, %33 : i64
    %35 = llvm.getelementptr %31[%34] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    llvm.store %30, %35 : f32, !llvm.ptr
    %36 = llvm.add %16, %4 : i64
    llvm.br ^bb3(%36 : i64)
  ^bb5:  // pred: ^bb3
    %37 = llvm.add %14, %4 : i64
    llvm.br ^bb1(%37 : i64)
  ^bb6:  // pred: ^bb1
    %38 = llvm.getelementptr %1[1024] : (!llvm.ptr) -> !llvm.ptr, f32
    %39 = llvm.ptrtoint %38 : !llvm.ptr to i64
    %40 = llvm.call @malloc(%39) : (i64) -> !llvm.ptr
    %41 = llvm.insertvalue %40, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %42 = llvm.insertvalue %40, %41[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %43 = llvm.insertvalue %6, %42[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %44 = llvm.insertvalue %5, %43[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %45 = llvm.insertvalue %5, %44[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %46 = llvm.insertvalue %5, %45[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %47 = llvm.insertvalue %4, %46[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %48 = llvm.intr.stacksave : !llvm.ptr
    %49 = llvm.alloca %4 x !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> : (i64) -> !llvm.ptr
    llvm.store %13, %49 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, !llvm.ptr
    %50 = llvm.insertvalue %2, %3[0] : !llvm.struct<(i64, ptr)> 
    %51 = llvm.insertvalue %49, %50[1] : !llvm.struct<(i64, ptr)> 
    %52 = llvm.alloca %4 x !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> : (i64) -> !llvm.ptr
    llvm.store %47, %52 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, !llvm.ptr
    %53 = llvm.insertvalue %52, %50[1] : !llvm.struct<(i64, ptr)> 
    %54 = llvm.alloca %4 x !llvm.struct<(i64, ptr)> : (i64) -> !llvm.ptr
    llvm.store %51, %54 : !llvm.struct<(i64, ptr)>, !llvm.ptr
    %55 = llvm.alloca %4 x !llvm.struct<(i64, ptr)> : (i64) -> !llvm.ptr
    llvm.store %53, %55 : !llvm.struct<(i64, ptr)>, !llvm.ptr
    %56 = llvm.getelementptr %1[1] : (!llvm.ptr) -> !llvm.ptr, f32
    %57 = llvm.ptrtoint %56 : !llvm.ptr to i64
    llvm.call @memrefCopy(%57, %54, %55) : (i64, !llvm.ptr, !llvm.ptr) -> ()
    llvm.intr.stackrestore %48 : !llvm.ptr
    llvm.return %47 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  }
}
```
- Defined ops in TTIR and TTNN dialects.
- Implemented StableHLO to TTIR converison (unified with logical ops)
- Implemented TTIR to TTNN conversion
- Added tests

Fixes #1202.

Half of solution (tt-xla and tt-torch tests are second half) for issues:
#1051
#1053
#1054
#1055

Left asserts in runtime code due to
tenstorrent/tt-metal#13582.
@vwellsTT vwellsTT closed this Dec 19, 2024
@vwellsTT
Copy link
Contributor Author

whoops, wrong order

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants