Skip to content

Commit

Permalink
fixed ad with s missing from obj
Browse files Browse the repository at this point in the history
  • Loading branch information
dfridovi committed Nov 26, 2024
1 parent b8525d7 commit 308d02f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/AutoDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,14 @@ function ChainRulesCore.rrule(
∂l∂x = ∂solution.x
∂l∂y = ∂solution.y
∂l∂s = ∂solution.s
project_to_θ(∂z∂θ' * [∂l∂x; ∂l∂y; ∂l∂s])

@views project_to_θ(
∂z∂θ[1:(mcp.unconstrained_dimension), :]' * ∂l∂x +
∂z∂θ[(mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension), :]' *
∂l∂y +
∂z∂θ[(mcp.unconstrained_dimension + mcp.constrained_dimension + 1):end, :]' *
∂l∂s,
)
end

no_grad_args..., ∂θ
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ using FiniteDiff: FiniteDiff
function f(θ)
sol = MCPSolver.solve(MCPSolver.InteriorPoint(), mcp, θ)
#@infiltrate
sum(sol.x .^ 2) + sum(sol.y .^ 2) + sum(sol.s .^ 2)
sum(sol.x .^ 2) + sum(sol.y .^ 2)
end

#@infiltrate
Expand Down

0 comments on commit 308d02f

Please sign in to comment.