From 81cc04bbdcd0ac92e8e5543da6a10d72ce274f27 Mon Sep 17 00:00:00 2001 From: jalving Date: Wed, 31 Jan 2024 13:58:24 -0800 Subject: [PATCH] fix link constraints with operators --- src/node_constraints.jl | 30 +----------------------------- src/optiedge.jl | 4 ++++ src/optigraph.jl | 3 ++- 3 files changed, 7 insertions(+), 30 deletions(-) diff --git a/src/node_constraints.jl b/src/node_constraints.jl index 990ca4a..5aae189 100644 --- a/src/node_constraints.jl +++ b/src/node_constraints.jl @@ -116,34 +116,6 @@ function _check_node_variables( JuMP.GenericNonlinearExpr } ) - return isempty(setdiff(_node_variables(jump_func), JuMP.all_variables(node))) + return isempty(setdiff(_extract_variables(jump_func), JuMP.all_variables(node))) end -function _node_variables(jump_func::NodeVariableRef) - return [jump_func] -end - -function _node_variables(jump_func::JuMP.GenericAffExpr) - vars = [term[2] for term in JuMP.linear_terms(jump_func)] - return vars -end - -function _node_variables(jump_func::JuMP.GenericQuadExpr) - vars_aff = [term[2] for term in JuMP.linear_terms(jump_func)] - vars_quad = vcat([[term[2], term[3]] for term in JuMP.quad_terms(jump_func)]...) - vars_unique = unique([vars_aff;vars_quad]) - return vars_unique -end - -function _node_variables(jump_func::JuMP.GenericNonlinearExpr) - vars = NodeVariableRef[] - for i = 1:length(jump_func.args) - jump_arg = jump_func.args[i] - if typeof(jump_arg) == JuMP.GenericNonlinearExpr - append!(vars, _node_variables(jump_arg)) - elseif typeof(jump_arg) == NodeVariableRef - push!(vars, jump_arg) - end - end - return vars -end \ No newline at end of file diff --git a/src/optiedge.jl b/src/optiedge.jl index 76bd5cc..da846c3 100644 --- a/src/optiedge.jl +++ b/src/optiedge.jl @@ -169,6 +169,10 @@ end ### Utilities for querying variables used in constraints +function _extract_variables(jump_func::NodeVariableRef) + return [jump_func] +end + function _extract_variables(ref::ConstraintRef) func = JuMP.jump_function(JuMP.constraint_object(ref)) return _extract_variables(func) diff --git a/src/optigraph.jl b/src/optigraph.jl index 3b0385b..53d99af 100644 --- a/src/optigraph.jl +++ b/src/optigraph.jl @@ -236,6 +236,7 @@ function JuMP.add_constraint( graph::OptiGraph, con::JuMP.AbstractConstraint, name::String="" ) nodes = _collect_nodes(JuMP.jump_function(con)) + @assert length(nodes) > 0 length(nodes) > 1 || error("Cannot create a linking constraint on a single node") edge = add_edge(graph, nodes...) con = JuMP.model_convert(edge, con) @@ -251,7 +252,7 @@ function _collect_nodes( JuMP.GenericNonlinearExpr } ) - vars = _node_variables(jump_func) + vars = _extract_variables(jump_func) nodes = JuMP.owner_model.(vars) return collect(nodes) end