Skip to content

Commit

Permalink
working on AD. test not passing
Browse files Browse the repository at this point in the history
  • Loading branch information
dfridovi committed Nov 26, 2024
1 parent 801a505 commit ac11389
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 7 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,25 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
BlockArrays = "0.16.43"
ChainRulesCore = "1.25.0"
DataStructures = "0.18.20"
FastDifferentiation = "0.4.2"
FiniteDiff = "2.26.2"
ForwardDiff = "0.10.38"
Infiltrator = "1.8.3"
LinearAlgebra = "1.11.0"
SparseArrays = "1.11.0"
Symbolics = "6.19.0"
TrajectoryGamesBase = "0.3.10"
Zygote = "0.6.73"
31 changes: 24 additions & 7 deletions src/AutoDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ module AutoDiff
using ..MCPSolver: MCPSolver
using ChainRulesCore: ChainRulesCore
using ForwardDiff: ForwardDiff
using LinearAlgebra: LinearAlgebra
using SparseArrays: SparseArrays

using Infiltrator

function _solve_jacobian_θ(mcp::MCPSolver.PrimalDualMCP, solution, θ)
!isnothing(mcp.∇F_θ) || throw(
Expand All @@ -21,21 +25,34 @@ function _solve_jacobian_θ(mcp::MCPSolver.PrimalDualMCP, solution, θ)
)

(; x, y, s, ϵ) = solution
-mcp.∇F_z(x, y, s; θ, ϵ) \ mcp.∇F_θ(x, y, s; θ, ϵ)
∂z∂θ = -mcp.∇F_z(x, y, s; θ, ϵ) \ mcp.∇F_θ(x, y, s; θ, ϵ)

SparseArrays.sparse(∂z∂θ)
end

function ChainRulesCore.rrule(::typeof(MCPSolver.solve), mcp, θ; kwargs...)
solution = MCPSolver.solve(mcp, θ; kwargs...)
function ChainRulesCore.rrule(
::typeof(MCPSolver.solve),
solver_type::MCPSolver.SolverType,
mcp::MCPSolver.PrimalDualMCP;
θ,
kwargs...,
)
solution = MCPSolver.solve(solver_type, mcp; θ, kwargs...)
project_to_θ = ChainRulesCore.ProjectTo(θ)

function solve_pullback(∂solution)
no_grad_args =
(; ∂self = ChainRulesCore.NoTangent(), ∂problem = ChainRulesCore.NoTangent())
no_grad_args = (;
∂self = ChainRulesCore.NoTangent(),
∂solver_type = ChainRulesCore.NoTangent(),
∂mcp = ChainRulesCore.NoTangent(),
)

∂θ = ChainRulesCore.@thunk let
∂z∂θ = _solve_jacobian_θ(mcp, solution, θ)
∂l∂z = ∂solution.z
project_to_θ(∂z∂θ' * ∂l∂z)
∂l∂x = ∂solution.x
∂l∂y = ∂solution.y
∂l∂s = ∂solution.s
project_to_θ(∂z∂θ' * [∂l∂x; ∂l∂y; ∂l∂s])
end

no_grad_args..., ∂θ
Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
24 changes: 24 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using Test: @testset, @test

using MCPSolver
using BlockArrays: BlockArray, Block, mortar, blocks
using Zygote: Zygote
using FiniteDiff: FiniteDiff

@testset "QPTestProblem" begin
""" Test for the following QP:
Expand Down Expand Up @@ -59,6 +61,28 @@ using BlockArrays: BlockArray, Block, mortar, blocks

check_solution(sol)
end

@testset "AutodifferentationTests" begin
mcp = MCPSolver.PrimalDualMCP(
G,
H;
unconstrained_dimension = size(M, 1),
constrained_dimension = length(b),
parameter_dimension = size(M, 1),
compute_sensitivities = true
)

function f(θ)
sol = MCPSolver.solve(MCPSolver.InteriorPoint(), mcp; θ)
sum(sol.x .^ 2) + sum(sol.y .^ 2)
end

∇_autodiff_reverse = only(Zygote.gradient(f, θ))
∇_autodiff_forward = only(Zygote.gradient-> Zygote.forwarddiff(f, θ), θ))
∇_finitediff = FiniteDiff.finite_difference_gradient(f, θ)
@test isapprox(∇_autodiff_reverse, ∇_finitediff; atol = 1e-3)
@test isapprox(∇_autodiff_reverse, ∇_autodiff_forward; atol = 1e-3)
end
end

@testset "ParametricGameTests" begin
Expand Down

0 comments on commit ac11389

Please sign in to comment.