Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime refactor to support runtime stitching #448

Closed
wants to merge 2 commits into from
Closed

Conversation

jnie-TT
Copy link
Contributor

@jnie-TT jnie-TT commented Aug 20, 2024

#103 Runtime-Stitching
First Iteration/Prototype:

  • Added toLayout API that takes in a device, binary, program index, input index, and input tensor, and returns a tensor with the layout converted to the memory descriptor describing this input tensor in the binary.
    • This gives the user the ability to hold a handle of a tensor that could be used in multiple program preceding runs. For example the weights of the model when running inference. The user could for example use toLayout to receive a tensor handle of the weights on device, and pass this handle to all subsequent forward runs.
  • Updated runProgram to implicitly convert inputs/outputs to the desired layout described in the binary.
  • Submit now returns a vector of tensors instead of accepting output containers from the user. Each tensor now has an event that can be waited on (not implemented yet).
    • With this, the user is now responsible for deallocating these tensors. I added a brief deallocate API, maybe we could add this to the destructor of the Tensor class.
    • These tensors returned could reside anywhere (host, device dram, device l1) and could have different layouts/memory configs, all according to the MemoryDesc in the flatbuffer.

TODOs:

  • Add metal support. Currently just added support in ttnn as a prototype.
  • Polish tensor lifetimes, events, allocate/deallocate. I'm not super familiar with how tensors are deallocated in metal, I can look into it more. But I imagine this can get complicated as we introduce async execution/events and even multi device, so it would be great if we could come up with a clean routine from the start.
  • Testing. Currently I haven't run any tests with this routine yet, just want to get input on whether the overall structure/implementation makes sense to stakeholders. Once we come up with a finalized prototype I'll update runtime tests/ttrt and test with existing flatbuffers.

Please let me know what you think, any suggestions are appreciated!

tensor.deallocate();
return;
}
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if i understand correctly we can build for both runtimes, so this should be an #if?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you're right this should be an #if, will update!

@@ -34,6 +34,29 @@ ttnn::Tensor untilize(ttnn::Tensor const &input) {

namespace tt::runtime::ttnn {

static bool isOnHost(const ::ttnn::Tensor &tensor) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there util methods like these in tt-metal? if there are, we should probably use those...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK in metal they just compare storage type like what we're doing here. Custom API also allows us to catch cases that are not supported yet.

output->desc()->layout()->memory_desc();
::ttnn::Tensor layoutUpdatedOutputTensor =
updateTensorMemoryConfig(device, *outputTensor, outputDesc);
outputs.push_back(toTypeErasedTensor(std::move(layoutUpdatedOutputTensor)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this way, output tensors will be either on the host or on device, depending on the outputDesc in the flatbuffer?

how will the FE runtime move it to host?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct it could be on host or device depending on outputDesc. If FE wants to move to host, with the current implementation it would need to have a program with output desc host.

If desired I can overload toLayout that accepts some layout descriptors to decouple it from the flatbuffer.

@@ -24,6 +24,14 @@ target_include_directories(TTBinary
)
add_dependencies(TTBinary FBS_GENERATION)

add_library(TTRuntimeTypes STATIC types.cpp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this static library (TTRuntimeTypes) to be the part of the shared lib used by tt-forge (in lib/SharedLib/CMakeLists.txt).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants