From ac1138955a24c98a8edfd86d8d1580b1bba87881 Mon Sep 17 00:00:00 2001 From: dfridovi Date: Mon, 25 Nov 2024 20:03:00 -0600 Subject: [PATCH] working on AD. test not passing --- Project.toml | 4 ++++ src/AutoDiff.jl | 31 ++++++++++++++++++++++++------- test/Project.toml | 3 +++ test/runtests.jl | 24 ++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index aca515a..2410ee1 100644 --- a/Project.toml +++ b/Project.toml @@ -8,21 +8,25 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] BlockArrays = "0.16.43" ChainRulesCore = "1.25.0" DataStructures = "0.18.20" FastDifferentiation = "0.4.2" +FiniteDiff = "2.26.2" ForwardDiff = "0.10.38" Infiltrator = "1.8.3" LinearAlgebra = "1.11.0" SparseArrays = "1.11.0" Symbolics = "6.19.0" TrajectoryGamesBase = "0.3.10" +Zygote = "0.6.73" diff --git a/src/AutoDiff.jl b/src/AutoDiff.jl index 6099ddc..3e71e14 100644 --- a/src/AutoDiff.jl +++ b/src/AutoDiff.jl @@ -12,6 +12,10 @@ module AutoDiff using ..MCPSolver: MCPSolver using ChainRulesCore: ChainRulesCore using ForwardDiff: ForwardDiff +using LinearAlgebra: LinearAlgebra +using SparseArrays: SparseArrays + +using Infiltrator function _solve_jacobian_θ(mcp::MCPSolver.PrimalDualMCP, solution, θ) !isnothing(mcp.∇F_θ) || throw( @@ -21,21 +25,34 @@ function _solve_jacobian_θ(mcp::MCPSolver.PrimalDualMCP, solution, θ) ) (; x, y, s, ϵ) = solution - -mcp.∇F_z(x, y, s; θ, ϵ) \ mcp.∇F_θ(x, y, s; θ, ϵ) + ∂z∂θ = -mcp.∇F_z(x, y, s; θ, ϵ) \ mcp.∇F_θ(x, y, s; θ, ϵ) + + SparseArrays.sparse(∂z∂θ) end -function ChainRulesCore.rrule(::typeof(MCPSolver.solve), mcp, θ; kwargs...) - solution = MCPSolver.solve(mcp, θ; kwargs...) +function ChainRulesCore.rrule( + ::typeof(MCPSolver.solve), + solver_type::MCPSolver.SolverType, + mcp::MCPSolver.PrimalDualMCP; + θ, + kwargs..., +) + solution = MCPSolver.solve(solver_type, mcp; θ, kwargs...) project_to_θ = ChainRulesCore.ProjectTo(θ) function solve_pullback(∂solution) - no_grad_args = - (; ∂self = ChainRulesCore.NoTangent(), ∂problem = ChainRulesCore.NoTangent()) + no_grad_args = (; + ∂self = ChainRulesCore.NoTangent(), + ∂solver_type = ChainRulesCore.NoTangent(), + ∂mcp = ChainRulesCore.NoTangent(), + ) ∂θ = ChainRulesCore.@thunk let ∂z∂θ = _solve_jacobian_θ(mcp, solution, θ) - ∂l∂z = ∂solution.z - project_to_θ(∂z∂θ' * ∂l∂z) + ∂l∂x = ∂solution.x + ∂l∂y = ∂solution.y + ∂l∂s = ∂solution.s + project_to_θ(∂z∂θ' * [∂l∂x; ∂l∂y; ∂l∂s]) end no_grad_args..., ∂θ diff --git a/test/Project.toml b/test/Project.toml index 289d603..0ed39c2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,7 @@ [deps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/runtests.jl b/test/runtests.jl index 4af8a07..cb131c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,8 @@ using Test: @testset, @test using MCPSolver using BlockArrays: BlockArray, Block, mortar, blocks +using Zygote: Zygote +using FiniteDiff: FiniteDiff @testset "QPTestProblem" begin """ Test for the following QP: @@ -59,6 +61,28 @@ using BlockArrays: BlockArray, Block, mortar, blocks check_solution(sol) end + + @testset "AutodifferentationTests" begin + mcp = MCPSolver.PrimalDualMCP( + G, + H; + unconstrained_dimension = size(M, 1), + constrained_dimension = length(b), + parameter_dimension = size(M, 1), + compute_sensitivities = true + ) + + function f(θ) + sol = MCPSolver.solve(MCPSolver.InteriorPoint(), mcp; θ) + sum(sol.x .^ 2) + sum(sol.y .^ 2) + end + + ∇_autodiff_reverse = only(Zygote.gradient(f, θ)) + ∇_autodiff_forward = only(Zygote.gradient(θ -> Zygote.forwarddiff(f, θ), θ)) + ∇_finitediff = FiniteDiff.finite_difference_gradient(f, θ) + @test isapprox(∇_autodiff_reverse, ∇_finitediff; atol = 1e-3) + @test isapprox(∇_autodiff_reverse, ∇_autodiff_forward; atol = 1e-3) + end end @testset "ParametricGameTests" begin