diff --git a/Project.toml b/Project.toml index 8e604bb..aca515a 100644 --- a/Project.toml +++ b/Project.toml @@ -5,8 +5,10 @@ version = "0.1.0" [deps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -15,8 +17,10 @@ TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574" [compat] BlockArrays = "0.16.43" +ChainRulesCore = "1.25.0" DataStructures = "0.18.20" FastDifferentiation = "0.4.2" +ForwardDiff = "0.10.38" Infiltrator = "1.8.3" LinearAlgebra = "1.11.0" SparseArrays = "1.11.0" diff --git a/src/AutoDiff.jl b/src/AutoDiff.jl index 4a8bf35..6099ddc 100644 --- a/src/AutoDiff.jl +++ b/src/AutoDiff.jl @@ -62,8 +62,8 @@ function MCPSolver.solve( # glue forward and backward pass together into dual number types z_d = ForwardDiff.Dual{T}.(solution.z, z_p) x_d = @view z_d[1:(mcp.unconstrained_dimension)] - y_d = @view - z_d[(mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension)] + y_d = + (@view z_d[(mcp.unconstrained_dimension + 1):(mcp.unconstrained_dimension + mcp.constrained_dimension)]) s_d = @view z_d[(mcp.unconstrained_dimension + mcp.constrained_dimension + 1):end] (; solution.status, solution.kkt_error, solution.ϵ, x = x_d, y = y_d, s = s_d)