-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#0: Add the dim 0 support repeat backward #5596
Conversation
dd4c872
to
9b67f2a
Compare
Hi @tarafdarTT ,
Supported shapes
For more info, please find the below doc link |
@ruthreshx this is awesome ty! This LGTM, I see you have added support at the tt_lib level. In a future PR we could look at adding this at the ttnn level as well. |
std::vector<Tensor> _repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config) { | ||
std::vector<Tensor> grad_tensor; | ||
auto shape_wh = input.shape(); | ||
TT_ASSERT( shape_wh[0] == 1 && "input shape[0] should be 1"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the TT_ASSERT should be switched to TT_FATAL here
e1524a3
to
5bf77e8
Compare
5bf77e8
to
a079bfe
Compare
Added the dim 0 support for repeat backward.
yet to add the support for dim 1.