diff --git a/docs/make.jl b/docs/make.jl index ccfc75e8122..89ad5f7d1c7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -245,7 +245,15 @@ jump_api_reference = DocumenterReference.automatic_reference_documentation(; "Containers.nested" => DocumenterReference.DOCTYPE_FUNCTION, "Containers.vectorized_product" => DocumenterReference.DOCTYPE_FUNCTION, - "Containers.build_ref_sets" => + "Containers.build_error_fn" => + DocumenterReference.DOCTYPE_FUNCTION, + "Containers.parse_macro_arguments" => + DocumenterReference.DOCTYPE_FUNCTION, + "Containers.parse_ref_sets" => + DocumenterReference.DOCTYPE_FUNCTION, + "Containers.build_name_expr" => + DocumenterReference.DOCTYPE_FUNCTION, + "Containers.add_additional_args" => DocumenterReference.DOCTYPE_FUNCTION, "Containers.container_code" => DocumenterReference.DOCTYPE_FUNCTION, diff --git a/src/Containers/macro.jl b/src/Containers/macro.jl index 95c7d04fe24..997ec64e751 100644 --- a/src/Containers/macro.jl +++ b/src/Containers/macro.jl @@ -3,20 +3,8 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. -_get_name(c::Union{Symbol,Nothing}) = c -_get_name(c) = error("Expression `$c::$(typeof(c))` cannot be used as a name.") - -function _get_name(c::Expr) - if Meta.isexpr(c, :vcat) || Meta.isexpr(c, :vect) - return nothing - elseif Meta.isexpr(c, :ref) || Meta.isexpr(c, :typed_vcat) - return _get_name(c.args[1]) - end - return error("Expression $c cannot be used as a name.") -end - function _reorder_parameters(args) - if !Meta.isexpr(args[1], :parameters) + if isempty(args) || !Meta.isexpr(args[1], :parameters) return args end args = collect(args) @@ -29,7 +17,12 @@ function _reorder_parameters(args) end """ - parse_macro_arguments(error_fn::Function, args) + parse_macro_arguments( + error_fn::Function, + args; + valid_kwargs::Union{Nothing,Vector{Symbol}} = nothing, + num_positional_args::Union{Nothing,Int,UnitRange{Int}} = nothing, + ) Returns a `Tuple{Vector{Any},Dict{Symbol,Any}}` containing the ordered positional arguments and a dictionary mapping the keyword arguments. @@ -37,24 +30,52 @@ positional arguments and a dictionary mapping the keyword arguments. This specially handles the distinction of `@foo(key = value)` and `@foo(; key = value)` in macros. -Throws an error if mulitple keyword arguments are passed with the same name. +An error is thrown if multiple keyword arguments are passed with the same key. + +If `valid_kwargs` is a `Vector{Symbol}`, an error is thrown if a keyword is not +in `valid_kwargs`. + +If `num_positional_args` is not nothing, an error is thrown if the number of +positional arguments is not in `num_positional_args`. """ -function parse_macro_arguments(error_fn::Function, args) - pos_args, kw_args = Any[], Dict{Symbol,Any}() +function parse_macro_arguments( + error_fn::Function, + args; + valid_kwargs::Union{Nothing,Vector{Symbol}} = nothing, + num_positional_args::Union{Nothing,Int,UnitRange{Int}} = nothing, +) + pos_args, kwargs = Any[], Dict{Symbol,Any}() for arg in _reorder_parameters(args) if Meta.isexpr(arg, :(=), 2) - if haskey(kw_args, arg.args[1]) + if haskey(kwargs, arg.args[1]) error_fn( "the keyword argument `$(arg.args[1])` was given " * "multiple times.", ) + elseif valid_kwargs !== nothing && !(arg.args[1] in valid_kwargs) + error_fn("unsupported keyword argument `$(arg.args[1])`.") end - kw_args[arg.args[1]] = arg.args[2] + kwargs[arg.args[1]] = arg.args[2] else push!(pos_args, arg) end end - return pos_args, kw_args + if num_positional_args isa Int + n = length(pos_args) + if n != num_positional_args + error_fn( + "expected $num_positional_args positional arguments, got $n.", + ) + end + elseif num_positional_args isa UnitRange{Int} + if !(length(pos_args) in num_positional_args) + a, b = num_positional_args.start, num_positional_args.stop + error_fn( + "expected $a to $b positional arguments, got $(length(pos_args)).", + ) + end + end + return pos_args, kwargs end """ @@ -204,8 +225,27 @@ function _has_dependent_sets(index_vars::Vector{Any}, index_sets::Vector{Any}) return false end +function _container_name(error_fn::Function, x) + return error_fn("Expression `$x::$(typeof(x))` cannot be used as a name.") +end + +_container_name(::Function, expr::Union{Symbol,Nothing}) = expr + +function _container_name(error_fn::Function, expr::Expr) + if Meta.isexpr(expr, (:vcat, :vect)) + return nothing + elseif Meta.isexpr(expr, (:ref, :typed_vcat)) + return _container_name(error_fn, expr.args[1]) + end + return error_fn("Expression $expr cannot be used as a name.") +end + """ - build_ref_sets(error_fn::Function, expr) + parse_ref_sets( + error_fn::Function, + expr; + invalid_index_variables::Vector{Symbol} = Symbol[], + ) Helper function for macros to construct container objects. @@ -223,16 +263,64 @@ Helper function for macros to construct container objects. ## Returns - 1. `index_vars`: a `Vector{Any}` of names for the index variables, e.g., + 1. `name`: the name of the container, if given, otherwise `nothing` + 2. `index_vars`: a `Vector{Any}` of names for the index variables, e.g., `[:i, gensym(), :k]`. These may also be expressions, like `:((i, j))` from a call like `:(x[(i, j) in S])`. - 2. `indices`: an iterator over the indices, for example, + 3. `indices`: an iterator over the indices, for example, [`Containers.NestedIterator`](@ref) ## Example See [`container_code`](@ref) for a worked example. """ +function parse_ref_sets( + error_fn::Function, + expr::Union{Nothing,Symbol,Expr}; + invalid_index_variables::Vector = Symbol[], +) + name = _container_name(error_fn, expr) + index_vars, indices = build_ref_sets(error_fn, expr) + for name in invalid_index_variables + if name in index_vars + error_fn( + "the index name `$name` conflicts with another variable in " * + "this scope. Use a different name for the index.", + ) + end + end + return name, index_vars, indices +end + +# This method is needed because Julia v1.10 prints LineNumberNode in the string +# representation of an expression. +function _strip_LineNumberNode(x::Expr) + if Meta.isexpr(x, :block) + return Expr(:block, filter(!Base.Fix2(isa, LineNumberNode), x.args)...) + end + return x +end + +_strip_LineNumberNode(x) = x + +""" + build_error_fn(macro_name, args, source) + +Return a function that can be used in place of `Base.error`, but which +additionally prints the macro from which it was called. +""" +function build_error_fn(macro_name, args, source) + str_args = join(_strip_LineNumberNode.(args), ", ") + msg = "At $(source.file):$(source.line): `@$macro_name($str_args)`: " + error_fn(str...) = error(msg, str...) + return error_fn +end + +""" + build_ref_sets(error_fn::Function, expr) + +This function is deprecated. Use [`parse_ref_sets`](@ref) instead. +""" function build_ref_sets(error_fn::Function, expr) index_vars, index_sets, condition = _parse_ref_sets(error_fn, expr) if any(_expr_is_splat, index_sets) @@ -263,16 +351,107 @@ function build_ref_sets(error_fn::Function, expr) return index_vars, indices end +""" + add_additional_args( + call::Expr, + args::Vector, + kwargs::Dict{Symbol,Any}; + kwarg_exclude::Vector{Symbol} = Symbol[], + ) + +Add the positional arguments `args` to the function call expression `call`, +escaping each argument expression. + +This function is able to incorporate additional positional arguments to `call`s +that already have keyword arguments. +""" +function add_additional_args( + call::Expr, + args::Vector, + kwargs::Dict{Symbol,Any}; + kwarg_exclude::Vector{Symbol} = Symbol[], +) + call_args = call.args + if Meta.isexpr(call, :.) + # call is broadcasted + call_args = call.args[2].args + end + # Cache all keyword arguments + kw_args = filter(Base.Fix2(Meta.isexpr, :kw), call_args) + # Remove keyowrd arguments from the end + filter!(!Base.Fix2(Meta.isexpr, :kw), call_args) + # Add the new positional arguments + append!(call_args, esc.(args)) + # Re-add the cached keyword arguments back to the end + append!(call_args, kw_args) + for (key, value) in kwargs + if !(key in kwarg_exclude) + push!(call_args, esc(Expr(:kw, key, value))) + end + end + return +end + +""" + build_name_expr( + name::Union{Symbol,Nothing}, + index_vars::Vector, + kwargs::Dict{Symbol,Any}, + ) + +Returns an expression for the name of a container element, where `name` and +`index_vars` are the values returned by [`parse_ref_sets`](@ref) and `kwargs` +is the dictionary returned by [`parse_macro_arguments`](@ref). + +This assumes that the key in `kwargs` used to over-ride the name choice is +`:base_name`. + +## Examples + +```jldoctest +julia> Containers.build_name_expr(:x, [:i, :j], Dict{Symbol,Any}()) +:(string("x", "[", string(\$(Expr(:escape, :i))), ",", string(\$(Expr(:escape, :j))), "]")) + +julia> Containers.build_name_expr(nothing, [:i, :j], Dict{Symbol,Any}()) +"" + +julia> Containers.build_name_expr(:y, [:i, :j], Dict{Symbol,Any}(:base_name => "y")) +:(string("y", "[", string(\$(Expr(:escape, :i))), ",", string(\$(Expr(:escape, :j))), "]")) +``` +""" +function build_name_expr( + name::Union{Symbol,Nothing}, + index_vars::Vector, + kwargs::Dict{Symbol,Any}, +) + base_name = get(kwargs, :base_name, string(something(name, ""))) + if base_name isa Expr + base_name = esc(base_name) + end + if isempty(index_vars) || base_name == "" + return base_name + end + expr = Expr(:call, :string, base_name, "[") + for index in index_vars + # Converting the arguments to strings before concatenating is faster: + # https://github.com/JuliaLang/julia/issues/29550. + push!(expr.args, :(string($(esc(index))))) + push!(expr.args, ",") + end + expr.args[end] = "]" + return expr +end + """ container_code( index_vars::Vector{Any}, indices::Expr, code, - requested_container::Union{Symbol,Expr}, + requested_container::Union{Symbol,Expr,Dict{Symbol,Any}}, ) Used in macros to construct a call to [`container`](@ref). This should be used -in conjunction with [`build_ref_sets`](@ref). +in conjunction with [`parse_ref_sets`](@ref). ## Arguments @@ -285,7 +464,8 @@ in conjunction with [`build_ref_sets`](@ref). * `requested_container`: passed to the third argument of [`container`](@ref). For built-in JuMP types, choose one of `:Array`, `:DenseAxisArray`, `:SparseAxisArray`, or `:Auto`. For a user-defined container, this expression - must evaluate to the correct type. + must evaluate to the correct type. You may also pass the `kwargs` dictionary + from [`parse_macro_arguments`](@ref). !!! warning In most cases, you should `esc(code)` before passing it to `container_code`. @@ -294,17 +474,20 @@ in conjunction with [`build_ref_sets`](@ref). ```jldoctest julia> macro foo(ref_sets, code) - index_vars, indices = Containers.build_ref_sets(error, ref_sets) - return Containers.container_code( - index_vars, - indices, - esc(code), - :Auto, - ) + name, index_vars, indices = + Containers.parse_ref_sets(error, ref_sets) + @assert name !== nothing # Anonymous container not supported + container = + Containers.container_code(index_vars, indices, esc(code), :Auto) + return quote + \$(esc(name)) = \$container + end end @foo (macro with 1 method) -julia> @foo(x[i=1:2, j=["A", "B"]], j^i) +julia> @foo(x[i=1:2, j=["A", "B"]], j^i); + +julia> x 2-dimensional DenseAxisArray{String,2,...} with index sets: Dimension 1, Base.OneTo(2) Dimension 2, ["A", "B"] @@ -342,6 +525,16 @@ function container_code( return Expr(:call, container, f, indices, container_type, index_vars) end +function container_code( + index_vars::Vector{Any}, + indices::Expr, + code, + kwargs::Dict{Symbol,Any}, +) + container = get(kwargs, :container, :Auto) + return container_code(index_vars, indices, code, container) +end + """ @container([i=..., j=..., ...], expr[, container = :Auto]) @@ -400,15 +593,14 @@ SparseAxisArray{Int64, 2, Tuple{Int64, Int64}} with 6 entries: ``` """ macro container(input_args...) - args, kw_args = parse_macro_arguments(error, input_args) - container = get(kw_args, :container, :Auto) - @assert length(args) == 2 - for key in keys(kw_args) - @assert key == :container - end - index_vars, indices = build_ref_sets(error, args[1]) - code = container_code(index_vars, indices, esc(args[2]), container) - name = _get_name(args[1]) + args, kwargs = parse_macro_arguments( + error, + input_args; + num_positional_args = 2, + valid_kwargs = [:container], + ) + name, index_vars, indices = parse_ref_sets(error, args[1]) + code = container_code(index_vars, indices, esc(args[2]), kwargs) if name === nothing return code end diff --git a/src/macros.jl b/src/macros.jl index 884de13748c..d597c311fea 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -323,47 +323,6 @@ function model_convert( return model_convert.(model, x) end -""" - _add_additional_args( - call::Expr, - args::Vector, - kwargs::Dict{Symbol,Any}; - kwarg_exclude::Vector{Symbol} = Symbol[], - ) - -Add the positional arguments `args` to the function call expression `call`, -escaping each argument expression. - -This function is able to incorporate additional positional arguments to `call`s -that already have keyword arguments. -""" -function _add_additional_args( - call::Expr, - args::Vector, - kwargs::Dict{Symbol,Any}; - kwarg_exclude::Vector{Symbol} = Symbol[], -) - call_args = call.args - if Meta.isexpr(call, :.) - # call is broadcasted - call_args = call.args[2].args - end - # Cache all keyword arguments - kw_args = filter(Base.Fix2(Meta.isexpr, :kw), call_args) - # Remove keyowrd arguments from the end - filter!(!Base.Fix2(Meta.isexpr, :kw), call_args) - # Add the new positional arguments - append!(call_args, esc.(args)) - # Re-add the cached keyword arguments back to the end - append!(call_args, kw_args) - for (key, value) in kwargs - if !(key in kwarg_exclude) - push!(call_args, esc(Expr(:kw, key, value))) - end - end - return -end - _valid_model(::AbstractModel, ::Any) = nothing function _valid_model(m::M, name) where {M} @@ -435,40 +394,6 @@ function _error_if_cannot_register(model::AbstractModel, name::Symbol) return end -# This method is needed because Julia v1.10 prints LineNumberNode in the string -# representation of an expression. -function _strip_LineNumberNode(x::Expr) - if Meta.isexpr(x, :block) - return Expr(:block, filter(!Base.Fix2(isa, LineNumberNode), x.args)...) - end - return x -end - -_strip_LineNumberNode(x) = x - -function _macro_error(macro_name, args, source, str...) - str_args = join(_strip_LineNumberNode.(args), ", ") - return error( - "At $(source.file):$(source.line): `@$macro_name($str_args)`: ", - str..., - ) -end - -function _base_name_with_indices(base_name, index_vars::Vector) - if isempty(index_vars) || base_name == "" - return base_name - end - expr = Expr(:call, :string, base_name, "[") - for index in index_vars - # Converting the arguments to strings before concatenating is faster: - # https://github.com/JuliaLang/julia/issues/29550. - push!(expr.args, :(string($(esc(index))))) - push!(expr.args, ",") - end - expr.args[end] = "]" - return expr -end - """ _replace_zero(model::M, x) where {M<:AbstractModel} diff --git a/src/macros/@NL.jl b/src/macros/@NL.jl index 060b6560830..a51e77acba7 100644 --- a/src/macros/@NL.jl +++ b/src/macros/@NL.jl @@ -273,9 +273,8 @@ Subject to ``` """ macro NLobjective(model, sense, x) - function error_fn(str...) - return _macro_error(:NLobjective, (model, sense, x), __source__, str...) - end + error_fn = + Containers.build_error_fn(:NLobjective, (model, sense, x), __source__) sense_expr = _parse_moi_sense(error_fn, sense) esc_model = esc(model) parsing_code, expr = _parse_nonlinear_expression(esc_model, x) @@ -313,9 +312,8 @@ julia> @NLconstraint(model, [i = 1:3], sin(i * x) <= 1 / i) ``` """ macro NLconstraint(m, x, args...) - function error_fn(str...) - return _macro_error(:NLconstraint, (m, x, args...), __source__, str...) - end + error_fn = + Containers.build_error_fn(:NLconstraint, (m, x, args...), __source__) esc_m = esc(m) if Meta.isexpr(x, :block) error_fn("Invalid syntax. Did you mean to use `@NLconstraints`?") @@ -332,13 +330,8 @@ macro NLconstraint(m, x, args...) con = length(extra) == 1 ? extra[1] : x # Strategy: build up the code for non-macro add_constraint, and if needed # we will wrap in loops to assign to the ConstraintRefs - idxvars, indices = Containers.build_ref_sets(error_fn, c) - if m in idxvars - error_fn( - "Index $(m) is the same symbol as the model. Use a different " * - "name for the index.", - ) - end + name, idxvars, indices = + Containers.parse_ref_sets(error_fn, c; invalid_index_variables = [m]) parsing_code, expr = _parse_nonlinear_expression(esc_m, con) code = quote $parsing_code @@ -354,7 +347,7 @@ macro NLconstraint(m, x, args...) esc_m, creation_code, __source__; - register_name = Containers._get_name(c), + register_name = name, ) end @@ -436,7 +429,7 @@ subexpression[5]: log(1.0 + (exp(subexpression[2]) + exp(subexpression[3]))) ``` """ macro NLexpression(args...) - error_fn(str...) = _macro_error(:NLexpression, args, __source__, str...) + error_fn = Containers.build_error_fn(:NLexpression, args, __source__) args, kw_args, requested_container = _extract_kw_args(args) if length(args) <= 1 error_fn( @@ -454,13 +447,12 @@ macro NLexpression(args...) if length(args) > 3 || length(kw_args) > 0 error_fn("To many arguments ($(length(args))).") end - idxvars, indices = Containers.build_ref_sets(error_fn, c) - if args[1] in idxvars - error_fn( - "Index $(args[1]) is the same symbol as the model. Use a " * - "different name for the index.", - ) - end + name, idxvars, indices = Containers.parse_ref_sets(error_fn, c) + name, idxvars, indices = Containers.parse_ref_sets( + error_fn, + c; + invalid_index_variables = [args[1]], + ) esc_m = esc(m) parsing_code, expr = _parse_nonlinear_expression(esc_m, x) code = quote @@ -473,7 +465,7 @@ macro NLexpression(args...) esc_m, creation_code, __source__; - register_name = Containers._get_name(c), + register_name = name, ) end @@ -598,9 +590,8 @@ julia> value(y[2]) """ macro NLparameter(model, args...) esc_m = esc(model) - function error_fn(str...) - return _macro_error(:NLparameter, (model, args...), __source__, str...) - end + error_fn = + Containers.build_error_fn(:NLparameter, (model, args...), __source__) pos_args, kw_args, requested_container = _extract_kw_args(args) value = missing for arg in kw_args @@ -634,13 +625,11 @@ macro NLparameter(model, args...) if ismissing(value) param, value = pos_args[1].args[2], pos_args[1].args[3] end - index_vars, index_values = Containers.build_ref_sets(error_fn, param) - if model in index_vars - error_fn( - "Index $(model) is the same symbol as the model. Use a different " * - "name for the index.", - ) - end + name, index_vars, index_values = Containers.parse_ref_sets( + error_fn, + param; + invalid_index_variables = [model], + ) code = quote if !isa($(esc(value)), Number) $(esc(error_fn))("Parameter value is not a number.") @@ -657,7 +646,7 @@ macro NLparameter(model, args...) esc_m, creation_code, __source__; - register_name = Containers._get_name(param), + register_name = name, ) end diff --git a/src/macros/@constraint.jl b/src/macros/@constraint.jl index 2e128567080..b33468c587b 100644 --- a/src/macros/@constraint.jl +++ b/src/macros/@constraint.jl @@ -59,7 +59,7 @@ which are not listed here. Other keyword arguments may be supported by JuMP extensions. """ macro constraint(input_args...) - error_fn(str...) = _macro_error(:constraint, input_args, __source__, str...) + error_fn = Containers.build_error_fn(:constraint, input_args, __source__) args, kwargs = Containers.parse_macro_arguments(error_fn, input_args) if length(args) < 2 && !isempty(kwargs) error_fn( @@ -67,7 +67,7 @@ macro constraint(input_args...) "construct an equality constraint, use `==` instead of `=`.", ) elseif length(args) < 2 - error_fn("Not enough arguments") + error_fn("expected 2 to 4 positional arguments, got $(length(args)).") elseif Meta.isexpr(args[2], :block) error_fn("Invalid syntax. Did you mean to use `@constraints`?") end @@ -92,31 +92,20 @@ macro constraint(input_args...) if length(extra) > 1 error_fn("Cannot specify more than 1 additional positional argument.") end - index_vars, indices = Containers.build_ref_sets(error_fn, c) - if args[1] in index_vars - error_fn( - "Index $(args[1]) is the same symbol as the model. Use a " * - "different name for the index.", - ) - end + name, index_vars, indices = Containers.parse_ref_sets( + error_fn, + c; + invalid_index_variables = [args[1]], + ) is_vectorized, parse_code, build_call = parse_constraint(error_fn, x) - _add_additional_args( + Containers.add_additional_args( build_call, extra, kwargs; kwarg_exclude = [:base_name, :container, :set_string_name], ) - # ; base_name - default_base_name = string(something(Containers._get_name(c), "")) - base_name = get(kwargs, :base_name, default_base_name) - if base_name isa Expr - base_name = esc(base_name) - end - # ; container - # There is no need to escape this one. - container = get(kwargs, :container, :Auto) # ; set_string_name - name_expr = _base_name_with_indices(base_name, index_vars) + name_expr = Containers.build_name_expr(name, index_vars, kwargs) if name_expr != "" set_string_name = if haskey(kwargs, :set_string_name) esc(kwargs[:set_string_name]) @@ -146,9 +135,9 @@ macro constraint(input_args...) end return _finalize_macro( model, - Containers.container_code(index_vars, indices, code, container), + Containers.container_code(index_vars, indices, code, kwargs), __source__; - register_name = Containers._get_name(c), + register_name = name, wrap_let = true, ) end @@ -218,9 +207,7 @@ ScalarConstraint{AffExpr, MathOptInterface.GreaterThan{Float64}}(2 x, MathOptInt ``` """ macro build_constraint(arg) - function error_fn(str...) - return _macro_error(:build_constraint, (arg,), __source__, str...) - end + error_fn = Containers.build_error_fn(:build_constraint, (arg,), __source__) _, parse_code, build_call = parse_constraint(error_fn, arg) return quote $parse_code diff --git a/src/macros/@expression.jl b/src/macros/@expression.jl index 1cdc3ad1e9d..610ccf4204f 100644 --- a/src/macros/@expression.jl +++ b/src/macros/@expression.jl @@ -64,27 +64,22 @@ julia> expr = @expression(model, [i in 1:3], i * sum(x[j] for j in 1:3)) ``` """ macro expression(input_args...) - error_fn(str...) = _macro_error(:expression, input_args, __source__, str...) - args, kwargs = Containers.parse_macro_arguments(error_fn, input_args) - if !(2 <= length(args) <= 3) - error_fn("expected 2 or 3 positional arguments, got $(length(args)).") - elseif Meta.isexpr(args[2], :block) + error_fn = Containers.build_error_fn(:expression, input_args, __source__) + args, kwargs = Containers.parse_macro_arguments( + error_fn, + input_args; + num_positional_args = 2:3, + valid_kwargs = [:container], + ) + if Meta.isexpr(args[2], :block) error_fn("Invalid syntax. Did you mean to use `@expressions`?") - elseif !isempty(kwargs) - for key in keys(kwargs) - if key != :container - error_fn("unsupported keyword argument `$key`.") - end - end end name_expr = length(args) == 3 ? args[2] : nothing - index_vars, indices = Containers.build_ref_sets(error_fn, name_expr) - if args[1] in index_vars - error_fn( - "Index $(args[1]) is the same symbol as the model. Use a " * - "different name for the index.", - ) - end + name, index_vars, indices = Containers.parse_ref_sets( + error_fn, + name_expr; + invalid_index_variables = [args[1]], + ) model = esc(args[1]) expr, build_code = _rewrite_expression(args[end]) code = quote @@ -93,12 +88,11 @@ macro expression(input_args...) # other structure that returns `_MA.Zero()`. _replace_zero($model, $expr) end - container = get(kwargs, :container, :Auto) return _finalize_macro( model, - Containers.container_code(index_vars, indices, code, container), + Containers.container_code(index_vars, indices, code, kwargs), __source__; - register_name = Containers._get_name(name_expr), + register_name = name, wrap_let = true, ) end diff --git a/src/macros/@objective.jl b/src/macros/@objective.jl index 4f35db77f55..d3a7fe0a7de 100644 --- a/src/macros/@objective.jl +++ b/src/macros/@objective.jl @@ -52,15 +52,13 @@ x² - 2 x + 1 ``` """ macro objective(input_args...) - error_fn(str...) = _macro_error(:objective, input_args, __source__, str...) - args, kwargs = Containers.parse_macro_arguments(error_fn, input_args) - if length(args) != 3 - error_fn("expected 3 positional arguments, got $(length(args)).") - elseif !isempty(kwargs) - for key in keys(kwargs) - error_fn("unsupported keyword argument `$key`.") - end - end + error_fn = Containers.build_error_fn(:objective, input_args, __source__) + args, kwargs = Containers.parse_macro_arguments( + error_fn, + input_args; + num_positional_args = 3, + valid_kwargs = Symbol[], + ) esc_model = esc(args[1]) sense = _parse_moi_sense(error_fn, args[2]) expr, parse_code = _rewrite_expression(args[3]) diff --git a/src/macros/@variable.jl b/src/macros/@variable.jl index 09d148ac63c..72d9ba0cccb 100644 --- a/src/macros/@variable.jl +++ b/src/macros/@variable.jl @@ -140,7 +140,7 @@ julia> @variable(model, z[i=1:3], set_string_name = false) ``` """ macro variable(input_args...) - error_fn(str...) = _macro_error(:variable, input_args, __source__, str...) + error_fn = Containers.build_error_fn(:variable, input_args, __source__) args, kwargs = Containers.parse_macro_arguments(error_fn, input_args) if length(args) >= 2 && Meta.isexpr(args[2], :block) error_fn("Invalid syntax. Did you mean to use `@variables`?") @@ -188,16 +188,14 @@ macro variable(input_args...) end end # if var === nothing, then the variable is anonymous - if !(var isa Symbol || var isa Expr || var === nothing) + if !(var isa Union{Nothing,Symbol,Expr}) error_fn("Expected $var to be a variable name") end - index_vars, indices = Containers.build_ref_sets(error_fn, var) - if model_sym in index_vars - error_fn( - "Index $model_sym is the same symbol as the model. Use a " * - "different name for the index.", - ) - end + name, index_vars, indices = Containers.parse_ref_sets( + error_fn, + var; + invalid_index_variables = [model_sym], + ) # Handle special keyword arguments # ; set set_kw = get(kwargs, :set, nothing) @@ -210,17 +208,8 @@ macro variable(input_args...) end set = set_kw end - # ; base_name - default_base_name = string(something(Containers._get_name(var), "")) - base_name = get(kwargs, :base_name, default_base_name) - if base_name isa Expr - base_name = esc(base_name) - end - # ; container - # There is no need to escape this one. - container = get(kwargs, :container, :Auto) # ; set_string_name - name_expr = _base_name_with_indices(base_name, index_vars) + name_expr = Containers.build_name_expr(name, index_vars, kwargs) if name_expr != "" set_string_name = if haskey(kwargs, :set_string_name) esc(kwargs[:set_string_name]) @@ -268,7 +257,7 @@ macro variable(input_args...) end filter!(ex -> !(ex in (:Int, :Bin, :PSD, :Symmetric, :Hermitian)), args) build_code = :(build_variable($error_fn, $(_constructor_expr(info_expr)))) - _add_additional_args( + Containers.add_additional_args( build_code, args, kwargs; @@ -288,7 +277,7 @@ macro variable(input_args...) variable = model_convert($model, $build_code) add_variable($model, variable, $name_expr) end, - container, + kwargs, ) elseif any(Base.Fix2(Containers.depends_on, set), index_vars) # This is for calls in which the set depends on the indices. @@ -300,7 +289,7 @@ macro variable(input_args...) build = build_variable($error_fn, $build_code, $set) add_variable($model, model_convert($model, build), $name_expr) end, - container, + kwargs, ) else # This is for calls in which the set does not depend on the indices. @@ -312,13 +301,13 @@ macro variable(input_args...) index_vars, indices, build_code, - container, + kwargs, ) name_expr = Containers.container_code( index_vars, indices, name_expr, - container, + kwargs, ) quote build = build_variable($error_fn, $build_code, $set) @@ -329,7 +318,7 @@ macro variable(input_args...) model, code, __source__; - register_name = Containers._get_name(var), + register_name = name, wrap_let = true, ) end diff --git a/test/Containers/test_macro.jl b/test/Containers/test_macro.jl index 64c71085fae..9227e240fe0 100644 --- a/test/Containers/test_macro.jl +++ b/test/Containers/test_macro.jl @@ -231,4 +231,39 @@ function test__MyContainer2() return end +function test_parse_macro_arguments() + args, kwargs = Containers.parse_macro_arguments(error, ()) + @test args == Any[] + @test isempty(kwargs) + return +end + +function test_add_additional_args() + call = :(f(1; a = 2)) + kwargs = Dict{Symbol,Any}() + @test Containers.add_additional_args(call, [:(foo)], kwargs) === nothing + @test call == :(f(1, $(Expr(:escape, :foo)); a = 2)) + call = :(f(1)) + Containers.add_additional_args(call, [2, 3], kwargs) + @test call == :(f(1, $(esc(2)), $(esc(3)))) + call = :(f.(1)) + Containers.add_additional_args(call, [2, 3], kwargs) + @test call == :(f.(1, $(esc(2)), $(esc(3)))) + call = :(f(1; a = 4)) + Containers.add_additional_args(call, [2, 3], kwargs) + @test call == :(f(1, $(esc(2)), $(esc(3)); a = 4)) + call = :(f.(1; a = 4)) + Containers.add_additional_args(call, [2, 3], kwargs) + @test call == :(f.(1, $(esc(2)), $(esc(3)); a = 4)) + call = :(f.(1, a = 4)) + kwargs = Dict{Symbol,Any}(:b => 4, :c => false) + Containers.add_additional_args(call, Any[2], kwargs; kwarg_exclude = [:b]) + @test call == Expr( + :., + :f, + Expr(:tuple, 1, esc(2), Expr(:kw, :a, 4), esc(Expr(:kw, :c, false))), + ) + return +end + end # module diff --git a/test/test_macros.jl b/test/test_macros.jl index b70382f4e23..25464e13967 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -166,34 +166,6 @@ function test_Check_Julia_condition_expression_parsing() return end -function test_add_additional_args() - call = :(f(1; a = 2)) - kwargs = Dict{Symbol,Any}() - @test JuMP._add_additional_args(call, [:(MyObject)], kwargs) isa Nothing - @test call == :(f(1, $(Expr(:escape, :MyObject)); a = 2)) - call = :(f(1)) - JuMP._add_additional_args(call, [2, 3], kwargs) - @test call == :(f(1, $(esc(2)), $(esc(3)))) - call = :(f.(1)) - JuMP._add_additional_args(call, [2, 3], kwargs) - @test call == :(f.(1, $(esc(2)), $(esc(3)))) - call = :(f(1; a = 4)) - JuMP._add_additional_args(call, [2, 3], kwargs) - @test call == :(f(1, $(esc(2)), $(esc(3)); a = 4)) - call = :(f.(1; a = 4)) - JuMP._add_additional_args(call, [2, 3], kwargs) - @test call == :(f.(1, $(esc(2)), $(esc(3)); a = 4)) - call = :(f.(1, a = 4)) - kwargs = Dict{Symbol,Any}(:b => 4, :c => false) - JuMP._add_additional_args(call, Any[2], kwargs; kwarg_exclude = [:b]) - @test call == Expr( - :., - :f, - Expr(:tuple, 1, esc(2), Expr(:kw, :a, 4), esc(Expr(:kw, :c, false))), - ) - return -end - function test_MutableArithmetics_Zero_Issue_2187() model = Model() c = @constraint(model, sum(1 for _ in 1:0) == sum(1 for _ in 1:0)) @@ -964,7 +936,7 @@ end function test_Model_as_index() m = Model() @variable(m, x) - msg = "Index m is the same symbol as the model. Use a different name for the index." + msg = "the index name `m` conflicts with another variable in this scope. Use a different name for the index." @test_throws_parsetime( ErrorException("In `@variable(m, y[m = 1:2] <= m)`: $(msg)"), @variable(m, y[m = 1:2] <= m), @@ -1395,7 +1367,9 @@ end function test_invalid_name_errors() model = Model() @test_throws_parsetime( - ErrorException("Expression x.y cannot be used as a name."), + ErrorException( + "In `@variable(model, x.y)`: Expression x.y cannot be used as a name.", + ), @variable(model, x.y), ) return @@ -1404,7 +1378,9 @@ end function test_invalid_name_errors_denseaxisarray() model = Model() @test_throws_parsetime( - ErrorException("Expression x.y cannot be used as a name."), + ErrorException( + "In `@variable(model, x.y[2:3, 1:2])`: Expression x.y cannot be used as a name.", + ), @variable(model, x.y[2:3, 1:2]), ) return @@ -1413,7 +1389,9 @@ end function test_invalid_name_errors_sparseaxisarray() model = Model() @test_throws_parsetime( - ErrorException("Expression x.y cannot be used as a name."), + ErrorException( + "In `@variable(model, x.y[i = 1:3; isodd(i)])`: Expression x.y cannot be used as a name.", + ), @variable(model, x.y[i = 1:3; isodd(i)]), ) return @@ -1680,7 +1658,9 @@ end function test_constraint_not_enough_arguments() model = Model() @test_throws_parsetime( - ErrorException("In `@constraint(model)`: Not enough arguments"), + ErrorException( + "In `@constraint(model)`: expected 2 to 4 positional arguments, got 1.", + ), @constraint(model), ) return @@ -1713,7 +1693,7 @@ function test_expression_not_enough_arguments() model = Model() @test_throws_parsetime( ErrorException( - "In `@expression(model)`: expected 2 or 3 positional arguments, got 1.", + "In `@expression(model)`: expected 2 to 3 positional arguments, got 1.", ), @expression(model), )