Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: grad for indexed ops with axes #1360

Merged
merged 6 commits into from
Nov 11, 2023
Merged

Conversation

polvalente
Copy link
Contributor

No description provided.

@polvalente polvalente self-assigned this Nov 9, 2023
Comment on lines 2107 to 2112
inverse_permutation =
if diff > 0 do
Enum.to_list(0..(diff - 1)) ++ Enum.map(inverse_permutation, &(&1 + diff))
else
inverse_permutation
end
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally sure if this is the correct fix. Basically the test for grad(gather) was failing and this fixed it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we ever hit the else branch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we can. Keeping only the do branch works if we change the range to 0..(diff-1)//1 because it seems that for some cases the diff is 0. I think those are cases where the index tensor as many axes as the input tensor.

The Enum.filter(... in Nx.axes(out)) takes care of the case where the index tensor has fewer axes, and the do block takes care of when there are more axes than in the input.

@josevalim josevalim merged commit bb76359 into main Nov 11, 2023
8 checks passed
@josevalim josevalim deleted the pv-fix/indexed-axes-grad branch November 11, 2023 20:35
@josevalim
Copy link
Collaborator

💚 💙 💜 💛 ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants