Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parse_macro_arguments to unify how we handle macro inputs #3616

Merged
merged 6 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,46 @@ function _get_name(c::Expr)
return error("Expression $c cannot be used as a name.")
end

function _reorder_parameters(args)
if !Meta.isexpr(args[1], :parameters)
return args
end
args = collect(args)
p = popfirst!(args)
for arg in p.args
@assert arg.head == :kw
push!(args, Expr(:(=), arg.args[1], arg.args[2]))
end
return args
end

"""
_extract_kw_args(args)
parse_macro_arguments(error_fn::Function, args)

Process the arguments to a macro, separating out the keyword arguments.
Returns a `Tuple{Vector{Any},Dict{Symbol,Any}}` containing the ordered
positional arguments and a dictionary mapping the keyword arguments.

Return a tuple of (flat_arguments, keyword arguments, and requested_container),
where `requested_container` is a symbol to be passed to `container_code`.
This specially handles the distinction of `@foo(key = value)` and
`@foo(; key = value)` in macros.

Throws an error if mulitple keyword arguments are passed with the same name.
"""
function _extract_kw_args(args)
flat_args, kw_args, requested_container = Any[], Any[], :Auto
for arg in args
if Meta.isexpr(arg, :(=))
if arg.args[1] == :container
requested_container = arg.args[2]
else
push!(kw_args, arg)
function parse_macro_arguments(error_fn::Function, args)
pos_args, kw_args = Any[], Dict{Symbol,Any}()
for arg in _reorder_parameters(args)
if Meta.isexpr(arg, :(=), 2)
if haskey(kw_args, arg.args[1])
error_fn(
"the keyword argument `$(arg.args[1])` was given " *
"multiple times.",
)
end
kw_args[arg.args[1]] = arg.args[2]
else
push!(flat_args, arg)
push!(pos_args, arg)
end
end
return flat_args, kw_args, requested_container
return pos_args, kw_args
end

"""
Expand Down Expand Up @@ -381,14 +399,16 @@ SparseAxisArray{Int64, 2, Tuple{Int64, Int64}} with 6 entries:
[3, 3] = 6
```
"""
macro container(args...)
args, kw_args, requested_container = _extract_kw_args(args)
macro container(input_args...)
args, kw_args = parse_macro_arguments(error, input_args)
container = get(kw_args, :container, :Auto)
@assert length(args) == 2
@assert isempty(kw_args)
var, value = args
index_vars, indices = build_ref_sets(error, var)
code = container_code(index_vars, indices, esc(value), requested_container)
name = _get_name(var)
for key in keys(kw_args)
@assert key == :container
end
index_vars, indices = build_ref_sets(error, args[1])
code = container_code(index_vars, indices, esc(args[2]), container)
name = _get_name(args[1])
if name === nothing
return code
end
Expand Down
81 changes: 10 additions & 71 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,45 +323,24 @@ function model_convert(
return model_convert.(model, x)
end

"""
_add_kw_args(call, kw_args)

Add the keyword arguments `kw_args` to the function call expression `call`,
escaping the expressions. The elements of `kw_args` should be expressions of the
form `:(key = value)`. The `kw_args` vector can be extracted from the arguments
of a macro with [`Containers._extract_kw_args`](@ref).

## Example

```jldoctest
julia> call = :(f(1, a=2))
:(f(1, a = 2))

julia> JuMP._add_kw_args(call, [:(b=3), :(c=4)])

julia> call
:(f(1, a = 2, \$(Expr(:escape, :(\$(Expr(:kw, :b, 3))))), \$(Expr(:escape, :(\$(Expr(:kw, :c, 4)))))))
```
"""
function _add_kw_args(call, kw_args; exclude = Symbol[])
for kw in kw_args
@assert Meta.isexpr(kw, :(=))
if kw.args[1] in exclude
function _add_keyword_args(call::Expr, kwargs::Dict; exclude = Symbol[])
for (key, value) in kwargs
if key in exclude
continue
end
push!(call.args, esc(Expr(:kw, kw.args...)))
push!(call.args, esc(Expr(:kw, key, value)))
end
return
end

"""
_add_positional_args(call, args)::Nothing
_add_positional_args(call::Expr, args::Vector{Any})::Nothing

Add the positional arguments `args` to the function call expression `call`,
escaping each argument expression. The elements of `args` should be ones that
were extracted via [`Containers._extract_kw_args`](@ref) and had appropriate
arguments filtered out (e.g., the model argument). This is able to incorporate
additional positional arguments to `call`s that already have keyword arguments.
escaping each argument expression.

This function is able to incorporate additional positional arguments to `call`s
that already have keyword arguments.

## Example

Expand All @@ -375,7 +354,7 @@ julia> call
:(f(1, $(Expr(:escape, :x)), a = 2))
```
"""
function _add_positional_args(call, args)
function _add_positional_args(call::Expr, args::Vector)
call_args = call.args
if Meta.isexpr(call, :.)
# call is broadcasted
Expand All @@ -392,19 +371,6 @@ function _add_positional_args(call, args)
return
end

function _reorder_parameters(args)
if !Meta.isexpr(args[1], :parameters)
return args
end
args = collect(args)
p = popfirst!(args)
for arg in p.args
@assert arg.head == :kw
push!(args, Expr(:(=), arg.args[1], arg.args[2]))
end
return args
end

_valid_model(::AbstractModel, ::Any) = nothing

function _valid_model(m::M, name) where {M}
Expand Down Expand Up @@ -654,33 +620,6 @@ function _wrap_let(model, code)
return code
end

function _get_kwarg_value(
error_fn,
kwargs,
key::Symbol;
default = nothing,
escape::Bool = true,
)
index, count = 0, 0
for (i, kwarg) in enumerate(kwargs)
if kwarg.args[1] == key
count += 1
index = i
end
end
if count == 0
return default
elseif count == 1
if escape
return esc(kwargs[index].args[2])
else
return kwargs[index].args[2]
end
else
error_fn("`$key` keyword argument was given $count times.")
end
end

include("macros/@objective.jl")
include("macros/@expression.jl")
include("macros/@constraint.jl")
Expand Down
32 changes: 28 additions & 4 deletions src/macros/@NL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,30 @@ function _parse_generator_expression(code, x, operators)
return y
end

"""
_extract_kw_args(args)

Process the arguments to a macro, separating out the keyword arguments.

Return a tuple of (flat_arguments, keyword arguments, and requested_container),
where `requested_container` is a symbol to be passed to `container_code`.
"""
function _extract_kw_args(args)
flat_args, kw_args, requested_container = Any[], Any[], :Auto
for arg in args
if Meta.isexpr(arg, :(=))
if arg.args[1] == :container
requested_container = arg.args[2]
else
push!(kw_args, arg)
end
else
push!(flat_args, arg)
end
end
return flat_args, kw_args, requested_container
end

###
### @NLobjective(s)
###
Expand Down Expand Up @@ -252,7 +276,7 @@ macro NLobjective(model, sense, x)
function error_fn(str...)
return _macro_error(:NLobjective, (model, sense, x), __source__, str...)
end
sense_expr = _moi_sense(error_fn, sense)
sense_expr = _parse_moi_sense(error_fn, sense)
esc_model = esc(model)
parsing_code, expr = _parse_nonlinear_expression(esc_model, x)
code = quote
Expand Down Expand Up @@ -299,7 +323,7 @@ macro NLconstraint(m, x, args...)
# Two formats:
# - @NLconstraint(m, a*x <= 5)
# - @NLconstraint(m, myref[a=1:5], sin(x^a) <= 5)
extra, kw_args, requested_container = Containers._extract_kw_args(args)
extra, kw_args, requested_container = _extract_kw_args(args)
if length(extra) > 1 || length(kw_args) > 0
error_fn("too many arguments.")
end
Expand Down Expand Up @@ -413,7 +437,7 @@ subexpression[5]: log(1.0 + (exp(subexpression[2]) + exp(subexpression[3])))
"""
macro NLexpression(args...)
error_fn(str...) = _macro_error(:NLexpression, args, __source__, str...)
args, kw_args, requested_container = Containers._extract_kw_args(args)
args, kw_args, requested_container = _extract_kw_args(args)
if length(args) <= 1
error_fn(
"To few arguments ($(length(args))); must pass the model and nonlinear expression as arguments.",
Expand Down Expand Up @@ -577,7 +601,7 @@ macro NLparameter(model, args...)
function error_fn(str...)
return _macro_error(:NLparameter, (model, args...), __source__, str...)
end
pos_args, kw_args, requested_container = Containers._extract_kw_args(args)
pos_args, kw_args, requested_container = _extract_kw_args(args)
value = missing
for arg in kw_args
if arg.args[1] == :value
Expand Down
45 changes: 27 additions & 18 deletions src/macros/@constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ user syntax: `@constraint(model, ref[...], expr, my_arg, kwargs...)`.
"""
macro constraint(input_args...)
error_fn(str...) = _macro_error(:constraint, input_args, __source__, str...)
args, kwargs, container = Containers._extract_kw_args(input_args)
args, kwargs = Containers.parse_macro_arguments(error_fn, input_args)
if length(args) < 2 && !isempty(kwargs)
error_fn(
"No constraint expression detected. If you are trying to " *
Expand All @@ -102,13 +102,12 @@ macro constraint(input_args...)
# [1:2] | Expr | :vect
# [i = 1:2, j = 1:2; i + j >= 3] | Expr | :vcat
# a constraint expression | Expr | :call or :comparison
c, x = if y isa Symbol || Meta.isexpr(y, (:vect, :vcat, :ref, :typed_vcat))
c, x = nothing, y
if y isa Symbol || Meta.isexpr(y, (:vect, :vcat, :ref, :typed_vcat))
if length(extra) == 0
error_fn("No constraint expression was given.")
end
y, popfirst!(extra)
else
nothing, y
c, x = y, popfirst!(extra)
end
if length(extra) > 1
error_fn("Cannot specify more than 1 additional positional argument.")
Expand All @@ -122,20 +121,30 @@ macro constraint(input_args...)
end
is_vectorized, parse_code, build_call = parse_constraint(error_fn, x)
_add_positional_args(build_call, extra)
_add_kw_args(build_call, kwargs; exclude = [:base_name, :set_string_name])
base_name = _get_kwarg_value(
error_fn,
kwargs,
:base_name;
default = string(something(Containers._get_name(c), "")),
)
set_name_flag = _get_kwarg_value(
error_fn,
kwargs,
:set_string_name;
default = :(set_string_names_on_creation($model)),
_add_keyword_args(
build_call,
kwargs;
exclude = [:base_name, :container, :set_string_name],
)
name_expr = :($set_name_flag ? $(_name_call(base_name, index_vars)) : "")
# ; base_name
default_base_name = string(something(Containers._get_name(c), ""))
base_name = get(kwargs, :base_name, default_base_name)
if base_name isa Expr
base_name = esc(base_name)
end
# ; container
# There is no need to escape this one.
container = get(kwargs, :container, :Auto)
# ; set_string_name
name_expr = _name_call(base_name, index_vars)
if name_expr != ""
set_string_name = if haskey(kwargs, :set_string_name)
esc(kwargs[:set_string_name])
else
:(set_string_names_on_creation($model))
end
name_expr = :($set_string_name ? $name_expr : "")
end
code = if is_vectorized
quote
$parse_code
Expand Down
17 changes: 11 additions & 6 deletions src/macros/@expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,17 @@ julia> expr = @expression(model, [i in 1:3], i * sum(x[j] for j in 1:3))
"""
macro expression(input_args...)
error_fn(str...) = _macro_error(:expression, input_args, __source__, str...)
args, kw_args, container = Containers._extract_kw_args(input_args)
args, kwargs = Containers.parse_macro_arguments(error_fn, input_args)
if !(2 <= length(args) <= 3)
error_fn("needs at least two arguments.")
elseif !isempty(kw_args)
error_fn("unrecognized keyword argument")
error_fn("expected 2 or 3 positional arguments, got $(length(args)).")
elseif Meta.isexpr(args[2], :block)
error_fn("Invalid syntax. Did you mean to use `@expressions`?")
elseif !isempty(kwargs)
for key in keys(kwargs)
if key != :container
error_fn("unsupported keyword argument `$key`.")
end
end
end
name_expr = length(args) == 3 ? args[2] : nothing
index_vars, indices = Containers.build_ref_sets(error_fn, name_expr)
Expand All @@ -70,14 +74,15 @@ macro expression(input_args...)
"different name for the index.",
)
end
expr_var, build_code = _rewrite_expression(args[end])
model = esc(args[1])
expr, build_code = _rewrite_expression(args[end])
code = quote
$build_code
# Don't leak a `_MA.Zero` if the expression is an empty summation, or
# other structure that returns `_MA.Zero()`.
_replace_zero($model, $expr_var)
_replace_zero($model, $expr)
end
container = get(kwargs, :container, :Auto)
return _finalize_macro(
model,
Containers.container_code(index_vars, indices, code, container),
Expand Down
Loading
Loading