From 9a93f0e9d789ce49d29cdd5b05284a0db8f2ba50 Mon Sep 17 00:00:00 2001 From: odow Date: Thu, 12 Oct 2023 15:15:03 +0200 Subject: [PATCH] Restrict rhs to Bool --- src/constraints.jl | 25 +++++++++---------- test/test_constraint.jl | 54 ++++++++--------------------------------- 2 files changed, 21 insertions(+), 58 deletions(-) diff --git a/src/constraints.jl b/src/constraints.jl index 5db80f7ca12..dcb76d0d1c1 100644 --- a/src/constraints.jl +++ b/src/constraints.jl @@ -1549,25 +1549,22 @@ end model_convert(::AbstractModel, set::_DoNotConvertSet) = set -function moi_set(constraint::ScalarConstraint{F,<:_DoNotConvertSet}) where {F} - return constraint.set.set -end - -function _build_boolean_equal_to(::Function, lhs, rhs) - set = _DoNotConvertSet(MOI.EqualTo(true)) - return ScalarConstraint(op_equal_to(lhs, rhs), set) -end +moi_set(c::ScalarConstraint{F,<:_DoNotConvertSet}) where {F} = c.set.set -function _build_boolean_equal_to(::Function, lhs::Bool, rhs) - return ScalarConstraint(rhs, _DoNotConvertSet(MOI.EqualTo(lhs))) +function _build_boolean_equal_to(::Function, lhs::AbstractJuMPScalar, rhs::Bool) + return ScalarConstraint(lhs, _DoNotConvertSet(MOI.EqualTo(rhs))) end -function _build_boolean_equal_to(::Function, lhs, rhs::Bool) - return ScalarConstraint(lhs, _DoNotConvertSet(MOI.EqualTo(rhs))) +function _build_boolean_equal_to(error_fn::Function, ::AbstractJuMPScalar, rhs) + return error_fn( + "cannot add the `:=` constraint. The right-hand side must be a `Bool`", + ) end -function _build_boolean_equal_to(error_fn::Function, lhs::Bool, rhs::Bool) - return error_fn("cannot add the trivial constraint `$lhs := $rhs`") +function _build_boolean_equal_to(error_fn::Function, lhs, ::Any) + return error_fn( + "cannot add the `:=` constraint with left-hand side of type `::$(typeof(lhs))`", + ) end function parse_constraint_head(error_fn::Function, ::Val{:(:=)}, lhs, rhs) diff --git a/test/test_constraint.jl b/test/test_constraint.jl index 02e5dd009e8..de98fd8b331 100644 --- a/test/test_constraint.jl +++ b/test/test_constraint.jl @@ -1721,14 +1721,11 @@ function test_triangle_vec() return end -function test_def_equal_to_operator() - model = GenericModel{Bool}() +function _test_def_equal_to_operator_T(::Type{T}) where {T} + model = GenericModel{T}() @variable(model, x[1:3]) # x[1] := x[2] - c = @constraint(model, x[1] := x[2]) - o = constraint_object(c) - @test isequal_canonical(o.func, op_equal_to(x[1], x[2])) - @test o.set == MOI.EqualTo(true) + @test_throws ErrorException @constraint(model, x[1] := x[2]) # x[1] == x[2] := false c = @constraint(model, x[1] == x[2] := false) o = constraint_object(c) @@ -1752,48 +1749,17 @@ function test_def_equal_to_operator() @test o.set == MOI.EqualTo(y) # y := x[1] || x[2] y = true - c = @constraint(model, y := x[1] || x[2]) - o = constraint_object(c) - @test isequal_canonical(o.func, op_or(x[1], x[2])) - @test o.set == MOI.EqualTo(y) + @test_throws ErrorException @constraint(model, y := x[1] || x[2]) return end function test_def_equal_to_operator_float() - model = Model() - @variable(model, x[1:3]) - # x[1] := x[2] - c = @constraint(model, x[1] := x[2]) - o = constraint_object(c) - @test isequal_canonical(o.func, op_equal_to(x[1], x[2])) - @test o.set == MOI.EqualTo(true) - # x[1] == x[2] := false - c = @constraint(model, x[1] == x[2] := false) - o = constraint_object(c) - @test isequal_canonical(o.func, op_equal_to(x[1], x[2])) - @test o.set == MOI.EqualTo(false) - # x[1] && x[2] := false - c = @constraint(model, x[1] && x[2] := false) - o = constraint_object(c) - @test isequal_canonical(o.func, op_and(x[1], x[2])) - @test o.set == MOI.EqualTo(false) - # x[1] && x[2] := true - c = @constraint(model, x[1] && x[2] := true) - o = constraint_object(c) - @test isequal_canonical(o.func, op_and(x[1], x[2])) - @test o.set == MOI.EqualTo(true) - # x[1] || x[2] := y - y = true - c = @constraint(model, x[1] || x[2] := y) - o = constraint_object(c) - @test isequal_canonical(o.func, op_or(x[1], x[2])) - @test o.set == MOI.EqualTo(true) - # y := x[1] || x[2] - y = true - c = @constraint(model, y := x[1] || x[2]) - o = constraint_object(c) - @test isequal_canonical(o.func, op_or(x[1], x[2])) - @test o.set == MOI.EqualTo(true) + _test_def_equal_to_operator_T(Float64) + return +end + +function test_def_equal_to_operator_bool() + _test_def_equal_to_operator_T(Bool) return end