diff --git a/src/AutoDiff.jl b/src/AutoDiff.jl index e55c3f6..c3dcbde 100644 --- a/src/AutoDiff.jl +++ b/src/AutoDiff.jl @@ -15,18 +15,27 @@ using ForwardDiff: ForwardDiff using LinearAlgebra: LinearAlgebra function _solve_jacobian_θ(mcp::MixedComplementarityProblems.PrimalDualMCP, solution, θ) - !isnothing(mcp.∇F_θ) || throw( + !isnothing(mcp.∇F_θ!) || throw( ArgumentError( "Missing sensitivities. Set `compute_sensitivities = true` when constructing the PrimalDualMCP.", ), ) (; x, y, s, ϵ) = solution - ∂z∂θ = - LinearAlgebra.qr(-collect(mcp.∇F_z(x, y, s; θ, ϵ)), LinearAlgebra.ColumnNorm()) \ - collect(mcp.∇F_θ(x, y, s; θ, ϵ)) - ∂z∂θ + ∇F_z = let + ∇F = MixedComplementarityProblems.get_result_buffer(mcp.∇F_z!) + mcp.∇F_z!(∇F, x, y, s; θ, ϵ) + ∇F + end + + ∇F_θ = let + ∇F = MixedComplementarityProblems.get_result_buffer(mcp.∇F_θ!) + mcp.∇F_θ!(∇F, x, y, s; θ, ϵ) + ∇F + end + + LinearAlgebra.qr(-collect(∇F_z), LinearAlgebra.ColumnNorm()) \ collect(∇F_θ) end function ChainRulesCore.rrule( @@ -54,10 +63,14 @@ function ChainRulesCore.rrule( @views project_to_θ( ∂z∂θ[1:(mcp.unconstrained_dimension), :]' * ∂l∂x + - ∂z∂θ[(mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension), :]' * - ∂l∂y + - ∂z∂θ[(mcp.unconstrained_dimension + mcp.constrained_dimension + 1):end, :]' * - ∂l∂s, + ∂z∂θ[ + (mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension), + :, + ]' * ∂l∂y + + ∂z∂θ[ + (mcp.unconstrained_dimension + mcp.constrained_dimension + 1):end, + :, + ]' * ∂l∂s, ) end diff --git a/src/MixedComplementarityProblems.jl b/src/MixedComplementarityProblems.jl index cdb322d..d23eaf6 100644 --- a/src/MixedComplementarityProblems.jl +++ b/src/MixedComplementarityProblems.jl @@ -3,7 +3,7 @@ module MixedComplementarityProblems using SparseArrays: SparseArrays using FastDifferentiation: FastDifferentiation as FD using Symbolics: Symbolics -using LinearAlgebra: I, norm, eigvals +using LinearAlgebra: LinearAlgebra, I, norm, eigvals using BlockArrays: blocks, blocksizes using TrajectoryGamesBase: to_blockvector diff --git a/src/mcp.jl b/src/mcp.jl index 063ca70..4ca2bc3 100644 --- a/src/mcp.jl +++ b/src/mcp.jl @@ -87,7 +87,7 @@ function PrimalDualMCP( backend_options, ) - (result, x, y, s; θ, ϵ) -> _F(result, [x; y; s; θ; ϵ]) + (result, x, y, s; θ, ϵ) -> _F!(result, [x; y; s; θ; ϵ]) end ∇F_z! = let @@ -101,9 +101,13 @@ function PrimalDualMCP( rows, cols, _ = SparseArrays.findnz(∇F_symbolic) constant_entries = get_constant_entries(∇F_symbolic, z_symbolic) - SparseFunction(rows, cols, size(∇F_symbolic), constant_entries) do (result, x, y, s; θ, ϵ) - _∇F!(result, [x; y; s; θ; ϵ]) - end + SparseFunction( + (result, x, y, s; θ, ϵ) -> _∇F!(result, [x; y; s; θ; ϵ]), + rows, + cols, + size(∇F_symbolic), + constant_entries, + ) end ∇F_θ! = @@ -113,20 +117,19 @@ function PrimalDualMCP( _∇F! = SymbolicUtils.build_function( ∇F_symbolic, [z_symbolic; θ_symbolic; ϵ_symbolic]; - in_place = false, + in_place = true, backend_options, ) rows, cols, _ = SparseArrays.findnz(∇F_symbolic) constant_entries = get_constant_entries(∇F_symbolic, θ_symbolic) SparseFunction( + (result, x, y, s; θ, ϵ) -> _∇F!(result, [x; y; s; θ; ϵ]), rows, cols, size(∇F_symbolic), constant_entries, - ) do (result, x, y, s; θ, ϵ) - _∇F!(result, [x; y; s; θ; ϵ]) - end + ) end PrimalDualMCP(F!, ∇F_z!, ∇F_θ!, length(x_symbolic), length(y_symbolic)) diff --git a/src/solver.jl b/src/solver.jl index 3afde81..9a6125f 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -26,6 +26,16 @@ function solve( max_inner_iters = 20, max_outer_iters = 50, ) + # Set up common memory. + ∇F = get_result_buffer(mcp.∇F_z!) + F = zeros(mcp.unconstrained_dimension + 2mcp.constrained_dimension) + δz = zeros(mcp.unconstrained_dimension + 2mcp.constrained_dimension) + δx = @view δz[1:(mcp.unconstrained_dimension)] + δy = + @view δz[(mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension)] + δs = @view δz[(mcp.unconstrained_dimension + mcp.constrained_dimension + 1):end] + + # Main solver loop. x = x₀ y = y₀ s = s₀ @@ -40,16 +50,15 @@ function solve( while kkt_error > ϵ && inner_iters < max_inner_iters # Compute the Newton step. # TODO! Can add some adaptive regularization. - F = mcp.F(x, y, s; θ, ϵ) - δz = -mcp.∇F_z(x, y, s; θ, ϵ) \ F + mcp.F!(F, x, y, s; θ, ϵ) + mcp.∇F_z!(∇F, x, y, s; θ, ϵ) + LinearAlgebra.ldiv!( + δz, + LinearAlgebra.qr(-collect(∇F), LinearAlgebra.ColumnNorm()), + collect(F), + ) # Fraction to the boundary linesearch. - δx = @view δz[1:(mcp.unconstrained_dimension)] - δy = - @view δz[(mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension)] - δs = - @view δz[(mcp.unconstrained_dimension + mcp.constrained_dimension + 1):end] - α_s = fraction_to_the_boundary_linesearch(s, δs; tol) α_y = fraction_to_the_boundary_linesearch(y, δy; tol) diff --git a/src/sparse_utils.jl b/src/sparse_utils.jl index 8e21748..7fcd066 100644 --- a/src/sparse_utils.jl +++ b/src/sparse_utils.jl @@ -26,7 +26,7 @@ struct SparseFunction{T1,T2} end end -(f::SparseFunction)(args...) = f._f(args...) +(f::SparseFunction)(args...; kwargs...) = f._f(args...; kwargs...) SparseArrays.nnz(f::SparseFunction) = length(f.rows) function get_result_buffer(rows::Vector{Int}, cols::Vector{Int}, size::Tuple{Int,Int}) @@ -37,3 +37,27 @@ end function get_result_buffer(f::SparseFunction) get_result_buffer(f.rows, f.cols, f.size) end + +"Get the (sparse) linear indices of all entries that are constant in the symbolic matrix M w.r.t. symbolic vector z." +function get_constant_entries( + M_symbolic::AbstractMatrix{<:Symbolics.Num}, + z_symbolic::AbstractVector{<:Symbolics.Num}, +) + _z_syms = Symbolics.tosymbol.(z_symbolic) + findall(SparseArrays.nonzeros(M_symbolic)) do v + _vars_syms = Symbolics.tosymbol.(Symbolics.get_variables(v)) + isempty(intersect(_vars_syms, _z_syms)) + end +end + +function get_constant_entries( + M_symbolic::AbstractMatrix{<:FD.Node}, + z_symbolic::AbstractVector{<:FD.Node}, +) + _z_syms = [zs.node_value for zs in FD.variables(z_symbolic)] + # find all entries that are not a function of any of the symbols in z + findall(SparseArrays.nonzeros(M_symbolic)) do v + _vars_syms = [vs.node_value for vs in FD.variables(v)] + isempty(intersect(_vars_syms, _z_syms)) + end +end