Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support boxdot with n neighboring indices #22

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

KeitaNakamura
Copy link
Contributor

Hi, @mcabbott.
I have implemented boxdot with n neighboring indices, as I mentioned on Discourse. I'm not sure if this aligns with your intended function, but I would appreciate any feedback you may have.

julia> using TensorCore

julia> A = rand(3,3,3);

julia> B = rand(3,3,3);

julia> A ₂ B  reshape(A, 3,:) * reshape(B, :,3)
true

julia> boxdot(A, B, Val(3))  transpose(vec(A)) * vec(B)
true

@mcabbott
Copy link
Collaborator

At first glance looks good!

I'd like to look closely at how adjoint vectors get handled, as that was the tricky case before.

I wonder whether boxdot!(C, A, B) can just infer Val(2) when necessary? It should be known from the types.

@KeitaNakamura
Copy link
Contributor Author

KeitaNakamura commented Oct 27, 2024

I'd like to look closely at how adjoint vectors get handled, as that was the tricky case before.

Yes. I kept your implementation of single contraction for handling adjoint vectors, using Val{1} specialization.

I wonder whether boxdot!(C, A, B) can just infer Val(2) when necessary? It should be known from the types.

I’m not sure if I fully understand your suggestion, but currently, the implementation does not check the size of the C tensor, so any order of contraction works as long as C has the correct length. For example:

julia> boxdot!(similar(A, 81), A, B, Val(1)); # works

julia> boxdot!(similar(A, 9,9), A, B, Val(1)); # works

julia> boxdot!(similar(A, 9), A, B, Val(2)); # works

julia> boxdot!(similar(A, 1), A, B, Val(3)); # works

Are you suggesting we check the length of C to automatically apply the appropriate contraction? Or should we instead check the size of C (or perhaps just ndims) and select the contraction accordingly? But in both cases, I'm concerned that if someone mistakenly provides an incorrect C tensor or just forgets to put Val(N), the function might not throw an error.

(Edit)
From the above examples, I think we should at least check the ndims of the tensor C. I actually prefer to check the size of C strictly, though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants