Skip to content

Commit

Permalink
Fix bug printing scientific numbers in MIME"text/latex" (jump-dev#3838)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Oct 7, 2024
1 parent 1aa0ff7 commit e78a3a4
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 91 deletions.
67 changes: 40 additions & 27 deletions src/print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,37 @@ function _is_im_for_printing(coef::Complex)
return _is_zero_for_printing(r) && _is_one_for_printing(i)
end

_escape_if_scientific(::MIME, x::String) = x

function _escape_if_scientific(::MIME"text/latex", x::String)
m = match(r"([0-9]+.[0-9]+)e(-?[0-9]+)", x)
if m === nothing
return x
end
return "$(m[1]) \\times 10^{$(m[2])}"
end

# Helper function that rounds carefully for the purposes of printing Reals
# for example, 5.3 => 5.3, and 1.0 => 1
function _string_round(x::Union{Float32,Float64})
function _string_round(mode, x::Union{Float32,Float64})
if isinteger(x) && typemin(Int64) <= x <= typemax(Int64)
return string(round(Int64, x))
end
return string(x)
return _escape_if_scientific(mode, string(x))
end

_string_round(::typeof(abs), x::Real) = _string_round(abs(x))
_string_round(mode, ::typeof(abs), x::Real) = _string_round(mode, abs(x))

_sign_string(x::Real) = x < zero(x) ? " - " : " + "

function _string_round(::typeof(abs), x::Complex)
function _string_round(mode, ::typeof(abs), x::Complex)
r, i = reim(x)
if _is_zero_for_printing(r)
return _string_round(Complex(r, abs(i)))
return _string_round(mode, Complex(r, abs(i)))
elseif _is_zero_for_printing(i)
return _string_round(Complex(abs(r), i))
return _string_round(mode, Complex(abs(r), i))
else
return _string_round(x)
return _string_round(mode, x)
end
end

Expand All @@ -105,15 +115,15 @@ end

# Fallbacks for other number types

_string_round(x::Any) = string(x)
_string_round(mode, x::Any) = string(x)

_string_round(::typeof(abs), x::Any) = _string_round(x)
_string_round(mode, ::typeof(abs), x::Any) = _string_round(mode, x)

_sign_string(::Any) = " + "

function _string_round(x::Complex)
function _string_round(mode, x::Complex)
r, i = reim(x)
r_str = _string_round(r)
r_str = _string_round(mode, r)
if _is_zero_for_printing(i)
return r_str
elseif _is_zero_for_printing(r)
Expand All @@ -124,12 +134,13 @@ function _string_round(x::Complex)
return "im"
end
else
return string(_string_round(i), "im")
return string(_string_round(mode, i), "im")
end
elseif _is_one_for_printing(i)
return string("(", r_str, _sign_string(i), "im)")
else
return string("(", r_str, _sign_string(i), _string_round(abs, i), "im)")
abs_i = _string_round(mode, abs, i)
return string("(", r_str, _sign_string(i), abs_i, "im)")
end
end

Expand Down Expand Up @@ -587,11 +598,11 @@ function nonlinear_constraint_string(
body = nonlinear_expr_string(model, mode, constraint.expression)
lhs = _set_lhs(constraint.set)
rhs = _set_rhs(constraint.set)
output = "$body $(_math_symbol(mode, rhs[1])) $(_string_round(rhs[2]))"
output = "$body $(_math_symbol(mode, rhs[1])) $(_string_round(mode, rhs[2]))"
if lhs === nothing
return output
end
return "$(_string_round(lhs[2])) $(_math_symbol(mode, lhs[1])) $output"
return "$(_string_round(mode, lhs[2])) $(_math_symbol(mode, lhs[1])) $output"
end

"""
Expand Down Expand Up @@ -825,13 +836,13 @@ function function_string(mode::MIME"text/latex", v::AbstractVariableRef)
return var_name
end

function _term_string(coef, factor)
function _term_string(mode, coef, factor)
if _is_one_for_printing(coef)
return factor
elseif _is_im_for_printing(coef)
return string(factor, " ", _string_round(abs, coef))
return string(factor, " ", _string_round(mode, abs, coef))
else
return string(_string_round(abs, coef), " ", factor)
return string(_string_round(mode, abs, coef), " ", factor)
end
end

Expand Down Expand Up @@ -873,20 +884,20 @@ end
# TODO(odow): remove show_constant in JuMP 1.0
function function_string(mode, a::GenericAffExpr, show_constant = true)
if length(linear_terms(a)) == 0
return show_constant ? _string_round(a.constant) : "0"
return show_constant ? _string_round(mode, a.constant) : "0"
end
terms = fill("", 2 * length(linear_terms(a)))
for (elm, (coef, var)) in enumerate(linear_terms(a))
terms[2*elm-1] = _sign_string(coef)
terms[2*elm] = _term_string(coef, function_string(mode, var))
terms[2*elm] = _term_string(mode, coef, function_string(mode, var))
end
terms[1] = terms[1] == " - " ? "-" : ""
ret = _terms_to_truncated_string(mode, terms)
if show_constant && !_is_zero_for_printing(a.constant)
ret = string(
ret,
_sign_string(a.constant),
_string_round(abs, a.constant),
_string_round(mode, abs, a.constant),
)
end
return ret
Expand All @@ -907,7 +918,7 @@ function function_string(mode, q::GenericQuadExpr)
times = mode == MIME("text/latex") ? "\\times " : "*"
factor = string(x, times, y)
end
terms[2*elm] = _term_string(coef, factor)
terms[2*elm] = _term_string(mode, coef, factor)
end
terms[1] = terms[1] == " - " ? "-" : ""
ret = _terms_to_truncated_string(mode, terms)
Expand Down Expand Up @@ -1032,25 +1043,27 @@ julia> in_set_string(MIME("text/plain"), MOI.Interval(1.0, 2.0))
function in_set_string end

function in_set_string(mode::MIME, set::MOI.LessThan)
return string(_math_symbol(mode, :leq), " ", _string_round(set.upper))
return string(_math_symbol(mode, :leq), " ", _string_round(mode, set.upper))
end

function in_set_string(mode::MIME, set::MOI.GreaterThan)
return string(_math_symbol(mode, :geq), " ", _string_round(set.lower))
return string(_math_symbol(mode, :geq), " ", _string_round(mode, set.lower))
end

function in_set_string(mode::MIME, set::MOI.EqualTo)
return string(_math_symbol(mode, :eq), " ", _string_round(set.value))
return string(_math_symbol(mode, :eq), " ", _string_round(mode, set.value))
end

function in_set_string(::MIME"text/latex", set::MOI.Interval)
lower, upper = _string_round(set.lower), _string_round(set.upper)
lower = _string_round(mode, set.lower)
upper = _string_round(mode, set.upper)
return string("\\in [", lower, ", ", upper, "]")
end

function in_set_string(mode::MIME"text/plain", set::MOI.Interval)
in = _math_symbol(mode, :in)
lower, upper = _string_round(set.lower), _string_round(set.upper)
lower = _string_round(mode, set.lower)
upper = _string_round(mode, set.upper)
return string("$in [", lower, ", ", upper, "]")
end

Expand Down
136 changes: 72 additions & 64 deletions test/test_print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -887,73 +887,66 @@ function test_print_hermitian_psd_cone()
end

function test_print_complex_string_round()
@test JuMP._string_round(1.0 + 0.0im) == "1"
@test JuMP._string_round(-1.0 + 0.0im) == "-1"
@test JuMP._string_round(1.0 - 0.0im) == "1"
@test JuMP._string_round(-1.0 - 0.0im) == "-1"
@test JuMP._string_round(0.0 + 1.0im) == "im"
@test JuMP._string_round(-0.0 + 1.0im) == "im"
@test JuMP._string_round(0.0 - 1.0im) == "-im"
@test JuMP._string_round(-0.0 - 1.0im) == "-im"
@test JuMP._string_round(1.0 + 2.0im) == "(1 + 2im)"
@test JuMP._string_round(1.0 - 2.0im) == "(1 - 2im)"
@test JuMP._string_round(-1.0 + 2.0im) == "(-1 + 2im)"
@test JuMP._string_round(-1.0 - 2.0im) == "(-1 - 2im)"
@test JuMP._string_round(1.0 + 1.0im) == "(1 + im)"
@test JuMP._string_round(1.0 - 1.0im) == "(1 - im)"
@test JuMP._string_round(-1.0 + 1.0im) == "(-1 + im)"
@test JuMP._string_round(-1.0 - 1.0im) == "(-1 - im)"
for (test, result) in Any[
1.0+0.0im=>"1",
-1.0+0.0im=>"-1",
1.0-0.0im=>"1",
-1.0-0.0im=>"-1",
0.0+1.0im=>"im",
-0.0+1.0im=>"im",
0.0-1.0im=>"-im",
-0.0-1.0im=>"-im",
1.0+2.0im=>"(1 + 2im)",
1.0-2.0im=>"(1 - 2im)",
-1.0+2.0im=>"(-1 + 2im)",
-1.0-2.0im=>"(-1 - 2im)",
1.0+1.0im=>"(1 + im)",
1.0-1.0im=>"(1 - im)",
-1.0+1.0im=>"(-1 + im)",
-1.0-1.0im=>"(-1 - im)",
]
@test JuMP._string_round(MIME("text/plain"), test) == result
end
return
end

function test_print_huge_integer_string_round()
@test JuMP._string_round(-1 + Float32(typemax(Int32))) == "2147483648"
@test JuMP._string_round(-1 + Float32(typemin(Int32))) == "-2147483648"
@test JuMP._string_round(-1 + Float64(typemax(Int32))) == "2147483646"
@test JuMP._string_round(-1 + Float64(typemin(Int32))) == "-2147483649"

@test JuMP._string_round(Float32(typemax(Int32))) == "2147483648"
@test JuMP._string_round(Float32(typemin(Int32))) == "-2147483648"
@test JuMP._string_round(Float64(typemax(Int32))) == "2147483647"
@test JuMP._string_round(Float64(typemin(Int32))) == "-2147483648"

@test JuMP._string_round(1 + Float32(typemax(Int32))) == "2147483648"
@test JuMP._string_round(1 + Float32(typemin(Int32))) == "-2147483648"
@test JuMP._string_round(1 + Float64(typemax(Int32))) == "2147483648"
@test JuMP._string_round(1 + Float64(typemin(Int32))) == "-2147483647"

@test JuMP._string_round(2 * Float32(typemax(Int32))) == "4294967296"
@test JuMP._string_round(2 * Float32(typemin(Int32))) == "-4294967296"
@test JuMP._string_round(2 * Float64(typemax(Int32))) == "4294967294"
@test JuMP._string_round(2 * Float64(typemin(Int32))) == "-4294967296"

@test JuMP._string_round(-1 + Float32(typemax(Int64))) == "9.223372e18"
@test JuMP._string_round(-1 + Float32(typemin(Int64))) ==
"-9223372036854775808"
@test JuMP._string_round(-1 + Float64(typemax(Int64))) ==
"9.223372036854776e18"
@test JuMP._string_round(-1 + Float64(typemin(Int64))) ==
"-9223372036854775808"

@test JuMP._string_round(Float32(typemax(Int64))) == "9.223372e18"
@test JuMP._string_round(Float32(typemin(Int64))) == "-9223372036854775808"
@test JuMP._string_round(Float64(typemax(Int64))) == "9.223372036854776e18"
@test JuMP._string_round(Float64(typemin(Int64))) == "-9223372036854775808"

@test JuMP._string_round(1 + Float32(typemax(Int64))) == "9.223372e18"
@test JuMP._string_round(1 + Float32(typemin(Int64))) ==
"-9223372036854775808"
@test JuMP._string_round(1 + Float64(typemax(Int64))) ==
"9.223372036854776e18"
@test JuMP._string_round(1 + Float64(typemin(Int64))) ==
"-9223372036854775808"

@test JuMP._string_round(2 * Float32(typemax(Int64))) == "1.8446744e19"
@test JuMP._string_round(2 * Float32(typemin(Int64))) == "-1.8446744e19"
@test JuMP._string_round(2 * Float64(typemax(Int64))) ==
"1.8446744073709552e19"
@test JuMP._string_round(2 * Float64(typemin(Int64))) ==
"-1.8446744073709552e19"
for (test, result) in Any[
-1+Float32(typemax(Int32))=>"2147483648",
-1+Float32(typemin(Int32))=>"-2147483648",
-1+Float64(typemax(Int32))=>"2147483646",
-1+Float64(typemin(Int32))=>"-2147483649",
Float32(typemax(Int32))=>"2147483648",
Float32(typemin(Int32))=>"-2147483648",
Float64(typemax(Int32))=>"2147483647",
Float64(typemin(Int32))=>"-2147483648",
1+Float32(typemax(Int32))=>"2147483648",
1+Float32(typemin(Int32))=>"-2147483648",
1+Float64(typemax(Int32))=>"2147483648",
1+Float64(typemin(Int32))=>"-2147483647",
2*Float32(typemax(Int32))=>"4294967296",
2*Float32(typemin(Int32))=>"-4294967296",
2*Float64(typemax(Int32))=>"4294967294",
2*Float64(typemin(Int32))=>"-4294967296",
-1+Float32(typemax(Int64))=>"9.223372e18",
-1+Float32(typemin(Int64))=>"-9223372036854775808",
-1+Float64(typemax(Int64))=>"9.223372036854776e18",
-1+Float64(typemin(Int64))=>"-9223372036854775808",
Float32(typemax(Int64))=>"9.223372e18",
Float32(typemin(Int64))=>"-9223372036854775808",
Float64(typemax(Int64))=>"9.223372036854776e18",
Float64(typemin(Int64))=>"-9223372036854775808",
1+Float32(typemax(Int64))=>"9.223372e18",
1+Float32(typemin(Int64))=>"-9223372036854775808",
1+Float64(typemax(Int64))=>"9.223372036854776e18",
1+Float64(typemin(Int64))=>"-9223372036854775808",
2*Float32(typemax(Int64))=>"1.8446744e19",
2*Float32(typemin(Int64))=>"-1.8446744e19",
2*Float64(typemax(Int64))=>"1.8446744073709552e19",
2*Float64(typemin(Int64))=>"-1.8446744073709552e19",
]
@test JuMP._string_round(MIME("text/plain"), test) == result
end
return
end

Expand All @@ -965,7 +958,7 @@ function test_print_model_with_huge_integers()
@test sprint(io -> show(io, MIME("text/plain"), c)) == "1.0e20 x $eq 42"
eq = JuMP._math_symbol(MIME("text/latex"), :eq)
@test sprint(io -> show(io, MIME("text/latex"), c)) ==
"\$\$ 1.0e20 x $eq 42 \$\$"
"\$\$ 1.0 \\times 10^{20} x $eq 42 \$\$"
return
end

Expand Down Expand Up @@ -1124,4 +1117,19 @@ function test_show_generic_model_bigfloat()
return
end

function test_small_number_latex()
model = Model()
@variable(model, x)
y = 1e-8 * x
@test function_string(MIME("text/latex"), y) == "1.0 \\times 10^{-8} x"
@test function_string(MIME("text/plain"), y) == "1.0e-8 x"
y = 0.23e-8 * x
@test function_string(MIME("text/latex"), y) == "2.3 \\times 10^{-9} x"
@test function_string(MIME("text/plain"), y) == "2.3e-9 x"
y = 1.23 * x
@test function_string(MIME("text/latex"), y) == "1.23 x"
@test function_string(MIME("text/plain"), y) == "1.23 x"
return
end

end # TestPrint

0 comments on commit e78a3a4

Please sign in to comment.