Skip to content

Commit

Permalink
fix link constraints with operators
Browse files Browse the repository at this point in the history
  • Loading branch information
jalving committed Jan 31, 2024
1 parent fc2ebbf commit 81cc04b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 30 deletions.
30 changes: 1 addition & 29 deletions src/node_constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,34 +116,6 @@ function _check_node_variables(
JuMP.GenericNonlinearExpr
}
)
return isempty(setdiff(_node_variables(jump_func), JuMP.all_variables(node)))
return isempty(setdiff(_extract_variables(jump_func), JuMP.all_variables(node)))
end

function _node_variables(jump_func::NodeVariableRef)
return [jump_func]
end

function _node_variables(jump_func::JuMP.GenericAffExpr)
vars = [term[2] for term in JuMP.linear_terms(jump_func)]
return vars
end

function _node_variables(jump_func::JuMP.GenericQuadExpr)
vars_aff = [term[2] for term in JuMP.linear_terms(jump_func)]
vars_quad = vcat([[term[2], term[3]] for term in JuMP.quad_terms(jump_func)]...)
vars_unique = unique([vars_aff;vars_quad])
return vars_unique
end

function _node_variables(jump_func::JuMP.GenericNonlinearExpr)
vars = NodeVariableRef[]
for i = 1:length(jump_func.args)
jump_arg = jump_func.args[i]
if typeof(jump_arg) == JuMP.GenericNonlinearExpr
append!(vars, _node_variables(jump_arg))
elseif typeof(jump_arg) == NodeVariableRef
push!(vars, jump_arg)
end
end
return vars
end
4 changes: 4 additions & 0 deletions src/optiedge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ end

### Utilities for querying variables used in constraints

function _extract_variables(jump_func::NodeVariableRef)
return [jump_func]
end

function _extract_variables(ref::ConstraintRef)
func = JuMP.jump_function(JuMP.constraint_object(ref))
return _extract_variables(func)
Expand Down
3 changes: 2 additions & 1 deletion src/optigraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ function JuMP.add_constraint(
graph::OptiGraph, con::JuMP.AbstractConstraint, name::String=""
)
nodes = _collect_nodes(JuMP.jump_function(con))
@assert length(nodes) > 0
length(nodes) > 1 || error("Cannot create a linking constraint on a single node")
edge = add_edge(graph, nodes...)
con = JuMP.model_convert(edge, con)
Expand All @@ -251,7 +252,7 @@ function _collect_nodes(
JuMP.GenericNonlinearExpr
}
)
vars = _node_variables(jump_func)
vars = _extract_variables(jump_func)
nodes = JuMP.owner_model.(vars)
return collect(nodes)
end
Expand Down

0 comments on commit 81cc04b

Please sign in to comment.