Skip to content

Commit

Permalink
Formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Sep 25, 2023
1 parent 7cddde3 commit ecfe4c4
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,26 @@ function FiniteDifferences.to_vec(t::AbstractTensorMap)
return vcat(real(v), imag(v))
end
end

function from_vec(x)
t′ = similar(t)
T = scalartype(t)
ctr = 0
for (c, b) in blocks(t′)
n = length(b)
if T <: Real
copyto!(b, reshape(x[ctr+1:ctr+n], size(b)) ./ sqrt(dim(c)))
copyto!(b, reshape(x[(ctr + 1):(ctr + n)], size(b)) ./ sqrt(dim(c)))
else
v = x[ctr+1:ctr+2n]
copyto!(b, complex.(x[ctr+1:ctr+n], x[ctr+n+1:ctr+2n]) ./ sqrt(dim(c)))
v = x[(ctr + 1):(ctr + 2n)]
copyto!(b,
complex.(x[(ctr + 1):(ctr + n)], x[(ctr + n + 1):(ctr + 2n)]) ./
sqrt(dim(c)))
end
ctr += T <: Real ? n : 2n
end
return t′
end

return vec, from_vec
end
FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t))
Expand Down Expand Up @@ -144,7 +146,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
E = TensorMap(randn, T, (V[1:i]...) (V[1:i]...))
test_rrule(LinearAlgebra.tr, E)
end

A = TensorMap(randn, T, V[1] V[2] V[3] V[4] V[5])
test_rrule(LinearAlgebra.adjoint, A)
test_rrule(LinearAlgebra.norm, A, 2)
Expand Down

0 comments on commit ecfe4c4

Please sign in to comment.