Skip to content

Commit

Permalink
Use SymbolicTracingUtils.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
lassepe committed Dec 16, 2024
1 parent e8618bc commit 608800b
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 228 deletions.
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,23 @@ version = "0.1.4"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
SymbolicTracingUtils = "77ddf47f-b2ab-4ded-95ee-54f4fa148129"
TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
BlockArrays = "0.16.43, 1"
ChainRulesCore = "1.25.0"
DataStructures = "0.18.20"
FastDifferentiation = "0.4.2"
FiniteDiff = "2.26.2"
ForwardDiff = "0.10.38"
LinearAlgebra = "1.11.0"
SparseArrays = "1.11.0"
Symbolics = "6.19.0"
SymbolicTracingUtils = "0.1.1"
TrajectoryGamesBase = "0.3.10"
Zygote = "0.6.73"
julia = "1.11"
5 changes: 3 additions & 2 deletions src/AutoDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using ..MixedComplementarityProblems: MixedComplementarityProblems
using ChainRulesCore: ChainRulesCore
using ForwardDiff: ForwardDiff
using LinearAlgebra: LinearAlgebra
using SymbolicTracingUtils: SymbolicTracingUtils

function _solve_jacobian_θ(mcp::MixedComplementarityProblems.PrimalDualMCP, solution, θ)
!isnothing(mcp.∇F_θ!) || throw(
Expand All @@ -24,13 +25,13 @@ function _solve_jacobian_θ(mcp::MixedComplementarityProblems.PrimalDualMCP, sol
(; x, y, s, ϵ) = solution

∇F_z = let
∇F = MixedComplementarityProblems.get_result_buffer(mcp.∇F_z!)
∇F = SymbolicTracingUtils.get_result_buffer(mcp.∇F_z!)
mcp.∇F_z!(∇F, x, y, s; θ, ϵ)
∇F
end

∇F_θ = let
∇F = MixedComplementarityProblems.get_result_buffer(mcp.∇F_θ!)
∇F = SymbolicTracingUtils.get_result_buffer(mcp.∇F_θ!)
mcp.∇F_θ!(∇F, x, y, s; θ, ϵ)
∇F
end
Expand Down
5 changes: 1 addition & 4 deletions src/MixedComplementarityProblems.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
module MixedComplementarityProblems

using SparseArrays: SparseArrays
using FastDifferentiation: FastDifferentiation as FD
using Symbolics: Symbolics
using LinearAlgebra: LinearAlgebra, I, norm, eigvals
using BlockArrays: blocks, blocksizes
using TrajectoryGamesBase: to_blockvector
using SymbolicTracingUtils: SymbolicTracingUtils as SymbolicTracingUtils

include("SymbolicUtils.jl")
include("sparse_utils.jl")
include("mcp.jl")
include("solver.jl")
include("game.jl")
Expand Down
122 changes: 0 additions & 122 deletions src/SymbolicUtils.jl

This file was deleted.

24 changes: 16 additions & 8 deletions src/game.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,21 @@ function ParametricGame(;

# Define primal and dual variables for the game, and game parameters..
# Note that BlockArrays can handle blocks of zero size.
backend = SymbolicUtils.SymbolicsBackend()
x = SymbolicUtils.make_variables(backend, :x, sum(dims.x)) |> to_blockvector(dims.x)
λ = SymbolicUtils.make_variables(backend, , sum(dims.λ)) |> to_blockvector(dims.λ)
μ = SymbolicUtils.make_variables(backend, , sum(dims.μ)) |> to_blockvector(dims.μ)
λ̃ = SymbolicUtils.make_variables(backend, :λ̃, dims.λ̃)
μ̃ = SymbolicUtils.make_variables(backend, :μ̃, dims.μ̃)
θ = SymbolicUtils.make_variables(backend, , sum(dims.θ)) |> to_blockvector(dims.θ)
backend = SymbolicTracingUtils.SymbolicsBackend()
x =
SymbolicTracingUtils.make_variables(backend, :x, sum(dims.x)) |>
to_blockvector(dims.x)
λ =
SymbolicTracingUtils.make_variables(backend, , sum(dims.λ)) |>
to_blockvector(dims.λ)
μ =
SymbolicTracingUtils.make_variables(backend, , sum(dims.μ)) |>
to_blockvector(dims.μ)
λ̃ = SymbolicTracingUtils.make_variables(backend, :λ̃, dims.λ̃)
μ̃ = SymbolicTracingUtils.make_variables(backend, :μ̃, dims.μ̃)
θ =
SymbolicTracingUtils.make_variables(backend, , sum(dims.θ)) |>
to_blockvector(dims.θ)

# Build symbolic expressions for objectives and constraints.
fs = map(problems, blocks(θ)) do p, θi
Expand All @@ -70,7 +78,7 @@ function ParametricGame(;
L =
f - (isnothing(g) ? 0 : sum(λi .* g)) - (isnothing(h) ? 0 : sum(μi .* h)) - (isnothing(g̃) ? 0 : sum(λ̃ .* g̃)) -
(isnothing(h̃) ? 0 : sum(μ̃ .* h̃))
SymbolicUtils.gradient(L, xi)
SymbolicTracingUtils.gradient(L, xi)
end

# Build MCP representation.
Expand Down
50 changes: 26 additions & 24 deletions src/mcp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ function PrimalDualMCP(
constrained_dimension,
parameter_dimension,
compute_sensitivities = false,
backend = SymbolicUtils.SymbolicsBackend(),
backend = SymbolicTracingUtils.SymbolicsBackend(),
backend_options = (;),
)
x_symbolic = SymbolicUtils.make_variables(backend, :x, unconstrained_dimension)
y_symbolic = SymbolicUtils.make_variables(backend, :y, constrained_dimension)
θ_symbolic = SymbolicUtils.make_variables(backend, , parameter_dimension)
x_symbolic = SymbolicTracingUtils.make_variables(backend, :x, unconstrained_dimension)
y_symbolic = SymbolicTracingUtils.make_variables(backend, :y, constrained_dimension)
θ_symbolic = SymbolicTracingUtils.make_variables(backend, , parameter_dimension)
G_symbolic = G(x_symbolic, y_symbolic; θ = θ_symbolic)
H_symbolic = H(x_symbolic, y_symbolic; θ = θ_symbolic)

Expand All @@ -60,17 +60,17 @@ function PrimalDualMCP(
θ_symbolic::Vector{T};
compute_sensitivities = false,
backend_options = (;),
) where {T<:Union{FD.Node,Symbolics.Num}}
) where {T<:Union{SymbolicTracingUtils.FD.Node,SymbolicTracingUtils.Symbolics.Num}}
# Create symbolic slack variable `s` and parameter `ϵ`.
if T == FD.Node
backend = SymbolicUtils.FastDifferentiationBackend()
if T == SymbolicTracingUtils.FD.Node
backend = SymbolicTracingUtils.FastDifferentiationBackend()
else
@assert T === Symbolics.Num
backend = SymbolicUtils.SymbolicsBackend()
@assert T === SymbolicTracingUtils.Symbolics.Num
backend = SymbolicTracingUtils.SymbolicsBackend()
end

s_symbolic = SymbolicUtils.make_variables(backend, :s, length(y_symbolic))
ϵ_symbolic = only(SymbolicUtils.make_variables(backend, , 1))
s_symbolic = SymbolicTracingUtils.make_variables(backend, :s, length(y_symbolic))
ϵ_symbolic = only(SymbolicTracingUtils.make_variables(backend, , 1))
z_symbolic = [x_symbolic; y_symbolic; s_symbolic]

F_symbolic = [
Expand All @@ -80,7 +80,7 @@ function PrimalDualMCP(
]

F! = let
_F! = SymbolicUtils.build_function(
_F! = SymbolicTracingUtils.build_function(
F_symbolic,
x_symbolic,
y_symbolic,
Expand All @@ -95,8 +95,8 @@ function PrimalDualMCP(
end

∇F_z! = let
∇F_symbolic = SymbolicUtils.sparse_jacobian(F_symbolic, z_symbolic)
_∇F! = SymbolicUtils.build_function(
∇F_symbolic = SymbolicTracingUtils.sparse_jacobian(F_symbolic, z_symbolic)
_∇F! = SymbolicTracingUtils.build_function(
∇F_symbolic,
x_symbolic,
y_symbolic,
Expand All @@ -108,8 +108,9 @@ function PrimalDualMCP(
)

rows, cols, _ = SparseArrays.findnz(∇F_symbolic)
constant_entries = get_constant_entries(∇F_symbolic, z_symbolic)
SparseFunction(
constant_entries =
SymbolicTracingUtils.get_constant_entries(∇F_symbolic, z_symbolic)
SymbolicTracingUtils.SparseFunction(
(result, x, y, s; θ, ϵ) -> _∇F!(result, x, y, s, θ, ϵ),
rows,
cols,
Expand All @@ -121,8 +122,8 @@ function PrimalDualMCP(
∇F_θ! =
!compute_sensitivities ? nothing :
let
∇F_symbolic = SymbolicUtils.sparse_jacobian(F_symbolic, θ_symbolic)
_∇F! = SymbolicUtils.build_function(
∇F_symbolic = SymbolicTracingUtils.sparse_jacobian(F_symbolic, θ_symbolic)
_∇F! = SymbolicTracingUtils.build_function(
∇F_symbolic,
x_symbolic,
y_symbolic,
Expand All @@ -134,8 +135,9 @@ function PrimalDualMCP(
)

rows, cols, _ = SparseArrays.findnz(∇F_symbolic)
constant_entries = get_constant_entries(∇F_symbolic, θ_symbolic)
SparseFunction(
constant_entries =
SymbolicTracingUtils.get_constant_entries(∇F_symbolic, θ_symbolic)
SymbolicTracingUtils.SparseFunction(
(result, x, y, s; θ, ϵ) -> _∇F!(result, x, y, s, θ, ϵ),
rows,
cols,
Expand All @@ -156,11 +158,11 @@ function PrimalDualMCP(
upper_bounds::Vector;
parameter_dimension,
compute_sensitivities = false,
backend = SymbolicUtils.SymbolicsBackend(),
backend = SymbolicTracingUtils.SymbolicsBackend(),
backend_options = (;),
)
z_symbolic = SymbolicUtils.make_variables(backend, :z, length(lower_bounds))
θ_symbolic = SymbolicUtils.make_variables(backend, , parameter_dimension)
z_symbolic = SymbolicTracingUtils.make_variables(backend, :z, length(lower_bounds))
θ_symbolic = SymbolicTracingUtils.make_variables(backend, , parameter_dimension)
K_symbolic = K(z_symbolic; θ = θ_symbolic)

PrimalDualMCP(
Expand All @@ -185,7 +187,7 @@ function PrimalDualMCP(
upper_bounds::Vector;
compute_sensitivities = false,
backend_options = (;),
) where {T<:Union{FD.Node,Symbolics.Num}}
) where {T<:Union{SymbolicTracingUtils.FD.Node,SymbolicTracingUtils.Symbolics.Num}}
@assert all(isinf.(upper_bounds)) && all(isinf.(lower_bounds) .|| lower_bounds .== 0)

unconstrained_indices = findall(isinf, lower_bounds)
Expand Down
2 changes: 1 addition & 1 deletion src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function solve(
max_outer_iters = 50,
)
# Set up common memory.
∇F = get_result_buffer(mcp.∇F_z!)
∇F = SymbolicTracingUtils.get_result_buffer(mcp.∇F_z!)
F = zeros(mcp.unconstrained_dimension + 2mcp.constrained_dimension)
δz = zeros(mcp.unconstrained_dimension + 2mcp.constrained_dimension)
δx = @view δz[1:(mcp.unconstrained_dimension)]
Expand Down
Loading

0 comments on commit 608800b

Please sign in to comment.