Skip to content

Commit

Permalink
[skip ci] tests not passing for game solver. infinite loop
Browse files Browse the repository at this point in the history
  • Loading branch information
dfridovi committed Nov 21, 2024
1 parent 79e8a48 commit 9ac246d
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 23 deletions.
3 changes: 3 additions & 0 deletions src/MCPSolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ using SparseArrays: SparseArrays
using FastDifferentiation: FastDifferentiation as FD
using Symbolics: Symbolics
using LinearAlgebra: I, norm, eigvals
using BlockArrays: blocks, blocksizes
using TrajectoryGamesBase: to_blockvector

using Infiltrator

include("SymbolicUtils.jl")
include("mcp.jl")
include("solver.jl")
include("game.jl")

end # module MCPSolver
22 changes: 5 additions & 17 deletions src/game.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ function ParametricGame(;
problems,
shared_equality = nothing,
shared_inequality = nothing,
parametric_mcp_options = (;),
)
N = length(problems)
@assert N == length(blocks(test_point))
Expand Down Expand Up @@ -155,30 +154,19 @@ function solve(
game::ParametricGame,
θ;
solver_type = InteriorPoint(),
x₀ = nothing,
y₀ = nothing,
tol = 1e-4
x₀ = zeros(sum(game.dims.x) + sum(game.dims.λ) + game.dims.λ̃),
y₀ = ones(sum(game.dims.μ) + game.dims.μ̃),
tol = 1e-4,
)
initial_x =
!isnothing(x₀) ? x₀ : zeros(sum(game.dims.x) + sum(game.dims.λ) + game.dims.λ̃)
initial_y = !isnothing(y₀) ? y₀ : zeros(sum(game.dims.μ) + game.dims.μ̃)

(; x, y, s, kkt_error) = solve(
solver_type,
game.mcp;
θ,
x₀ = initial_x,
y₀ = initial_y,
tol
)
(; x, y, s, kkt_error) = solve(solver_type, game.mcp; θ, x₀, y₀, tol)

# Unpack primals per-player for ease of access later.
end_dims = cumsum(game.dims.x)
primals = map(1:num_players(game)) do ii
(ii == 1) ? x[1:end_dims[ii]] : x[(end_dims[ii - 1] + 1):end_dims[ii]]
end

(; primals, variables = (; x, y), kkt_error)
(; primals, variables = (; x, y, s), kkt_error)
end

"Return the number of players in this game."
Expand Down
2 changes: 2 additions & 0 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ function solve(
y += α_y * δy

kkt_error = maximum(abs.(F))

@info iters, kkt_error
iters += 1
end

Expand Down
12 changes: 6 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ end
@testset "ParametricGameTests" begin
lim = 0.5

game = ParametricGame(;
game = MCPSolver.ParametricGame(;
test_point = mortar([[1, 1], [1, 1]]),
test_parameter = mortar([[1, 1], [1, 1]]),
problems = [
OptimizationProblem(;
MCPSolver.OptimizationProblem(;
objective = (x, θi) -> sum((x[Block(1)] - θi) .^ 2),
private_inequality = (xi, θi) -> -abs.(xi) .+ lim,
),
OptimizationProblem(;
MCPSolver.OptimizationProblem(;
objective = (x, θi) -> sum((x[Block(2)] - θi) .^ 2),
private_inequality = (xi, θi) -> -abs.(xi) .+ lim,
),
Expand All @@ -75,9 +75,9 @@ end


θ = mortar([[-1, 0], [1, 1]])
(; primals, variables, kkt_error) = solve(
game;
θ,
(; primals, variables, kkt_error) = MCPSolver.solve(
game,
θ;
tol = 1e-4,
)

Expand Down

0 comments on commit 9ac246d

Please sign in to comment.