Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: Add the dim 0 support repeat backward #5596

Merged
merged 1 commit into from
Mar 28, 2024
Merged

Conversation

ruthreshx
Copy link
Contributor

Added the dim 0 support for repeat backward.
yet to add the support for dim 1.

@ruthreshx ruthreshx force-pushed the ruthresh/repeat_backward branch 2 times, most recently from dd4c872 to 9b67f2a Compare February 26, 2024 11:48
@ruthreshx
Copy link
Contributor Author

ruthreshx commented Feb 26, 2024

Hi @tarafdarTT ,

  • With the help of a link I have derived the repeat backward formula in python & it is working as we expected in colab for the various inputs.

  • The bottleneck is, the intermediate grad_size and sum_dims are incompatible for the TT support. It requires >4D dims to reshape and sum accordingly.

  • Due to this we were unable to add the entire support, but we were able to add the partial support for the dim 0 & 1, only when N and C is 1 and the repeat is of any number for such dim, other should be 1.

  • For the dim 2 & 3, it requires the shape should TILE, rather our one of the constraints is, the respective dim should be 1.

Supported shapes
Ex:

Input shape:
 [1, 1, 32, 32], [1,1, 320, 320] (N & C should be one)
Repeats: 
 [12, 1, 1, 1], [6, 1, 1, 1], [1, 24, 1, 1], [1, 3, 1, 1]]  (repeat is of any number for such dim, other should be 1).

          For dim 0, Input is [1,1, 32, 32] the repeats for such dim should be any number.
	  [32, 1, 1, 1], [2, 1, 1, 1], [6, 1, 1, 1]
           
	 For dim 1, Input is [1,1, 64, 64] the repeats for such dim should be any number.
	  [1, 24, 1, 1], [1, 30, 1, 1], [1, 14, 1, 1]

For more info, please find the below doc link
https://docs.google.com/document/d/1kjQPPqD8uG-M14qz5AJlWoSY_otV6u6TJGvw9LvQPsE/edit

@ruthreshx ruthreshx requested a review from ntarafdar March 27, 2024 08:15
@ntarafdar
Copy link
Contributor

@ruthreshx this is awesome ty! This LGTM, I see you have added support at the tt_lib level. In a future PR we could look at adding this at the ttnn level as well.

std::vector<Tensor> _repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto shape_wh = input.shape();
TT_ASSERT( shape_wh[0] == 1 && "input shape[0] should be 1");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the TT_ASSERT should be switched to TT_FATAL here

@ruthreshx
Copy link
Contributor Author

@ruthreshx ruthreshx force-pushed the ruthresh/repeat_backward branch from 5bf77e8 to a079bfe Compare March 28, 2024 08:30
@ruthreshx ruthreshx merged commit a365131 into main Mar 28, 2024
38 checks passed
@github-actions github-actions bot deleted the ruthresh/repeat_backward branch December 13, 2024 03:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants