Skip to content

Commit

Permalink
Truncate printing of large expressions (#3575)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Nov 27, 2023
1 parent 9ce7e43 commit 96b4147
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 63 deletions.
91 changes: 30 additions & 61 deletions src/nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,79 +144,48 @@ function _needs_parentheses(x::GenericNonlinearExpr)
return x.head in _PREFIX_OPERATORS && length(x.args) > 1
end

function function_string(::MIME"text/plain", x::GenericNonlinearExpr)
io, stack = IOBuffer(), Any[x]
while !isempty(stack)
arg = pop!(stack)
if arg isa GenericNonlinearExpr
if arg.head in _PREFIX_OPERATORS && length(arg.args) > 1
if _needs_parentheses(arg.args[1])
print(io, "(")
end
if _needs_parentheses(arg.args[end])
push!(stack, ")")
end
for i in length(arg.args):-1:2
push!(stack, arg.args[i])
if _needs_parentheses(arg.args[i])
push!(stack, "(")
end
push!(stack, " $(arg.head) ")
if _needs_parentheses(arg.args[i-1])
push!(stack, ")")
end
end
push!(stack, arg.args[1])
else
print(io, arg.head, "(")
push!(stack, ")")
for i in length(arg.args):-1:2
push!(stack, arg.args[i])
push!(stack, ", ")
end
if length(arg.args) >= 1
push!(stack, arg.args[1])
end
end
else
print(io, arg)
end
end
seekstart(io)
return read(io, String)
end
_parens(::MIME) = "(", ")", "", "", ""
_parens(::MIME"text/latex") = "\\left(", "\\right)", "{", "}", "\\textsf"

function function_string(::MIME"text/latex", x::GenericNonlinearExpr)
function function_string(mime::MIME, x::GenericNonlinearExpr)
p_left, p_right, p_open, p_close, p_textsf = _parens(mime)
io, stack = IOBuffer(), Any[x]
while !isempty(stack)
arg = pop!(stack)
if arg isa GenericNonlinearExpr
if arg.head in _PREFIX_OPERATORS && length(arg.args) > 1
print(io, "{")
push!(stack, "}")
if _needs_parentheses(arg.args[1])
print(io, "\\left(")
end
if _needs_parentheses(arg.args[end])
push!(stack, "\\right)")
end
for i in length(arg.args):-1:2
push!(stack, arg.args[i])
if _needs_parentheses(arg.args[i])
push!(stack, "\\left(")
print(io, p_open)
push!(stack, p_close)
l = ceil(_TERM_LIMIT_FOR_PRINTING[] / 2)
r = floor(_TERM_LIMIT_FOR_PRINTING[] / 2)
skip_indices = (1+l):(length(arg.args)-r)
for i in length(arg.args):-1:1
if i in skip_indices
if i == skip_indices[end]
push!(
stack,
_terms_omitted(mime, length(skip_indices)),
)
push!(stack, " $(arg.head) $p_open")
end
continue
elseif _needs_parentheses(arg.args[i])
push!(stack, p_right)
push!(stack, arg.args[i])
push!(stack, p_left)
else
push!(stack, arg.args[i])
end
push!(stack, "} $(arg.head) {")
if _needs_parentheses(arg.args[i-1])
push!(stack, "\\right)")
if i > 1
push!(stack, "$p_close $(arg.head) $p_open")
end
end
push!(stack, arg.args[1])
else
print(io, "\\textsf{", arg.head, "}\\left({")
push!(stack, "}\\right)")
print(io, p_textsf, p_open, arg.head, p_close, p_left, p_open)
push!(stack, p_close * p_right)
for i in length(arg.args):-1:2
push!(stack, arg.args[i])
push!(stack, "}, {")
push!(stack, "$p_close, $p_open")
end
if length(arg.args) >= 1
push!(stack, arg.args[1])
Expand Down
39 changes: 37 additions & 2 deletions src/print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,41 @@ function _term_string(coef, factor)
end
end

"""
const _TERM_LIMIT_FOR_PRINTING = Ref{Int}(60)
A global constant used to control when terms are omitted when printing
expressions.
Get and set this value using `_TERM_LIMIT_FOR_PRINTING[]`.
```julia
julia> _TERM_LIMIT_FOR_PRINTING[]
60
julia> _TERM_LIMIT_FOR_PRINTING[] = 10
10
```
"""
const _TERM_LIMIT_FOR_PRINTING = Ref{Int}(60)

_terms_omitted(::MIME, n::Int) = "[[...$n terms omitted...]]"

function _terms_omitted(::MIME"text/latex", n::Int)
return "[[\\ldots\\text{$n terms omitted}\\ldots]]"
end

function _terms_to_truncated_string(mode, terms)
m = _TERM_LIMIT_FOR_PRINTING[]
if length(terms) <= 2 * m
return join(terms)
end
k_l = iseven(m) ? m + 1 : m + 2
k_r = iseven(m) ? m - 1 : m - 2
block = _terms_omitted(mode, div(length(terms), 2) - m)
return string(join(terms[1:k_l]), block, join(terms[(end-k_r):end]))
end

# TODO(odow): remove show_constant in JuMP 1.0
function function_string(mode, a::GenericAffExpr, show_constant = true)
if length(linear_terms(a)) == 0
Expand All @@ -616,7 +651,7 @@ function function_string(mode, a::GenericAffExpr, show_constant = true)
terms[2*elm] = _term_string(coef, function_string(mode, var))
end
terms[1] = terms[1] == " - " ? "-" : ""
ret = join(terms)
ret = _terms_to_truncated_string(mode, terms)
if show_constant && !_is_zero_for_printing(a.constant)
ret = string(
ret,
Expand Down Expand Up @@ -645,7 +680,7 @@ function function_string(mode, q::GenericQuadExpr)
terms[2*elm] = _term_string(coef, factor)
end
terms[1] = terms[1] == " - " ? "-" : ""
ret = join(terms)
ret = _terms_to_truncated_string(mode, terms)
aff_str = function_string(mode, q.aff)
if aff_str == "0"
return ret
Expand Down
15 changes: 15 additions & 0 deletions test/test_nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -978,4 +978,19 @@ function test_variable_ref_type()
return
end

function test_printing_truncation()
model = Model()
@variable(model, x[1:100])
y = @expression(model, sum(sin.(x) .* 2))
@test occursin(
"(sin(x[72]) * 2.0) + [[...41 terms omitted...]] + (sin(x[30]) * 2.0)",
function_string(MIME("text/plain"), y),
)
@test occursin(
"{\\left({\\textsf{sin}\\left({x[72]}\\right)} * {2.0}\\right) + {[[\\ldots\\text{41 terms omitted}\\ldots]]} + {\\left({\\textsf{sin}\\left({x[30]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[29]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[28]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[27]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[26]}\\right)} * {2.0}\\right)}",
function_string(MIME("text/latex"), y),
)
return
end

end # module
20 changes: 20 additions & 0 deletions test/test_print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -967,4 +967,24 @@ function test_print_text_latex_interval_set()
return
end

function test_truncated_printing()
model = Model()
@variable(model, x[1:1000])
y = sum(x)
s = function_string(MIME("text/plain"), y)
@test occursin("x[30] + [[...940 terms omitted...]] + x[971]", s)
@test occursin(
"x_{30} + [[\\ldots\\text{940 terms omitted}\\ldots]] + x_{971}",
function_string(MIME("text/latex"), y),
)
ret = JuMP._TERM_LIMIT_FOR_PRINTING[]
JuMP._TERM_LIMIT_FOR_PRINTING[] = 3
@test function_string(MIME("text/plain"), y) ==
"x[1] + x[2] + [[...997 terms omitted...]] + x[1000]"
@test function_string(MIME("text/latex"), y) ==
"x_{1} + x_{2} + [[\\ldots\\text{997 terms omitted}\\ldots]] + x_{1000}"
JuMP._TERM_LIMIT_FOR_PRINTING[] = ret
return
end

end

0 comments on commit 96b4147

Please sign in to comment.