diff --git a/src/optinode.jl b/src/optinode.jl index 87c92b2..bf23cdd 100644 --- a/src/optinode.jl +++ b/src/optinode.jl @@ -387,7 +387,8 @@ end function _copy_model_to!(node::OptiNode, model::JuMP.Model) if !(num_variables(node) == 0 && num_constraints(node) == 0) - error("An optinode must be empty to set a JuMP Model.") + error("An optinode must be empty to set a JuMP Model. Plasmo.jl does not + yet support re-writing an optinode.") end # get backends src = JuMP.backend(model) @@ -435,5 +436,102 @@ function _copy_model_to!(node::OptiNode, model::JuMP.Model) new_moi_obj_func = MOIU.map_indices(index_map, obj_func) new_obj_func = JuMP.jump_function(node, new_moi_obj_func) JuMP.set_objective(node, JuMP.objective_sense(model), new_obj_func) + + # copy object dictionary + # need to create equivalent containers + for (key, value) in JuMP.object_dictionary(model) + _copy_object_dict_data(index_map, node, key, value) + end + + return index_map +end + +function _copy_object_dict_data( + index_map::MOIU.IndexMap, node::OptiNode, symbol::Symbol, value +) + node[symbol] = value + return nothing +end + +function _copy_object_dict_data( + index_map::MOIU.IndexMap, + node::OptiNode, + symbol::Symbol, + value::Union{JuMP.VariableRef,JuMP.ConstraintRef}, +) + source_index = JuMP.index(value) + dest_index = index_map[source_index] + node_ref = graph_backend(node).graph_to_element_map[dest_index] + node[symbol] = node_ref + return nothing +end + +function _copy_object_dict_data( + index_map::MOIU.IndexMap, + node::OptiNode, + symbol::Symbol, + array::AbstractArray{<:JuMP.VariableRef}, +) + node_data = similar(array, NodeVariableRef) + source_inds = JuMP.index.(array) + for (i, source_index) in enumerate(source_inds) + dest_index = index_map[source_index] + node_ref = graph_backend(node).graph_to_element_map[dest_index] + node_data[i] = node_ref + end + node[symbol] = node_data + return nothing +end + +function _copy_object_dict_data( + index_map::MOIU.IndexMap, + node::OptiNode, + symbol::Symbol, + array::AbstractArray{<:JuMP.ConstraintRef}, +) + node_data = similar(array, ConstraintRef) + source_inds = JuMP.index.(array) + for (i, source_index) in enumerate(source_inds) + dest_index = index_map[source_index] + node_ref = graph_backend(node).graph_to_element_map[dest_index] + node_data[i] = node_ref + end + node[symbol] = node_data + return nothing +end + +function _copy_object_dict_data( + index_map::MOIU.IndexMap, + node::OptiNode, + symbol::Symbol, + daa::JuMP.Containers.DenseAxisArray{<:JuMP.VariableRef}, +) + node_data = similar(daa.data, NodeVariableRef) + source_inds = JuMP.index.(daa.data) + for (i, source_index) in enumerate(source_inds) + dest_index = index_map[source_index] + node_ref = graph_backend(node).graph_to_element_map[dest_index] + node_data[i] = node_ref + end + node_daa = JuMP.Containers.DenseAxisArray(node_data, daa.axes...) + node[symbol] = node_daa + return nothing +end + +function _copy_object_dict_data( + index_map::MOIU.IndexMap, + node::OptiNode, + symbol::Symbol, + daa::JuMP.Containers.DenseAxisArray{<:JuMP.ConstraintRef}, +) + node_data = similar(daa.data, ConstraintRef) + source_inds = JuMP.index.(daa.data) + for (i, source_index) in enumerate(source_inds) + dest_index = index_map[source_index] + node_ref = graph_backend(node).graph_to_element_map[dest_index] + node_data[i] = node_ref + end + node_daa = JuMP.Containers.DenseAxisArray(node_data, daa.axes...) + node[symbol] = node_daa return nothing end diff --git a/test/test_aggregate.jl b/test/test_aggregate.jl index 81b87dc..15dc396 100644 --- a/test/test_aggregate.jl +++ b/test/test_aggregate.jl @@ -41,8 +41,8 @@ function _create_test_model() model = Model() @variable(model, x[1:10] >= 0) @variable(model, y[1:5] >= 2) - @constraint(model, [j = 1:5], x[j] + y[j] <= 10) - @constraint(model, sum(x) <= y[1]^4) + @constraint(model, cons[j=1:5], x[j] + y[j] <= 10) + @constraint(model, sum_con_ref, sum(x) <= y[1]^4) @objective(model, Min, sum(x) + sum(y)^3) return model end @@ -79,6 +79,12 @@ function test_set_model() @test objective_value(m) == objective_value(graph, n1) @test value.(all_variables(m)) == value.(graph, all_variables(n1)) + + node_vars = all_variables(n1) + @test n1[:x] == node_vars[1:10] + @test n1[:y] == node_vars[11:15] + @test (n1, :cons) in keys(Plasmo.node_object_dictionary(n1)) + @test (n1, :sum_con_ref) in keys(Plasmo.node_object_dictionary(n1)) end function test_aggregate_to_depth()