diff --git a/Project.toml b/Project.toml index 95873c18..fc0c157c 100644 --- a/Project.toml +++ b/Project.toml @@ -14,23 +14,20 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b" -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - -[extensions] -TensorKitChainRulesCoreExt = "ChainRulesCore" - [compat] HalfIntegers = "1" LRUCache = "1.0.2" PackageExtensionCompat = "1" Strided = "2" -TensorOperations = "4.0.6" +TensorOperations = "4.0.6 - 4.0.7" TupleTools = "1.1" VectorInterface = "0.4" WignerSymbols = "1,2" julia = "1.6" +[extensions] +TensorKitChainRulesCoreExt = "ChainRulesCore" + [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" @@ -46,3 +43,6 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b" [targets] test = ["Combinatorics", "HalfIntegers", "LinearAlgebra", "Random", "TensorOperations", "Test", "TestExtras", "WignerSymbols", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences"] + +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index bc57d4c6..815f9900 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -56,6 +56,14 @@ using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, checksquare using ..TensorKit: OrthogonalFactorizationAlgorithm, QL, QLpos, QR, QRpos, LQ, LQpos, RQ, RQpos, SVD, SDD, Polar +# only defined in >v1.7 +@static if VERSION < v"1.7-" + _rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im) + _argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2] +else + _argmax(f, domain) = argmax(f, domain) +end + # TODO: define for CuMatrix if we support this function one!(A::DenseMatrix) Threads.@threads for j in 1:size(A, 2) @@ -273,12 +281,12 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where { while j <= n if DI[j] == 0 vr = view(VR, :, j) - s = conj(sign(argmax(abs, vr))) + s = conj(sign(_argmax(abs, vr))) V[:, j] .= s .* vr else vr = view(VR, :, j) vi = view(VR, :, j + 1) - s = conj(sign(argmax(abs, vr))) # vectors coming from lapack have already real absmax component + s = conj(sign(_argmax(abs, vr))) # vectors coming from lapack have already real absmax component V[:, j] .= s .* (vr .+ im .* vi) V[:, j + 1] .= s .* (vr .- im .* vi) j += 1 @@ -296,7 +304,7 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true, A)[[2, 4]] for j in 1:n v = view(V, :, j) - s = conj(sign(argmax(abs, v))) + s = conj(sign(_argmax(abs, v))) v .*= s end return D, V @@ -308,7 +316,7 @@ function eigh!(A::StridedMatrix{T}) where {T<:BlasFloat} D, V = LAPACK.syevr!('V', 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0) for j in 1:n v = view(V, :, j) - s = conj(sign(argmax(abs, v))) + s = conj(sign(_argmax(abs, v))) v .*= s end return D, V diff --git a/src/planar/macros.jl b/src/planar/macros.jl index 1d76cc48..e10308d2 100644 --- a/src/planar/macros.jl +++ b/src/planar/macros.jl @@ -23,7 +23,7 @@ function planarparser(planarexpr, kwargs...) # braiding tensors need to be instantiated before kwargs are processed push!(parser.preprocessors, _construct_braidingtensors) - + for (name, val) in kwargs if name == :order isexpr(val, :tuple) || @@ -62,7 +62,7 @@ function planarparser(planarexpr, kwargs...) throw(ArgumentError("Unknown keyword argument `name`.")) end end - + treebuilder = parser.contractiontreebuilder treesorter = parser.contractiontreesorter costcheck = parser.contractioncostcheck diff --git a/test/ad.jl b/test/ad.jl index 3b82c0b4..d686a834 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -61,7 +61,7 @@ end # Float32 and finite differences don't mix well precision(::Type{<:Union{Float32,Complex{Float32}}}) = 1e-2 -precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-8 +precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-6 # rrules for functions that destroy inputs # ---------------------------------------- @@ -111,7 +111,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), ℂ[Z2Irrep](0 => 3, 1 => 2)', ℂ[Z2Irrep](0 => 2, 1 => 3), ℂ[Z2Irrep](0 => 2, 1 => 2)), - (ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2), + (ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 2), ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1), ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)', ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2), @@ -227,7 +227,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), C = TensorMap(randn, T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) H = TensorMap(randn, T, V[3] ⊗ V[4] ← V[3] ⊗ V[4]) H = (H + H') / 2 - atol = 1e-6 + atol = precision(T) for alg in (TensorKit.QR(), TensorKit.QRpos()) test_rrule(leftorth, A; fkwargs=(; alg=alg), atol) @@ -312,7 +312,8 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0)) - c, = argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) + c, = TensorKit.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), + blocks(S)) U, S, V, ϵ = tsvd(C; trunc=truncdim(2 * dim(c))) ΔU = TensorMap(randn, scalartype(U), space(U)) ΔS = TensorMap(randn, scalartype(S), space(S)) diff --git a/test/planar.jl b/test/planar.jl index 062e41b6..bf3efc19 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -71,29 +71,29 @@ end @testset "@planar" verbose = true begin T = ComplexF64 - + @testset "contractcheck" begin V = ℂ^2 A = TensorMap(rand, T, V ⊗ V ← V) B = TensorMap(rand, T, V ⊗ V ← V') @tensor C1[i j; k l] := A[i j; m] * B[k l; m] - @tensor contractcheck=true C2[i j; k l] := A[i j; m] * B[k l; m] + @tensor contractcheck = true C2[i j; k l] := A[i j; m] * B[k l; m] @test C1 ≈ C2 B2 = TensorMap(rand, T, V ⊗ V ← V) # wrong duality for third space @test_throws SpaceMismatch("incompatible spaces for m: $V ≠ $(V')") begin @tensor contractcheck = true C3[i j; k l] := A[i j; m] * B2[k l; m] end - + A = TensorMap(rand, T, V ← V ⊗ V) B = TensorMap(rand, T, V ⊗ V ← V) @planar C1[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] - @planar contractcheck=true C2[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] + @planar contractcheck = true C2[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] @test C1 ≈ C2 @test_throws SpaceMismatch("incompatible spaces for l: $V ≠ $(V')") begin - @planar contractcheck=true C3[i; j] := A[i; k l] * τ[k l; m n] * B[n j; m] + @planar contractcheck = true C3[i; j] := A[i; k l] * τ[k l; m n] * B[n j; m] end end - + @testset "MPS networks" begin P = ℂ^2 Vmps = ℂ^12