Replies: 1 comment
-
I think this is a flax-specific question and not really connected to netket. Ask it on their forum! I think it does not work because of the way flax handles functions in those modules. Wrapping |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I am constructing the log of the overlap between the variant state and the basis.
From the teaching documents, it can be seen that most models utilize JNP's built-in functions to automatically handle the calculation of batch dimensions. For example, the meanfield ansatz of the Ising model:
However, when I tried to construct a model by myself, I encountered bugs. (I am a beginner using jax, so what I may encounter is probably just coding issues with jax)
As shown above, I tried to compute x with batch size, but Type error emerged, saying that "compute_energy required 4 parameters but 5 are given". I guess this might have something to do with defining the function within the class.
I'd appreciate it if anyone can help me correct the bug.
(BTW, I code with Torch a lot, and I'm very used to unsqueezing the data to get the batch dimension, then calculating with broadcast mechanism conveniently. Is there any similar method to do so in Jax?)
Beta Was this translation helpful? Give feedback.
All reactions