You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to compute the gradient of function, that takes a matrix as input and contracts with another tensor. If both tensors have either Float64 or ComplexF64 values, the Zygote gradient works. However, if one has Float64 and the other one ComplexF64 entries, it fails and returns an InexactError: Float64() error. Below I provide a MWE, which has the same behaviour as the actual function I need.
normDiff(a, b) = norm(a - b);
A = TensorMap(randn, Float64, ComplexSpace(2), ComplexSpace(2));
B = TensorMap(randn, Float64, ComplexSpace(2), ComplexSpace(2));
Zygote.gradient((x, y) -> normDiff(x, y), A, B)
A = TensorMap(randn, ComplexF64, ComplexSpace(2), ComplexSpace(2));
B = TensorMap(randn, ComplexF64, ComplexSpace(2), ComplexSpace(2));
Zygote.gradient((x, y) -> normDiff(x, y), A, B)
A = TensorMap(randn, Float64, ComplexSpace(2), ComplexSpace(2));
B = TensorMap(randn, ComplexF64, ComplexSpace(2), ComplexSpace(2));
Zygote.gradient((x, y) -> normDiff(x, y), A, B)
The text was updated successfully, but these errors were encountered:
Could you test this with the master version of TensorKit.jl? There was some fix for the chainrules in the case of mixed scalartype, which was not yet included in the latest registered version. We will provide a version update soon.
I am trying to compute the gradient of function, that takes a matrix as input and contracts with another tensor. If both tensors have either Float64 or ComplexF64 values, the Zygote gradient works. However, if one has Float64 and the other one ComplexF64 entries, it fails and returns an InexactError: Float64() error. Below I provide a MWE, which has the same behaviour as the actual function I need.
The text was updated successfully, but these errors were encountered: