diff --git a/src/MCPSolver.jl b/src/MCPSolver.jl index ac7858d..2c24e98 100644 --- a/src/MCPSolver.jl +++ b/src/MCPSolver.jl @@ -4,11 +4,14 @@ using SparseArrays: SparseArrays using FastDifferentiation: FastDifferentiation as FD using Symbolics: Symbolics using LinearAlgebra: I, norm, eigvals +using BlockArrays: blocks, blocksizes +using TrajectoryGamesBase: to_blockvector using Infiltrator include("SymbolicUtils.jl") include("mcp.jl") include("solver.jl") +include("game.jl") end # module MCPSolver diff --git a/src/game.jl b/src/game.jl index 7adf9fa..141c858 100644 --- a/src/game.jl +++ b/src/game.jl @@ -29,7 +29,6 @@ function ParametricGame(; problems, shared_equality = nothing, shared_inequality = nothing, - parametric_mcp_options = (;), ) N = length(problems) @assert N == length(blocks(test_point)) @@ -155,22 +154,11 @@ function solve( game::ParametricGame, θ; solver_type = InteriorPoint(), - x₀ = nothing, - y₀ = nothing, - tol = 1e-4 + x₀ = zeros(sum(game.dims.x) + sum(game.dims.λ) + game.dims.λ̃), + y₀ = ones(sum(game.dims.μ) + game.dims.μ̃), + tol = 1e-4, ) - initial_x = - !isnothing(x₀) ? x₀ : zeros(sum(game.dims.x) + sum(game.dims.λ) + game.dims.λ̃) - initial_y = !isnothing(y₀) ? y₀ : zeros(sum(game.dims.μ) + game.dims.μ̃) - - (; x, y, s, kkt_error) = solve( - solver_type, - game.mcp; - θ, - x₀ = initial_x, - y₀ = initial_y, - tol - ) + (; x, y, s, kkt_error) = solve(solver_type, game.mcp; θ, x₀, y₀, tol) # Unpack primals per-player for ease of access later. end_dims = cumsum(game.dims.x) @@ -178,7 +166,7 @@ function solve( (ii == 1) ? x[1:end_dims[ii]] : x[(end_dims[ii - 1] + 1):end_dims[ii]] end - (; primals, variables = (; x, y), kkt_error) + (; primals, variables = (; x, y, s), kkt_error) end "Return the number of players in this game." diff --git a/src/solver.jl b/src/solver.jl index 65d650f..d77bd30 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -58,6 +58,8 @@ function solve( y += α_y * δy kkt_error = maximum(abs.(F)) + + @info iters, kkt_error iters += 1 end diff --git a/test/runtests.jl b/test/runtests.jl index 7e9ba71..f95cf8d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -58,15 +58,15 @@ end @testset "ParametricGameTests" begin lim = 0.5 - game = ParametricGame(; + game = MCPSolver.ParametricGame(; test_point = mortar([[1, 1], [1, 1]]), test_parameter = mortar([[1, 1], [1, 1]]), problems = [ - OptimizationProblem(; + MCPSolver.OptimizationProblem(; objective = (x, θi) -> sum((x[Block(1)] - θi) .^ 2), private_inequality = (xi, θi) -> -abs.(xi) .+ lim, ), - OptimizationProblem(; + MCPSolver.OptimizationProblem(; objective = (x, θi) -> sum((x[Block(2)] - θi) .^ 2), private_inequality = (xi, θi) -> -abs.(xi) .+ lim, ), @@ -75,9 +75,9 @@ end θ = mortar([[-1, 0], [1, 1]]) - (; primals, variables, kkt_error) = solve( - game; - θ, + (; primals, variables, kkt_error) = MCPSolver.solve( + game, + θ; tol = 1e-4, )