-
Notifications
You must be signed in to change notification settings - Fork 199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: grad for indexed ops with axes #1360
Conversation
nx/lib/nx/binary_backend.ex
Outdated
inverse_permutation = | ||
if diff > 0 do | ||
Enum.to_list(0..(diff - 1)) ++ Enum.map(inverse_permutation, &(&1 + diff)) | ||
else | ||
inverse_permutation | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not totally sure if this is the correct fix. Basically the test for grad(gather) was failing and this fixed it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will we ever hit the else branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we can. Keeping only the do
branch works if we change the range to 0..(diff-1)//1
because it seems that for some cases the diff is 0. I think those are cases where the index tensor as many axes as the input tensor.
The Enum.filter(... in Nx.axes(out))
takes care of the case where the index tensor has fewer axes, and the do
block takes care of when there are more axes than in the input.
💚 💙 💜 💛 ❤️ |
No description provided.