-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
first cut of symbolics side. working on solver
- Loading branch information
Showing
6 changed files
with
274 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,13 @@ | ||
module MCPSolver | ||
|
||
greet() = print("Hello World!") | ||
using SparseArrays: SparseArrays | ||
using FastDifferentiation: FastDifferentiation as FD | ||
using Symbolics: Symbolics | ||
|
||
include("SymbolicUtils.jl") | ||
include("sparse_utils.jl") | ||
|
||
include("mcp.jl") | ||
include("solver.jl") | ||
|
||
end # module MCPSolver |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
""" | ||
Minimal abstraction on top of `Symbolics.jl` and `FastDifferentiation.jl` to make switching between the two easier. | ||
Taken from: ParametricMCPs.jl. | ||
""" | ||
module SymbolicUtils | ||
|
||
using Symbolics: Symbolics | ||
using FastDifferentiation: FastDifferentiation as FD | ||
|
||
export SymbolicsBackend, FastDifferentiationBackend, make_variables, build_function | ||
|
||
struct SymbolicsBackend end | ||
struct FastDifferentiationBackend end | ||
|
||
""" | ||
make_variables(backend, name, dimension) | ||
Creates a vector of `dimension` where each element is a scalar symbolic variable from `backend` with the given `name`. | ||
""" | ||
function make_variables end | ||
|
||
function make_variables(::SymbolicsBackend, name::Symbol, dimension::Int) | ||
vars = Symbolics.@variables($name[1:dimension]) |> only |> Symbolics.scalarize | ||
|
||
if isempty(vars) | ||
vars = Symbolics.Num[] | ||
end | ||
|
||
vars | ||
end | ||
|
||
function make_variables(::FastDifferentiationBackend, name::Symbol, dimension::Int) | ||
FD.make_variables(name, dimension) | ||
end | ||
|
||
""" | ||
build_function(backend, f_symbolic, args_symbolic...; in_place, options) | ||
Builds a callable function from a symbolic expression `f_symbolic` with the given `args_symbolic` as arguments. | ||
Depending on the `in_place` flag, the function will be built as in-place `f!(result, args...)` or out-of-place variant `restult = f(args...)`. | ||
`backend_options` will be forwarded to the backend specific function and differ between backends. | ||
""" | ||
function build_function end | ||
|
||
function build_function( | ||
f_symbolic::AbstractArray{T}, | ||
args_symbolic...; | ||
in_place, | ||
backend_options = (;), | ||
) where {T<:Symbolics.Num} | ||
f_callable, f_callable! = Symbolics.build_function( | ||
f_symbolic, | ||
args_symbolic...; | ||
expression = Val{false}, | ||
# slightly saner defaults... | ||
(; parallel = Symbolics.ShardedForm(), backend_options...)..., | ||
) | ||
in_place ? f_callable! : f_callable | ||
end | ||
|
||
function build_function( | ||
f_symbolic::AbstractArray{T}, | ||
args_symbolic...; | ||
in_place, | ||
backend_options = (;), | ||
) where {T<:FD.Node} | ||
FD.make_function(f_symbolic, args_symbolic...; in_place, backend_options...) | ||
end | ||
|
||
""" | ||
gradient(f_symbolic, x_symbolic) | ||
Computes the symbolic gradient of `f_symbolic` with respect to `x_symbolic`. | ||
""" | ||
function gradient end | ||
|
||
function gradient(f_symbolic::T, x_symbolic::Vector{T}) where {T<:Symbolics.Num} | ||
Symbolics.gradient(f_symbolic, x_symbolic) | ||
end | ||
|
||
function gradient(f_symbolic::T, x_symbolic::Vector{T}) where {T<:FD.Node} | ||
# FD does not have a gradient utility so we just flatten the jacobian here | ||
vec(FD.jacobian([f_symbolic], x_symbolic)) | ||
end | ||
|
||
""" | ||
jacobian(f_symbolic, x_symbolic) | ||
Computes the symbolic Jacobian of `f_symbolic` with respect to `x_symbolic`. | ||
""" | ||
function jacobian end | ||
|
||
function jacobian(f_symbolic::Vector{T}, x_symbolic::Vector{T}) where {T<:Symbolics.Num} | ||
Symbolics.jacobian(f_symbolic, x_symbolic) | ||
end | ||
|
||
function jacobian(f_symbolic::Vector{T}, x_symbolic::Vector{T}) where {T<:FD.Node} | ||
FD.jacobian([f_symbolic], x_symbolic) | ||
end | ||
|
||
""" | ||
sparse_jacobian(f_symbolic, x_symbolic) | ||
Computes the symbolic Jacobian of `f_symbolic` with respect to `x_symbolic` in a sparse format. | ||
""" | ||
function sparse_jacobian end | ||
|
||
function sparse_jacobian(f_symbolic::Vector{T}, x_symbolic::Vector{T}) where {T<:Symbolics.Num} | ||
Symbolics.sparsejacobian(f_symbolic, x_symbolic) | ||
end | ||
|
||
function sparse_jacobian(f_symbolic::Vector{T}, x_symbolic::Vector{T}) where {T<:FD.Node} | ||
FD.sparse_jacobian(f_symbolic, x_symbolic) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
""" Store key elements of the primal-dual KKT system for a MCP composed of | ||
functions G(.) and H(.) such that | ||
0 = G(x, y) | ||
0 ≤ H(x, y) ⟂ y ≥ 0. | ||
The primal-dual system arises when we introduce slack variable `s` and set | ||
G(x, y) = 0 | ||
H(x, y) - s = 0 | ||
sᵀy - ϵ = 0 | ||
for some ϵ > 0. Define the function `F(z; ϵ)` to return the left hand side of this | ||
system of equations, where `z = [x; y; s]`. | ||
""" | ||
struct PrimalDualMCP{T1,T2} | ||
"A callable `F(z; ϵ)` which computes the KKT error in the primal-dual system." | ||
F::T1 | ||
"A callable `∇F(z; ϵ)` which stores the Jacobian of the KKT error wrt z." | ||
∇F::T2 | ||
end | ||
|
||
"Construct a PrimalDualMCP from symbolic expressions of G(.) and H(.)." | ||
function PrimalDualMCP( | ||
G_symbolic::Vector{T}, | ||
H_symbolic::Vector{T}, | ||
x_symbolic::Vector{T}, | ||
y_symbolic::Vector{T}, | ||
backend = SymbolicUtils.SymbolicsBackend(), | ||
backend_options = (;) | ||
) where {T<:Union{FD.Node,Symbolics.Num}} | ||
# Create symbolic slack variable `s` and parameter `ϵ`. | ||
s_symbolic = SymbolicUtils.make_variables(backend, :s, length(y_symbolic)) | ||
ϵ_symbolic = SymbolicUtils.make_variables(backend, :ϵ, 1) | ||
z_symbolic = [x_symbolic; y_symbolic; s_symbolic] | ||
|
||
F_symbolic = [ | ||
G_symbolic; | ||
H_symbolic - s_symbolic; | ||
sum(s_symbolic .* y_symbolic) - ϵ_symbolic | ||
] | ||
|
||
F = let | ||
_F = SymbolicUtils.build_function( | ||
F_symbolic, | ||
[z_symbolic; ϵ_symbolic]; | ||
in_place = false, | ||
backend_options, | ||
) | ||
|
||
(x, y, s; ϵ) -> _F([x; y; s; ϵ]) | ||
end | ||
|
||
∇F = let | ||
∇F_symbolic = SymbolicUtils.sparse_jacobian(F_symbolic, z_symbolic) | ||
_∇F = SymbolicUtils.build_function( | ||
∇F_symbolic, | ||
[z_symbolic; ϵ_symbolic]; | ||
in_place = false, | ||
backend_options, | ||
) | ||
|
||
# rows, cols, _ = SparseArrays.findnz(∇F_symbolic) | ||
# constant_entries = get_constant_entries(∇F_symbolic, z_symbolic) | ||
# SparseFunction(rows, cols, size(∇F_symbolic), constant_entries) do x, y, s, ϵ | ||
# _∇F([x; y; z; ϵ]) | ||
# end | ||
|
||
(x, y, s; ϵ) -> _∇F([x; y; s; ϵ]) | ||
end | ||
|
||
PrimalDualMCP(F, ∇F) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
abstract type SolverType end | ||
struct InteriorPoint <: SolverType end | ||
|
||
""" Basic interior point solver, based on Nocedal & Wright, ch. 19. | ||
Computes step directions `δz` by solving the relaxed primal-dual system, i.e. | ||
∇F(z; ϵ) δz = -F(z; ϵ). | ||
Given a step direction `δz`, performs a "fraction to the boundary" linesearch, | ||
i.e., for `(x, s)` it chooses step size `α_s` such that | ||
α_s = max(α ∈ [0, 1] : s + δs ≥ (1 - τ) s) | ||
and for `y` it chooses step size `α_s` such that | ||
α_y = max(α ∈ [0, 1] : y + δy ≥ (1 - τ) y). | ||
A typical value of τ is 0.995. Once we converge to ||F(z; \epsilon)|| ≤ ϵ, | ||
we typically decrease ϵ by a factor of 0.1 or 0.2, with smaller values chosen | ||
when the previous subproblem is solved in fewer iterations. | ||
""" | ||
function solve(::InteriorPoint, mcp::MCP, x₀, y₀) | ||
# TODO! | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
""" Sparse utils from ParametricMCPs.jl. """ | ||
|
||
struct SparseFunction{T1,T2} | ||
_f::T1 | ||
result_buffer::T2 | ||
rows::Vector{Int} | ||
cols::Vector{Int} | ||
size::Tuple{Int,Int} | ||
constant_entries::Vector{Int} | ||
function SparseFunction(_f::T1, rows, cols, size, constant_entries = Int[]) where {T1} | ||
length(constant_entries) <= length(rows) || | ||
throw(ArgumentError("More constant entries than non-zero entries.")) | ||
result_buffer = get_result_buffer(rows, cols, size) | ||
new{T1,typeof(result_buffer)}(_f, result_buffer, rows, cols, size, constant_entries) | ||
end | ||
end | ||
|
||
(f::SparseFunction)(args...) = f._f(args...) | ||
SparseArrays.nnz(f::SparseFunction) = length(f.rows) | ||
|
||
function get_result_buffer(rows::Vector{Int}, cols::Vector{Int}, size::Tuple{Int,Int}) | ||
data = zeros(length(rows)) | ||
SparseArrays.sparse(rows, cols, data, size...) | ||
end | ||
|
||
function get_result_buffer(f::SparseFunction) | ||
get_result_buffer(f.rows, f.cols, f.size) | ||
end | ||
|
||
"Get the (sparse) linear indices of all entries that are constant in the symbolic matrix M w.r.t. symbolic vector z." | ||
function get_constant_entries( | ||
M_symbolic::AbstractMatrix{<:Symbolics.Num}, | ||
z_symbolic::AbstractVector{<:Symbolics.Num}, | ||
) | ||
_z_syms = Symbolics.tosymbol.(z_symbolic) | ||
findall(SparseArrays.nonzeros(M_symbolic)) do v | ||
_vars_syms = Symbolics.tosymbol.(Symbolics.get_variables(v)) | ||
isempty(intersect(_vars_syms, _z_syms)) | ||
end | ||
end | ||
|
||
function get_constant_entries( | ||
M_symbolic::AbstractMatrix{<:FD.Node}, | ||
z_symbolic::AbstractVector{<:FD.Node}, | ||
) | ||
_z_syms = FD.variables(z_symbolic) | ||
# find all entries that are not a function of any of the symbols in z | ||
findall(SparseArrays.nonzeros(M_symbolic)) do v | ||
_vars_syms = FD.variables(v) | ||
isempty(intersect(_vars_syms, _z_syms)) | ||
end | ||
end |