Skip to content

Commit

Permalink
first cut of symbolics side. working on solver
Browse files Browse the repository at this point in the history
  • Loading branch information
dfridovi committed Nov 15, 2024
1 parent 1616d9a commit 517b9cd
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 1 deletion.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ version = "0.1.0"
[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574"

[compat]
BlockArrays = "0.16.43"
DataStructures = "0.18.20"
FastDifferentiation = "0.4.2"
SparseArrays = "1.11.0"
Symbolics = "6.19.0"
TrajectoryGamesBase = "0.3.10"
10 changes: 9 additions & 1 deletion src/MCPSolver.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
module MCPSolver

greet() = print("Hello World!")
using SparseArrays: SparseArrays
using FastDifferentiation: FastDifferentiation as FD
using Symbolics: Symbolics

include("SymbolicUtils.jl")
include("sparse_utils.jl")

include("mcp.jl")
include("solver.jl")

end # module MCPSolver
119 changes: 119 additions & 0 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Minimal abstraction on top of `Symbolics.jl` and `FastDifferentiation.jl` to make switching between the two easier.
Taken from: ParametricMCPs.jl.
"""
module SymbolicUtils

using Symbolics: Symbolics
using FastDifferentiation: FastDifferentiation as FD

export SymbolicsBackend, FastDifferentiationBackend, make_variables, build_function

struct SymbolicsBackend end
struct FastDifferentiationBackend end

"""
make_variables(backend, name, dimension)
Creates a vector of `dimension` where each element is a scalar symbolic variable from `backend` with the given `name`.
"""
function make_variables end

function make_variables(::SymbolicsBackend, name::Symbol, dimension::Int)
vars = Symbolics.@variables($name[1:dimension]) |> only |> Symbolics.scalarize

if isempty(vars)
vars = Symbolics.Num[]
end

vars
end

function make_variables(::FastDifferentiationBackend, name::Symbol, dimension::Int)
FD.make_variables(name, dimension)
end

"""
build_function(backend, f_symbolic, args_symbolic...; in_place, options)
Builds a callable function from a symbolic expression `f_symbolic` with the given `args_symbolic` as arguments.
Depending on the `in_place` flag, the function will be built as in-place `f!(result, args...)` or out-of-place variant `restult = f(args...)`.
`backend_options` will be forwarded to the backend specific function and differ between backends.
"""
function build_function end

function build_function(
f_symbolic::AbstractArray{T},
args_symbolic...;
in_place,
backend_options = (;),
) where {T<:Symbolics.Num}
f_callable, f_callable! = Symbolics.build_function(
f_symbolic,
args_symbolic...;
expression = Val{false},
# slightly saner defaults...
(; parallel = Symbolics.ShardedForm(), backend_options...)...,
)
in_place ? f_callable! : f_callable
end

function build_function(
f_symbolic::AbstractArray{T},
args_symbolic...;
in_place,
backend_options = (;),
) where {T<:FD.Node}
FD.make_function(f_symbolic, args_symbolic...; in_place, backend_options...)
end

"""
gradient(f_symbolic, x_symbolic)
Computes the symbolic gradient of `f_symbolic` with respect to `x_symbolic`.
"""
function gradient end

function gradient(f_symbolic::T, x_symbolic::Vector{T}) where {T<:Symbolics.Num}
Symbolics.gradient(f_symbolic, x_symbolic)
end

function gradient(f_symbolic::T, x_symbolic::Vector{T}) where {T<:FD.Node}
# FD does not have a gradient utility so we just flatten the jacobian here
vec(FD.jacobian([f_symbolic], x_symbolic))
end

"""
jacobian(f_symbolic, x_symbolic)
Computes the symbolic Jacobian of `f_symbolic` with respect to `x_symbolic`.
"""
function jacobian end

function jacobian(f_symbolic::Vector{T}, x_symbolic::Vector{T}) where {T<:Symbolics.Num}
Symbolics.jacobian(f_symbolic, x_symbolic)
end

function jacobian(f_symbolic::Vector{T}, x_symbolic::Vector{T}) where {T<:FD.Node}
FD.jacobian([f_symbolic], x_symbolic)
end

"""
sparse_jacobian(f_symbolic, x_symbolic)
Computes the symbolic Jacobian of `f_symbolic` with respect to `x_symbolic` in a sparse format.
"""
function sparse_jacobian end

function sparse_jacobian(f_symbolic::Vector{T}, x_symbolic::Vector{T}) where {T<:Symbolics.Num}
Symbolics.sparsejacobian(f_symbolic, x_symbolic)
end

function sparse_jacobian(f_symbolic::Vector{T}, x_symbolic::Vector{T}) where {T<:FD.Node}
FD.sparse_jacobian(f_symbolic, x_symbolic)
end

end
70 changes: 70 additions & 0 deletions src/mcp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
""" Store key elements of the primal-dual KKT system for a MCP composed of
functions G(.) and H(.) such that
0 = G(x, y)
0 ≤ H(x, y) ⟂ y ≥ 0.
The primal-dual system arises when we introduce slack variable `s` and set
G(x, y) = 0
H(x, y) - s = 0
sᵀy - ϵ = 0
for some ϵ > 0. Define the function `F(z; ϵ)` to return the left hand side of this
system of equations, where `z = [x; y; s]`.
"""
struct PrimalDualMCP{T1,T2}
"A callable `F(z; ϵ)` which computes the KKT error in the primal-dual system."
F::T1
"A callable `∇F(z; ϵ)` which stores the Jacobian of the KKT error wrt z."
∇F::T2
end

"Construct a PrimalDualMCP from symbolic expressions of G(.) and H(.)."
function PrimalDualMCP(
G_symbolic::Vector{T},
H_symbolic::Vector{T},
x_symbolic::Vector{T},
y_symbolic::Vector{T},
backend = SymbolicUtils.SymbolicsBackend(),
backend_options = (;)
) where {T<:Union{FD.Node,Symbolics.Num}}
# Create symbolic slack variable `s` and parameter `ϵ`.
s_symbolic = SymbolicUtils.make_variables(backend, :s, length(y_symbolic))
ϵ_symbolic = SymbolicUtils.make_variables(backend, , 1)
z_symbolic = [x_symbolic; y_symbolic; s_symbolic]

F_symbolic = [
G_symbolic;
H_symbolic - s_symbolic;
sum(s_symbolic .* y_symbolic) - ϵ_symbolic
]

F = let
_F = SymbolicUtils.build_function(
F_symbolic,
[z_symbolic; ϵ_symbolic];
in_place = false,
backend_options,
)

(x, y, s; ϵ) -> _F([x; y; s; ϵ])
end

∇F = let
∇F_symbolic = SymbolicUtils.sparse_jacobian(F_symbolic, z_symbolic)
_∇F = SymbolicUtils.build_function(
∇F_symbolic,
[z_symbolic; ϵ_symbolic];
in_place = false,
backend_options,
)

# rows, cols, _ = SparseArrays.findnz(∇F_symbolic)
# constant_entries = get_constant_entries(∇F_symbolic, z_symbolic)
# SparseFunction(rows, cols, size(∇F_symbolic), constant_entries) do x, y, s, ϵ
# _∇F([x; y; z; ϵ])
# end

(x, y, s; ϵ) -> _∇F([x; y; s; ϵ])
end

PrimalDualMCP(F, ∇F)
end
20 changes: 20 additions & 0 deletions src/solver.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
abstract type SolverType end
struct InteriorPoint <: SolverType end

""" Basic interior point solver, based on Nocedal & Wright, ch. 19.
Computes step directions `δz` by solving the relaxed primal-dual system, i.e.
∇F(z; ϵ) δz = -F(z; ϵ).
Given a step direction `δz`, performs a "fraction to the boundary" linesearch,
i.e., for `(x, s)` it chooses step size `α_s` such that
α_s = max(α ∈ [0, 1] : s + δs ≥ (1 - τ) s)
and for `y` it chooses step size `α_s` such that
α_y = max(α ∈ [0, 1] : y + δy ≥ (1 - τ) y).
A typical value of τ is 0.995. Once we converge to ||F(z; \epsilon)|| ≤ ϵ,
we typically decrease ϵ by a factor of 0.1 or 0.2, with smaller values chosen
when the previous subproblem is solved in fewer iterations.
"""
function solve(::InteriorPoint, mcp::MCP, x₀, y₀)
# TODO!
end
52 changes: 52 additions & 0 deletions src/sparse_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
""" Sparse utils from ParametricMCPs.jl. """

struct SparseFunction{T1,T2}
_f::T1
result_buffer::T2
rows::Vector{Int}
cols::Vector{Int}
size::Tuple{Int,Int}
constant_entries::Vector{Int}
function SparseFunction(_f::T1, rows, cols, size, constant_entries = Int[]) where {T1}
length(constant_entries) <= length(rows) ||
throw(ArgumentError("More constant entries than non-zero entries."))
result_buffer = get_result_buffer(rows, cols, size)
new{T1,typeof(result_buffer)}(_f, result_buffer, rows, cols, size, constant_entries)
end
end

(f::SparseFunction)(args...) = f._f(args...)
SparseArrays.nnz(f::SparseFunction) = length(f.rows)

function get_result_buffer(rows::Vector{Int}, cols::Vector{Int}, size::Tuple{Int,Int})
data = zeros(length(rows))
SparseArrays.sparse(rows, cols, data, size...)
end

function get_result_buffer(f::SparseFunction)
get_result_buffer(f.rows, f.cols, f.size)
end

"Get the (sparse) linear indices of all entries that are constant in the symbolic matrix M w.r.t. symbolic vector z."
function get_constant_entries(
M_symbolic::AbstractMatrix{<:Symbolics.Num},
z_symbolic::AbstractVector{<:Symbolics.Num},
)
_z_syms = Symbolics.tosymbol.(z_symbolic)
findall(SparseArrays.nonzeros(M_symbolic)) do v
_vars_syms = Symbolics.tosymbol.(Symbolics.get_variables(v))
isempty(intersect(_vars_syms, _z_syms))
end
end

function get_constant_entries(
M_symbolic::AbstractMatrix{<:FD.Node},
z_symbolic::AbstractVector{<:FD.Node},
)
_z_syms = FD.variables(z_symbolic)
# find all entries that are not a function of any of the symbols in z
findall(SparseArrays.nonzeros(M_symbolic)) do v
_vars_syms = FD.variables(v)
isempty(intersect(_vars_syms, _z_syms))
end
end

0 comments on commit 517b9cd

Please sign in to comment.