From f7fb42bdf459425ecdc27e933f4c1d0c48215111 Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Mon, 27 Nov 2023 06:03:12 +0300 Subject: [PATCH] Add a few missing methods for `AbstractJuMPScalar` to support e.g. `Distances.jl` (#3585) --- src/JuMP.jl | 4 ++++ src/aff_expr.jl | 6 ++++++ src/variables.jl | 6 ++++++ test/test_variable.jl | 25 +++++++++++++++++++++++++ 4 files changed, 41 insertions(+) diff --git a/src/JuMP.jl b/src/JuMP.jl index 1ceb3e88831..659709c6b41 100644 --- a/src/JuMP.jl +++ b/src/JuMP.jl @@ -1049,6 +1049,9 @@ function owner_model end Base.ndims(::Type{<:AbstractJuMPScalar}) = 0 Base.ndims(::AbstractJuMPScalar) = 0 +Base.IteratorEltype(::Type{<:AbstractJuMPScalar}) = Base.HasEltype() +Base.eltype(::Type{T}) where {T<:AbstractJuMPScalar} = T + # These are required to create symmetric containers of AbstractJuMPScalars. LinearAlgebra.symmetric_type(::Type{T}) where {T<:AbstractJuMPScalar} = T LinearAlgebra.hermitian_type(::Type{T}) where {T<:AbstractJuMPScalar} = T @@ -1059,6 +1062,7 @@ LinearAlgebra.adjoint(scalar::AbstractJuMPScalar) = conj(scalar) Base.iterate(x::AbstractJuMPScalar) = (x, true) Base.iterate(::AbstractJuMPScalar, state) = nothing Base.isempty(::AbstractJuMPScalar) = false +Base.length(::AbstractJuMPScalar) = 1 # Check if two arrays of AbstractJuMPScalars are equal. Useful for testing. function isequal_canonical( diff --git a/src/aff_expr.jl b/src/aff_expr.jl index 86570a9ad2b..4c76f6ddc2e 100644 --- a/src/aff_expr.jl +++ b/src/aff_expr.jl @@ -204,8 +204,14 @@ function Base.one(::Type{GenericAffExpr{C,V}}) where {C,V} return GenericAffExpr{C,V}(one(C), OrderedDict{V,C}()) end +function Base.oneunit(::Type{GenericAffExpr{C,V}}) where {C,V} + return GenericAffExpr{C,V}(oneunit(C), OrderedDict{V,C}()) +end + Base.one(a::GenericAffExpr) = one(typeof(a)) +Base.oneunit(a::GenericAffExpr) = oneunit(typeof(a)) + Base.copy(a::GenericAffExpr) = GenericAffExpr(copy(a.constant), copy(a.terms)) Base.broadcastable(a::GenericAffExpr) = Ref(a) diff --git a/src/variables.jl b/src/variables.jl index 203722e9e00..7a4a16b5088 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -361,10 +361,16 @@ end Base.one(v::AbstractVariableRef) = one(typeof(v)) +Base.oneunit(v::AbstractVariableRef) = oneunit(typeof(v)) + function Base.one(::Type{V}) where {V<:AbstractVariableRef} return one(GenericAffExpr{value_type(V),V}) end +function Base.oneunit(::Type{V}) where {V<:AbstractVariableRef} + return oneunit(GenericAffExpr{value_type(V),V}) +end + """ coefficient(v1::GenericVariableRef{T}, v2::GenericVariableRef{T}) where {T} diff --git a/test/test_variable.jl b/test/test_variable.jl index 6f3f50205ae..22c8d951144 100644 --- a/test/test_variable.jl +++ b/test/test_variable.jl @@ -1601,4 +1601,29 @@ function test_bad_bound_types() return end +function test_variable_length() + model = Model() + @variable(model, x) + @test length(x) == 1 + return +end + +function test_variable_eltype() + model = Model() + @variable(model, x) + @test Base.IteratorEltype(x) == Base.HasEltype() + @test Base.eltype(typeof(x)) == typeof(x) + return +end + +function test_variable_one() + model = Model() + @variable(model, x) + @test one(x) == AffExpr(1.0) + @test one(2 * x) == AffExpr(1.0) + @test oneunit(x) == AffExpr(1.0) + @test oneunit(2 * x) == AffExpr(1.0) + return +end + end # module TestVariable