Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling op model interface for constraints and L1 usage. #1554

Merged
merged 32 commits into from
Dec 31, 2024

Conversation

mbezuljTT
Copy link
Contributor

@mbezuljTT mbezuljTT commented Dec 10, 2024

This PR plumbs OpModelInterface to the underlying tt-metal op queries for validation and L1 memory consumption.

TTNNOpModelInterface.td getOpConstraints takes input(s) and output TTNNLayoutAttr and returns a tuple of three values:

  1. A boolean indicating if the op is legal for the given input/output layouts.
  2. If the op is legal, a tuple of three values representing the op memory L1 usage estimate in bytes.
    • The first value is the CB L1 peak allocation in bytes.
    • The second value is the Tensor L1 peak allocation in bytes.
    • The third value is the Output L1 buffer allocation in bytes.
  3. If the op is illegal, a string describing the failure.

TTNNOpModelInterface.cpp implements hooks to the wrapper library 'TTNNOpModelLib' (where metal API is). Per each op, implementation takes

  • tensor shapes (llvm::ArrayRef<>) from its operands,
  • worker grid (used for virtual to physical cores conversion),
  • op specific params (like softmax dimension), and
  • with layouts TTNNLayoutsAttr
    and pass them to the wrapper library TTNNOpModelLib.

TTNNOpModelLib converts mlir structures to metal structures, and calls into underlying 'tt-metal' op interface.

Underlying tt-metal op interface ::ttnn::graph::query_op_constraints(..) consumes a target op (e.g. 'ttnn::relu') and it's arguments in the order of op implemented ::invoke function that we are targeting.

Implemented SingletonDeviceContext to avoid constant opening/closing device. This class should ensure opened device is a mockup device when it's implemented on the tt-metal side (tenstorrent/tt-metal#14000)

Added 3 types of unit tests:

  • TestConversion - tests conversion of the MLIR to TTNN types
  • TestOpModelLib - tests interface to metal API
  • TestOpModelInterface - tests interface built in metal ops

Due to differences in tt-metal and LLVM project setups (compiler standard, exceptions) these are implemented as the place Google unit test. Unlike other unit tests that are also Google unit tests but wrapped into LLVM (and invoked using llvm-lit).

As these tests require TT hardware (until mockup device is implemented), changed Build tt-mlir op_model flavour to use n300 runners.

Additionally, wired op model interface in the ShardSolver; mnist_sharded.mlir compiles and runs. @odjuricicTT confirmed found solution is the one we expected.

Internal doc describing more details can be found here

@mbezuljTT
Copy link
Contributor Author

Q: if validation fails, we get an exception message from tt-metal, should we wire this back to the API caller? shall we make IsLegal returns a tuple<bool, optionalstd::string> ? @nobradovictt @odjuricicTT

@odjuricicTT
Copy link
Contributor

Q: if validation fails, we get an exception message from tt-metal, should we wire this back to the API caller? shall we make IsLegal returns a tuple<bool, optionalstd::string> ?

I think it does make sense in the long run, we had something similar on Buda. Tho i don't know what is the scope of a change like this? We should definitely prioritize having something working e2e first.

@mbezuljTT
Copy link
Contributor Author

Q: if validation fails, we get an exception message from tt-metal, should we wire this back to the API caller? shall we make IsLegal returns a tuple<bool, optionalstd::string> ?

I think it does make sense in the long run, we had something similar on Buda. Tho i don't know what is the scope of a change like this? We should definitely prioritize having something working e2e first.

It's a simple change. The error message is still a human readable string, unusable to the compiler, but maybe usable to the people running/debugging compiler. Moving to program-friendly error message would be much harder problem.

@mbezuljTT mbezuljTT force-pushed the mbezulj/2411-opmodel-plumbing-mnist-ops branch 6 times, most recently from b63a86a to 99e5c29 Compare December 24, 2024 16:30
@mbezuljTT mbezuljTT marked this pull request as ready for review December 24, 2024 17:18
@odjuricicTT
Copy link
Contributor

The third value is the Output L1 buffer allocation in bytes.

@mbezuljTT What will the third value be if the op is DPS (the output tensor is pre-allocated and passed in as an arg)?

Copy link
Contributor

@odjuricicTT odjuricicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went trought half the changes, will continue after lunch.

.github/workflows/build-and-test.yml Outdated Show resolved Hide resolved
@mbezuljTT
Copy link
Contributor Author

The third value is the Output L1 buffer allocation in bytes.

@mbezuljTT What will the third value be if the op is DPS (the output tensor is pre-allocated and passed in as an arg)?

Op query functions implemented at TTNNOpModelLib are not DPS, therefore, it would be size of the output tensor anyway. however, peak usage might be wrong in this case as it might include output tensor size (which is probably not what you want).

when Ops really become DPS, you would want to change how op is invoked in the TTNNOpModelLib.cpp to DPS as well. When you do that, third value would become zero, but you can use another graph capture around create_device_tensor for output allocation to get it's size.

Copy link
Contributor

@odjuricicTT odjuricicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks for pushing this all the way :)

I have a few more details to look over tomorrow.

lib/Dialect/TTNN/Analysis/ShardSolver.cpp Outdated Show resolved Hide resolved
Comment on lines +544 to +567
for (uint32_t i = 0; i < numOperands; i++) {
auto operand = consumerOp->getOperand(i);
auto input = mlir::cast<RankedTensorType>(operand.getType());

if ((inputUnderCheckFound == false) &&
(inputUnderCheck.getShape() == input.getShape())) {
// this is the input we are checking compatibility for
inputUnderCheckFound = true;
inputLayouts.push_back(producerLayout);
} else {
// this is the other input that we DRAM interleave

// what if it is tilized already?
auto elementType =
TileType::get(consumerOp->getContext(), input.getElementType());

auto layout = TTNNLayoutAttr::get(
consumerOp->getContext(), input.getShape(), elementType,
BufferType::DRAM, workerGrid,
TensorMemoryLayoutAttr::get(consumerOp->getContext(),
TensorMemoryLayout::Interleaved));
inputLayouts.push_back(layout);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, a bit of cleanup is needed here. I'll do this in a follow up PR.

lib/Dialect/TTNN/Analysis/ShardSolver.cpp Show resolved Hide resolved
@mbezuljTT mbezuljTT force-pushed the mbezulj/2411-opmodel-plumbing-mnist-ops branch from bcfb884 to 7c8b8a5 Compare December 30, 2024 14:06
@vmilosevic vmilosevic merged commit 3745a88 into main Dec 31, 2024
20 checks passed
@vmilosevic vmilosevic deleted the mbezulj/2411-opmodel-plumbing-mnist-ops branch December 31, 2024 11:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants