From 9bdac76b2a6c74fa040b47d86c52703800271db3 Mon Sep 17 00:00:00 2001 From: lassepe Date: Fri, 6 Dec 2024 14:50:47 +0100 Subject: [PATCH] Fix edge-cases in symbolic backend handling --- src/game.jl | 5 +++-- src/mcp.jl | 8 +++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/game.jl b/src/game.jl index 86a203d..cb8b934 100644 --- a/src/game.jl +++ b/src/game.jl @@ -74,7 +74,8 @@ function ParametricGame(; end # Build MCP representation. - F = Vector{Symbolics.Num}( + symbolic_type = eltype(x) + F = Vector{symbolic_type}( filter!( !isnothing, [ @@ -87,7 +88,7 @@ function ParametricGame(; ), ) - z = Vector{Symbolics.Num}( + z = Vector{symbolic_type}( filter!( !isnothing, [ diff --git a/src/mcp.jl b/src/mcp.jl index 34ae6b1..b89e091 100644 --- a/src/mcp.jl +++ b/src/mcp.jl @@ -60,10 +60,16 @@ function PrimalDualMCP( y_symbolic::Vector{T}, θ_symbolic::Vector{T}; compute_sensitivities = false, - backend = SymbolicUtils.SymbolicsBackend(), backend_options = (;), ) where {T<:Union{FD.Node,Symbolics.Num}} # Create symbolic slack variable `s` and parameter `ϵ`. + if T == FD.Node + backend = SymbolicUtils.FastDifferentiationBackend() + else + @assert T === Symbolics.Num + backend = SymbolicUtils.SymbolicsBackend() + end + s_symbolic = SymbolicUtils.make_variables(backend, :s, length(y_symbolic)) ϵ_symbolic = only(SymbolicUtils.make_variables(backend, :ϵ, 1)) z_symbolic = [x_symbolic; y_symbolic; s_symbolic]