Skip to content

Commit

Permalink
reverse mode works except for handling ZeroTangents
Browse files Browse the repository at this point in the history
  • Loading branch information
dfridovi committed Nov 26, 2024
1 parent ac11389 commit fa8b9eb
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 21 deletions.
22 changes: 10 additions & 12 deletions src/AutoDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,16 @@ function _solve_jacobian_θ(mcp::MCPSolver.PrimalDualMCP, solution, θ)
)

(; x, y, s, ϵ) = solution
∂z∂θ = -mcp.∇F_z(x, y, s; θ, ϵ) \ mcp.∇F_θ(x, y, s; θ, ϵ)
∂z∂θ =
LinearAlgebra.qr(-collect(mcp.∇F_z(x, y, s; θ, ϵ)), LinearAlgebra.ColumnNorm()) \
collect(mcp.∇F_θ(x, y, s; θ, ϵ))

SparseArrays.sparse(∂z∂θ)
∂z∂θ
end

function ChainRulesCore.rrule(
::typeof(MCPSolver.solve),
solver_type::MCPSolver.SolverType,
mcp::MCPSolver.PrimalDualMCP;
θ,
kwargs...,
)
solution = MCPSolver.solve(solver_type, mcp; θ, kwargs...)
function ChainRulesCore.rrule(::typeof(MCPSolver.solve), solver_type, mcp, θ; kwargs...)
println("yoyoyyo")
solution = MCPSolver.solve(solver_type, mcp, θ; kwargs...)
project_to_θ = ChainRulesCore.ProjectTo(θ)

function solve_pullback(∂solution)
Expand All @@ -52,6 +49,7 @@ function ChainRulesCore.rrule(
∂l∂x = ∂solution.x
∂l∂y = ∂solution.y
∂l∂s = ∂solution.s
# @infiltrate
project_to_θ(∂z∂θ' * [∂l∂x; ∂l∂y; ∂l∂s])
end

Expand All @@ -63,8 +61,8 @@ end

function MCPSolver.solve(
solver_type::MCPSolver.SolverType,
mcp::MCPSolver.PrimalDualMCP;
θ::AbstractVector{<:ForwardDiff.Dual{T}},
mcp::MCPSolver.PrimalDualMCP,
θ::AbstractVector{<:ForwardDiff.Dual{T}};
kwargs...,
) where {T}
# strip off the duals
Expand Down
2 changes: 1 addition & 1 deletion src/game.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function solve(
y₀ = ones(sum(game.dims.μ) + game.dims.μ̃),
tol = 1e-4,
)
(; x, y, s, kkt_error, status) = solve(solver_type, game.mcp; θ, x₀, y₀, tol)
(; x, y, s, kkt_error, status) = solve(solver_type, game.mcp, θ; x₀, y₀, tol)

# Unpack primals per-player for ease of access later.
end_dims = cumsum(game.dims.x)
Expand Down
4 changes: 2 additions & 2 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ when the previous subproblem is solved in fewer iterations.
"""
function solve(
::InteriorPoint,
mcp::PrimalDualMCP;
θ,
mcp::PrimalDualMCP,
θ;
x₀ = zeros(mcp.unconstrained_dimension),
y₀ = ones(mcp.constrained_dimension),
tol = 1e-4,
Expand Down
15 changes: 9 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ using FiniteDiff: FiniteDiff
constrained_dimension = length(b),
parameter_dimension = size(M, 1),
)
sol = MCPSolver.solve(MCPSolver.InteriorPoint(), mcp; θ)
sol = MCPSolver.solve(MCPSolver.InteriorPoint(), mcp, θ)

check_solution(sol)
end
Expand All @@ -57,7 +57,7 @@ using FiniteDiff: FiniteDiff
fill(Inf, size(M, 1) + length(b));
parameter_dimension = size(M, 1),
)
sol = MCPSolver.solve(MCPSolver.InteriorPoint(), mcp; θ)
sol = MCPSolver.solve(MCPSolver.InteriorPoint(), mcp, θ)

check_solution(sol)
end
Expand All @@ -73,15 +73,18 @@ using FiniteDiff: FiniteDiff
)

function f(θ)
sol = MCPSolver.solve(MCPSolver.InteriorPoint(), mcp; θ)
sum(sol.x .^ 2) + sum(sol.y .^ 2)
sol = MCPSolver.solve(MCPSolver.InteriorPoint(), mcp, θ)
#@infiltrate
sum(sol.x .^ 2) + sum(sol.y .^ 2) + sum(sol.s .^ 2)
end

#@infiltrate

∇_autodiff_reverse = only(Zygote.gradient(f, θ))
∇_autodiff_forward = only(Zygote.gradient-> Zygote.forwarddiff(f, θ), θ))
#∇_autodiff_forward = only(Zygote.gradient(θ -> Zygote.forwarddiff(f, θ), θ))
∇_finitediff = FiniteDiff.finite_difference_gradient(f, θ)
@test isapprox(∇_autodiff_reverse, ∇_finitediff; atol = 1e-3)
@test isapprox(∇_autodiff_reverse, ∇_autodiff_forward; atol = 1e-3)
#@test isapprox(∇_autodiff_reverse, ∇_autodiff_forward; atol = 1e-3)
end
end

Expand Down

0 comments on commit fa8b9eb

Please sign in to comment.