-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling op model interface for constraints and L1 usage. #1554
Conversation
Q: if validation fails, we get an exception message from tt-metal, should we wire this back to the API caller? shall we make IsLegal returns a tuple<bool, optionalstd::string> ? @nobradovictt @odjuricicTT |
I think it does make sense in the long run, we had something similar on Buda. Tho i don't know what is the scope of a change like this? We should definitely prioritize having something working e2e first. |
It's a simple change. The error message is still a human readable string, unusable to the compiler, but maybe usable to the people running/debugging compiler. Moving to program-friendly error message would be much harder problem. |
b63a86a
to
99e5c29
Compare
@mbezuljTT What will the third value be if the op is DPS (the output tensor is pre-allocated and passed in as an arg)? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Went trought half the changes, will continue after lunch.
Op query functions implemented at TTNNOpModelLib are not DPS, therefore, it would be size of the output tensor anyway. however, peak usage might be wrong in this case as it might include output tensor size (which is probably not what you want). when Ops really become DPS, you would want to change how op is invoked in the TTNNOpModelLib.cpp to DPS as well. When you do that, third value would become zero, but you can use another graph capture around create_device_tensor for output allocation to get it's size. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thanks for pushing this all the way :)
I have a few more details to look over tomorrow.
for (uint32_t i = 0; i < numOperands; i++) { | ||
auto operand = consumerOp->getOperand(i); | ||
auto input = mlir::cast<RankedTensorType>(operand.getType()); | ||
|
||
if ((inputUnderCheckFound == false) && | ||
(inputUnderCheck.getShape() == input.getShape())) { | ||
// this is the input we are checking compatibility for | ||
inputUnderCheckFound = true; | ||
inputLayouts.push_back(producerLayout); | ||
} else { | ||
// this is the other input that we DRAM interleave | ||
|
||
// what if it is tilized already? | ||
auto elementType = | ||
TileType::get(consumerOp->getContext(), input.getElementType()); | ||
|
||
auto layout = TTNNLayoutAttr::get( | ||
consumerOp->getContext(), input.getShape(), elementType, | ||
BufferType::DRAM, workerGrid, | ||
TensorMemoryLayoutAttr::get(consumerOp->getContext(), | ||
TensorMemoryLayout::Interleaved)); | ||
inputLayouts.push_back(layout); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, a bit of cleanup is needed here. I'll do this in a follow up PR.
CI fixes
bcfb884
to
7c8b8a5
Compare
This PR plumbs OpModelInterface to the underlying tt-metal op queries for validation and L1 memory consumption.
TTNNOpModelInterface.td
getOpConstraints takes input(s) and outputTTNNLayoutAttr
and returns a tuple of three values:TTNNOpModelInterface.cpp
implements hooks to the wrapper library 'TTNNOpModelLib' (where metal API is). Per each op, implementation takesllvm::ArrayRef<>
) from its operands,TTNNLayoutsAttr
and pass them to the wrapper library
TTNNOpModelLib
.TTNNOpModelLib
converts mlir structures to metal structures, and calls into underlying 'tt-metal' op interface.Underlying
tt-metal
op interface::ttnn::graph::query_op_constraints(..)
consumes a target op (e.g. 'ttnn::relu') and it's arguments in the order of op implemented ::invoke function that we are targeting.Implemented
SingletonDeviceContext
to avoid constant opening/closing device. This class should ensure opened device is a mockup device when it's implemented on the tt-metal side (tenstorrent/tt-metal#14000)Added 3 types of unit tests:
Due to differences in tt-metal and LLVM project setups (compiler standard, exceptions) these are implemented as the place Google unit test. Unlike other unit tests that are also Google unit tests but wrapped into LLVM (and invoked using llvm-lit).
As these tests require TT hardware (until mockup device is implemented), changed Build tt-mlir op_model flavour to use n300 runners.
Additionally, wired op model interface in the ShardSolver; mnist_sharded.mlir compiles and runs. @odjuricicTT confirmed found solution is the one we expected.
Internal doc describing more details can be found here