Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pass to emit helper funcs so that we have common call signature #1600

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

vwellsTT
Copy link
Contributor

@vwellsTT vwellsTT commented Dec 13, 2024

Goal: The end-to-end goal is to integrate a path to compile and execute specific ops or sets of ops on the CPU.

Context:

The entire task will be split into (tentatively) 7 PRs, as follows:

  1. Hoist specific ops into isolated funcs in a separate module
  2. Convert TTIR ops to linalg ops within the module of hoisted funcs
  3. Build a pipeline to lower linalg to llvm from existing conversion passes
  4. Translate LLVM Dialect into a dynamic library for packing into flatbuffer
  5. Generate helper functions so that we can call all of our hoisted funcs with a common signature
  6. Insert TTNN instructions to move operands to host before executing hoisted func, then back to device afterwards
  7. Update ttir-to-ttnn and ttnn-to-flatbuffer pipelines to use new passes, generate dylibs, and embed them into output flatbuffers, and update update runtime to consume dylibs from flatbuffers

This PR represents the 5th subtask above. The goal here is to make it so we can have a single entry point to call all of our dylib functions from the runtime side--to that end, we need helper funcs for each of our LLVM-lowered funcs which takes a generic pointer to a tensor_wrapper, and unpacks args (including variable size + stride args based on tensor rank). We accomplish this via annotating our lowered functions with tensor input ranks, and using that attr to infer number of sizes + strides to unpack, and number of input tensors total as well.

Example

Input:

module attributes {ttir.cpu_module} {
  llvm.func @add(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: !llvm.ptr, %arg15: !llvm.ptr, %arg16: i64, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> attributes {arg_ranks = [2, 2, 2]} {
    %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    llvm.return %0 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  }
}

Output:

module attributes {ttir.cpu_module} {
  llvm.func @add(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: !llvm.ptr, %arg15: !llvm.ptr, %arg16: i64, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> attributes {arg_ranks = [2, 2, 2]} {
    %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    llvm.return %0 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  }
  llvm.func @add_helper(%arg0: !llvm.ptr) {
    %0 = llvm.mlir.constant(0 : i32) : i32
    %1 = llvm.getelementptr %arg0[%0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %2 = llvm.mlir.constant(0 : i32) : i32
    %3 = llvm.getelementptr %1[%2] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %4 = llvm.mlir.constant(1 : i32) : i32
    %5 = llvm.getelementptr %1[%4] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %6 = llvm.mlir.constant(2 : i32) : i32
    %7 = llvm.getelementptr %1[%6] : (!llvm.ptr, i32) -> !llvm.ptr, i64
    %8 = llvm.ptrtoint %7 : !llvm.ptr to i64
    %9 = llvm.mlir.constant(3 : i32) : i32
    %10 = llvm.getelementptr %1[%9] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %11 = llvm.mlir.constant(0 : i64) : i64
    %12 = llvm.getelementptr %10[%11] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %13 = llvm.load %12 : !llvm.ptr -> i64
    %14 = llvm.mlir.constant(1 : i64) : i64
    %15 = llvm.getelementptr %10[%14] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %16 = llvm.load %15 : !llvm.ptr -> i64
    %17 = llvm.mlir.constant(2 : i64) : i64
    %18 = llvm.getelementptr %10[%17] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %19 = llvm.load %18 : !llvm.ptr -> i64
    %20 = llvm.mlir.constant(3 : i64) : i64
    %21 = llvm.getelementptr %10[%20] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %22 = llvm.load %21 : !llvm.ptr -> i64
    %23 = llvm.mlir.constant(1 : i32) : i32
    %24 = llvm.getelementptr %arg0[%23] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %25 = llvm.mlir.constant(0 : i32) : i32
    %26 = llvm.getelementptr %24[%25] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %27 = llvm.mlir.constant(1 : i32) : i32
    %28 = llvm.getelementptr %24[%27] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %29 = llvm.mlir.constant(2 : i32) : i32
    %30 = llvm.getelementptr %24[%29] : (!llvm.ptr, i32) -> !llvm.ptr, i64
    %31 = llvm.ptrtoint %30 : !llvm.ptr to i64
    %32 = llvm.mlir.constant(3 : i32) : i32
    %33 = llvm.getelementptr %24[%32] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %34 = llvm.mlir.constant(0 : i64) : i64
    %35 = llvm.getelementptr %33[%34] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %36 = llvm.load %35 : !llvm.ptr -> i64
    %37 = llvm.mlir.constant(1 : i64) : i64
    %38 = llvm.getelementptr %33[%37] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %39 = llvm.load %38 : !llvm.ptr -> i64
    %40 = llvm.mlir.constant(2 : i64) : i64
    %41 = llvm.getelementptr %33[%40] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %42 = llvm.load %41 : !llvm.ptr -> i64
    %43 = llvm.mlir.constant(3 : i64) : i64
    %44 = llvm.getelementptr %33[%43] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %45 = llvm.load %44 : !llvm.ptr -> i64
    %46 = llvm.mlir.constant(2 : i32) : i32
    %47 = llvm.getelementptr %arg0[%46] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %48 = llvm.mlir.constant(0 : i32) : i32
    %49 = llvm.getelementptr %47[%48] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %50 = llvm.mlir.constant(1 : i32) : i32
    %51 = llvm.getelementptr %47[%50] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %52 = llvm.mlir.constant(2 : i32) : i32
    %53 = llvm.getelementptr %47[%52] : (!llvm.ptr, i32) -> !llvm.ptr, i64
    %54 = llvm.ptrtoint %53 : !llvm.ptr to i64
    %55 = llvm.mlir.constant(3 : i32) : i32
    %56 = llvm.getelementptr %47[%55] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
    %57 = llvm.mlir.constant(0 : i64) : i64
    %58 = llvm.getelementptr %56[%57] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %59 = llvm.load %58 : !llvm.ptr -> i64
    %60 = llvm.mlir.constant(1 : i64) : i64
    %61 = llvm.getelementptr %56[%60] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %62 = llvm.load %61 : !llvm.ptr -> i64
    %63 = llvm.mlir.constant(2 : i64) : i64
    %64 = llvm.getelementptr %56[%63] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %65 = llvm.load %64 : !llvm.ptr -> i64
    %66 = llvm.mlir.constant(3 : i64) : i64
    %67 = llvm.getelementptr %56[%66] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    %68 = llvm.load %67 : !llvm.ptr -> i64
    %69 = llvm.call @add(%3, %5, %8, %13, %16, %19, %22, %26, %28, %31, %36, %39, %42, %45, %49, %51, %54, %59, %62, %65, %68) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    llvm.return
  }
}

@@ -3,3 +3,4 @@ add_subdirectory(TTIR)
add_subdirectory(TTNN)
add_subdirectory(TTMetal)
add_subdirectory(TTKernel)
add_subdirectory(LLVM)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: imo this directory structure is most sensible. It's a little ugly to add a dir for "Dialect/LLVM" because LLVM isn't a new Dialect, and we only want to define a single Transform which is (softly) dependent on other TTIR passes to create module + funcs with expected attrs. On the other hand, jamming this Transform into TTIR or some other existing dir feels even worse to me, because this is strictly an LLVMDialect -> LLVMDialect Transform

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm not sure what's conventional here, @sdjordjevicTT, was anything like this covered at MLIR conf?

I generally agree with you though could see a loose argument for putting it in TTIR since it's defining the TTIR->CPU calling convention.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am honestly not sure, we didn't come up with something similar during the conference. I was observing a few mlir repos, and none of them had LLVM-specific folders in them, however not sure if any of them had something similar to implement.

Maybe the best advice we can get on this topic is their discord channel. @vwellsTT you can ask them this specific question here:
https://discord.com/channels/636084430946959380/642426447167881246

They are mostly responsive and helpful :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok--the discord link doesn't seem to take me to a channel properly though for some reason. Can you give me name or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, nvm, needed to join LLVM group first, now it works I think

@vwellsTT vwellsTT force-pushed the vwells/llvm_helper_transform branch 3 times, most recently from 608634e to 9a6d461 Compare December 13, 2024 21:31
mlir::tensor::TensorDialect, mlir::linalg::LinalgDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::tosa::TosaDialect, mlir::vector::VectorDialect,
mlir::emitc::EmitCDialect, mlir::bufferization::BufferizationDialect,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line is also modified by earlier PR, leading to some ugliness, but I think strictly speaking I only want to add mlir::LLVM::LLVMDialect for this PR to make testing easier, and earlier PRs will add Bufferization etc when merged

@vwellsTT
Copy link
Contributor Author

Perhaps adding my dylib shim at this stage would be useful for testing. So far, I've tested manually that the transform will run + produce legal output (w.r.t. to types, number of args, etc.) and that seems fine, but wouldn't be surprised at all if there's some logical bug in argument unpacking here

@vwellsTT vwellsTT force-pushed the vwells/llvm_helper_transform branch 8 times, most recently from 7f71e72 to 2ca4886 Compare December 16, 2024 22:47
@vwellsTT vwellsTT force-pushed the vwells/llvm_helper_transform branch 4 times, most recently from 7ea7691 to 6348c92 Compare December 17, 2024 20:20
@vwellsTT vwellsTT force-pushed the vwells/llvm_helper_transform branch from 6348c92 to b3e7289 Compare December 17, 2024 20:25
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things look good, but let's hold off on landing until we get other opinions on the directory structure.

@@ -3,3 +3,4 @@ add_subdirectory(TTIR)
add_subdirectory(TTNN)
add_subdirectory(TTMetal)
add_subdirectory(TTKernel)
add_subdirectory(LLVM)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm not sure what's conventional here, @sdjordjevicTT, was anything like this covered at MLIR conf?

I generally agree with you though could see a loose argument for putting it in TTIR since it's defining the TTIR->CPU calling convention.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants