Skip to content

Commit

Permalink
tests pass, but getting occasional segfaults
Browse files Browse the repository at this point in the history
  • Loading branch information
dfridovi committed Dec 9, 2024
1 parent 2b116db commit a00d205
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 27 deletions.
31 changes: 22 additions & 9 deletions src/AutoDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,27 @@ using ForwardDiff: ForwardDiff
using LinearAlgebra: LinearAlgebra

function _solve_jacobian_θ(mcp::MixedComplementarityProblems.PrimalDualMCP, solution, θ)
!isnothing(mcp.∇F_θ) || throw(
!isnothing(mcp.∇F_θ!) || throw(
ArgumentError(
"Missing sensitivities. Set `compute_sensitivities = true` when constructing the PrimalDualMCP.",
),
)

(; x, y, s, ϵ) = solution
∂z∂θ =
LinearAlgebra.qr(-collect(mcp.∇F_z(x, y, s; θ, ϵ)), LinearAlgebra.ColumnNorm()) \
collect(mcp.∇F_θ(x, y, s; θ, ϵ))

∂z∂θ
∇F_z = let
∇F = MixedComplementarityProblems.get_result_buffer(mcp.∇F_z!)
mcp.∇F_z!(∇F, x, y, s; θ, ϵ)
∇F
end

∇F_θ = let
∇F = MixedComplementarityProblems.get_result_buffer(mcp.∇F_θ!)
mcp.∇F_θ!(∇F, x, y, s; θ, ϵ)
∇F
end

LinearAlgebra.qr(-collect(∇F_z), LinearAlgebra.ColumnNorm()) \ collect(∇F_θ)
end

function ChainRulesCore.rrule(
Expand Down Expand Up @@ -54,10 +63,14 @@ function ChainRulesCore.rrule(

@views project_to_θ(
∂z∂θ[1:(mcp.unconstrained_dimension), :]' * ∂l∂x +
∂z∂θ[(mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension), :]' *
∂l∂y +
∂z∂θ[(mcp.unconstrained_dimension + mcp.constrained_dimension + 1):end, :]' *
∂l∂s,
∂z∂θ[
(mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension),
:,
]' * ∂l∂y +
∂z∂θ[
(mcp.unconstrained_dimension + mcp.constrained_dimension + 1):end,
:,
]' * ∂l∂s,
)
end

Expand Down
2 changes: 1 addition & 1 deletion src/MixedComplementarityProblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MixedComplementarityProblems
using SparseArrays: SparseArrays
using FastDifferentiation: FastDifferentiation as FD
using Symbolics: Symbolics
using LinearAlgebra: I, norm, eigvals
using LinearAlgebra: LinearAlgebra, I, norm, eigvals
using BlockArrays: blocks, blocksizes
using TrajectoryGamesBase: to_blockvector

Expand Down
19 changes: 11 additions & 8 deletions src/mcp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function PrimalDualMCP(
backend_options,
)

(result, x, y, s; θ, ϵ) -> _F(result, [x; y; s; θ; ϵ])
(result, x, y, s; θ, ϵ) -> _F!(result, [x; y; s; θ; ϵ])
end

∇F_z! = let
Expand All @@ -101,9 +101,13 @@ function PrimalDualMCP(

rows, cols, _ = SparseArrays.findnz(∇F_symbolic)
constant_entries = get_constant_entries(∇F_symbolic, z_symbolic)
SparseFunction(rows, cols, size(∇F_symbolic), constant_entries) do (result, x, y, s; θ, ϵ)
_∇F!(result, [x; y; s; θ; ϵ])
end
SparseFunction(
(result, x, y, s; θ, ϵ) -> _∇F!(result, [x; y; s; θ; ϵ]),
rows,
cols,
size(∇F_symbolic),
constant_entries,
)
end

∇F_θ! =
Expand All @@ -113,20 +117,19 @@ function PrimalDualMCP(
_∇F! = SymbolicUtils.build_function(
∇F_symbolic,
[z_symbolic; θ_symbolic; ϵ_symbolic];
in_place = false,
in_place = true,
backend_options,
)

rows, cols, _ = SparseArrays.findnz(∇F_symbolic)
constant_entries = get_constant_entries(∇F_symbolic, θ_symbolic)
SparseFunction(
(result, x, y, s; θ, ϵ) -> _∇F!(result, [x; y; s; θ; ϵ]),
rows,
cols,
size(∇F_symbolic),
constant_entries,
) do (result, x, y, s; θ, ϵ)
_∇F!(result, [x; y; s; θ; ϵ])
end
)
end

PrimalDualMCP(F!, ∇F_z!, ∇F_θ!, length(x_symbolic), length(y_symbolic))
Expand Down
25 changes: 17 additions & 8 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ function solve(
max_inner_iters = 20,
max_outer_iters = 50,
)
# Set up common memory.
∇F = get_result_buffer(mcp.∇F_z!)
F = zeros(mcp.unconstrained_dimension + 2mcp.constrained_dimension)
δz = zeros(mcp.unconstrained_dimension + 2mcp.constrained_dimension)
δx = @view δz[1:(mcp.unconstrained_dimension)]
δy =
@view δz[(mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension)]
δs = @view δz[(mcp.unconstrained_dimension + mcp.constrained_dimension + 1):end]

# Main solver loop.
x = x₀
y = y₀
s = s₀
Expand All @@ -40,16 +50,15 @@ function solve(
while kkt_error > ϵ && inner_iters < max_inner_iters
# Compute the Newton step.
# TODO! Can add some adaptive regularization.
F = mcp.F(x, y, s; θ, ϵ)
δz = -mcp.∇F_z(x, y, s; θ, ϵ) \ F
mcp.F!(F, x, y, s; θ, ϵ)
mcp.∇F_z!(∇F, x, y, s; θ, ϵ)
LinearAlgebra.ldiv!(
δz,
LinearAlgebra.qr(-collect(∇F), LinearAlgebra.ColumnNorm()),
collect(F),
)

# Fraction to the boundary linesearch.
δx = @view δz[1:(mcp.unconstrained_dimension)]
δy =
@view δz[(mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension)]
δs =
@view δz[(mcp.unconstrained_dimension + mcp.constrained_dimension + 1):end]

α_s = fraction_to_the_boundary_linesearch(s, δs; tol)
α_y = fraction_to_the_boundary_linesearch(y, δy; tol)

Expand Down
26 changes: 25 additions & 1 deletion src/sparse_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct SparseFunction{T1,T2}
end
end

(f::SparseFunction)(args...) = f._f(args...)
(f::SparseFunction)(args...; kwargs...) = f._f(args...; kwargs...)
SparseArrays.nnz(f::SparseFunction) = length(f.rows)

function get_result_buffer(rows::Vector{Int}, cols::Vector{Int}, size::Tuple{Int,Int})
Expand All @@ -37,3 +37,27 @@ end
function get_result_buffer(f::SparseFunction)
get_result_buffer(f.rows, f.cols, f.size)
end

"Get the (sparse) linear indices of all entries that are constant in the symbolic matrix M w.r.t. symbolic vector z."
function get_constant_entries(
M_symbolic::AbstractMatrix{<:Symbolics.Num},
z_symbolic::AbstractVector{<:Symbolics.Num},
)
_z_syms = Symbolics.tosymbol.(z_symbolic)
findall(SparseArrays.nonzeros(M_symbolic)) do v
_vars_syms = Symbolics.tosymbol.(Symbolics.get_variables(v))
isempty(intersect(_vars_syms, _z_syms))
end
end

function get_constant_entries(
M_symbolic::AbstractMatrix{<:FD.Node},
z_symbolic::AbstractVector{<:FD.Node},
)
_z_syms = [zs.node_value for zs in FD.variables(z_symbolic)]
# find all entries that are not a function of any of the symbols in z
findall(SparseArrays.nonzeros(M_symbolic)) do v
_vars_syms = [vs.node_value for vs in FD.variables(v)]
isempty(intersect(_vars_syms, _z_syms))
end
end

0 comments on commit a00d205

Please sign in to comment.