From 19818ef4bf47b6ea4612b1a3a4bd8b2a2fe89e15 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Sun, 10 Dec 2023 09:24:47 +1300 Subject: [PATCH] Fix text/latex printing of GenericNonlinearExpr (#3609) --- src/nlp_expr.jl | 22 +++++++++++++++++++--- test/test_nlp_expr.jl | 19 ++++++++++++++++++- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/nlp_expr.jl b/src/nlp_expr.jl index 96f3b120fa0..84b4c6f2bb3 100644 --- a/src/nlp_expr.jl +++ b/src/nlp_expr.jl @@ -147,12 +147,26 @@ end _parens(::MIME) = "(", ")", "", "", "" _parens(::MIME"text/latex") = "\\left(", "\\right)", "{", "}", "\\textsf" +""" + op_string(mime::MIME, x::GenericNonlinearExpr, ::Val{op}) where {op} + +Return the string that should be printed for the operator `op` when +[`function_string`](@ref) is called with `mime` and `x`. +""" +op_string(::MIME, ::GenericNonlinearExpr, ::Val{op}) where {op} = string(op) +op_string(::MIME"text/latex", ::GenericNonlinearExpr, ::Val{:&&}) = "\\wedge" +op_string(::MIME"text/latex", ::GenericNonlinearExpr, ::Val{:||}) = "\\vee" +op_string(::MIME"text/latex", ::GenericNonlinearExpr, ::Val{:<=}) = "\\le" +op_string(::MIME"text/latex", ::GenericNonlinearExpr, ::Val{:>=}) = "\\ge" +op_string(::MIME"text/latex", ::GenericNonlinearExpr, ::Val{:(==)}) = "=" + function function_string(mime::MIME, x::GenericNonlinearExpr) p_left, p_right, p_open, p_close, p_textsf = _parens(mime) io, stack = IOBuffer(), Any[x] while !isempty(stack) arg = pop!(stack) if arg isa GenericNonlinearExpr + op = op_string(mime, arg, Val(arg.head)) if arg.head in _PREFIX_OPERATORS && length(arg.args) > 1 print(io, p_open) push!(stack, p_close) @@ -166,7 +180,7 @@ function function_string(mime::MIME, x::GenericNonlinearExpr) stack, _terms_omitted(mime, length(skip_indices)), ) - push!(stack, " $(arg.head) $p_open") + push!(stack, " $op $p_open") end continue elseif _needs_parentheses(arg.args[i]) @@ -177,11 +191,11 @@ function function_string(mime::MIME, x::GenericNonlinearExpr) push!(stack, arg.args[i]) end if i > 1 - push!(stack, "$p_close $(arg.head) $p_open") + push!(stack, "$p_close $op $p_open") end end else - print(io, p_textsf, p_open, arg.head, p_close, p_left, p_open) + print(io, p_textsf, p_open, op, p_close, p_left, p_open) push!(stack, p_close * p_right) for i in length(arg.args):-1:2 push!(stack, arg.args[i]) @@ -191,6 +205,8 @@ function function_string(mime::MIME, x::GenericNonlinearExpr) push!(stack, arg.args[1]) end end + elseif arg isa AbstractJuMPScalar + print(io, function_string(mime, arg)) else print(io, arg) end diff --git a/test/test_nlp_expr.jl b/test/test_nlp_expr.jl index da60d3a13bd..747ce29d344 100644 --- a/test/test_nlp_expr.jl +++ b/test/test_nlp_expr.jl @@ -155,6 +155,23 @@ function test_extension_latex(ModelType = Model, VariableRefType = VariableRef) return end +function test_extension_latex2(ModelType = Model, VariableRefType = VariableRef) + model = ModelType() + @variable(model, x[1:2]) + @test function_string(MIME("text/latex"), sin(x[1])) == + raw"\textsf{sin}\left({x_{1}}\right)" + @test function_string(MIME("text/latex"), sin(x[1])^x[2]) == + raw"{\textsf{sin}\left({x_{1}}\right)} ^ {x_{2}}" + @expression( + model, + expr, + (x[1] <= x[2]) || (x[1] >= x[2]) && (x[1] == x[1]), + ) + @test function_string(MIME("text/latex"), expr) == + raw"{\left({x_{1}} \le {x_{2}}\right)} \vee {\left({\left({x_{1}} \ge {x_{2}}\right)} \wedge {\left({x_{1}} = {x_{1}}\right)}\right)}" + return +end + function test_extension_expression_addmul( ModelType = Model, VariableRefType = VariableRef, @@ -987,7 +1004,7 @@ function test_printing_truncation() function_string(MIME("text/plain"), y), ) @test occursin( - "{\\left({\\textsf{sin}\\left({x[72]}\\right)} * {2.0}\\right) + {[[\\ldots\\text{41 terms omitted}\\ldots]]} + {\\left({\\textsf{sin}\\left({x[30]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[29]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[28]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[27]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[26]}\\right)} * {2.0}\\right)}", + "{\\left({\\textsf{sin}\\left({x_{72}}\\right)} * {2.0}\\right) + {[[\\ldots\\text{41 terms omitted}\\ldots]]} + {\\left({\\textsf{sin}\\left({x_{30}}\\right)} * {2.0}\\right)}", function_string(MIME("text/latex"), y), ) return