From 308d02f98012bb70e6876aa55754c9f5dfd9884c Mon Sep 17 00:00:00 2001 From: dfridovi Date: Tue, 26 Nov 2024 08:49:20 -0600 Subject: [PATCH] fixed ad with s missing from obj --- src/AutoDiff.jl | 9 ++++++++- test/runtests.jl | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/AutoDiff.jl b/src/AutoDiff.jl index 546da6e..bd03cb9 100644 --- a/src/AutoDiff.jl +++ b/src/AutoDiff.jl @@ -52,7 +52,14 @@ function ChainRulesCore.rrule( ∂l∂x = ∂solution.x ∂l∂y = ∂solution.y ∂l∂s = ∂solution.s - project_to_θ(∂z∂θ' * [∂l∂x; ∂l∂y; ∂l∂s]) + + @views project_to_θ( + ∂z∂θ[1:(mcp.unconstrained_dimension), :]' * ∂l∂x + + ∂z∂θ[(mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension), :]' * + ∂l∂y + + ∂z∂θ[(mcp.unconstrained_dimension + mcp.constrained_dimension + 1):end, :]' * + ∂l∂s, + ) end no_grad_args..., ∂θ diff --git a/test/runtests.jl b/test/runtests.jl index 600d5e5..3c220f2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,7 +75,7 @@ using FiniteDiff: FiniteDiff function f(θ) sol = MCPSolver.solve(MCPSolver.InteriorPoint(), mcp, θ) #@infiltrate - sum(sol.x .^ 2) + sum(sol.y .^ 2) + sum(sol.s .^ 2) + sum(sol.x .^ 2) + sum(sol.y .^ 2) end #@infiltrate