Skip to content

Commit

Permalink
[Containers] add utilities to support extensions writing macros (#3620)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Dec 14, 2023
1 parent a7a65e6 commit 65d12dc
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 267 deletions.
10 changes: 9 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,15 @@ jump_api_reference = DocumenterReference.automatic_reference_documentation(;
"Containers.nested" => DocumenterReference.DOCTYPE_FUNCTION,
"Containers.vectorized_product" =>
DocumenterReference.DOCTYPE_FUNCTION,
"Containers.build_ref_sets" =>
"Containers.build_error_fn" =>
DocumenterReference.DOCTYPE_FUNCTION,
"Containers.parse_macro_arguments" =>
DocumenterReference.DOCTYPE_FUNCTION,
"Containers.parse_ref_sets" =>
DocumenterReference.DOCTYPE_FUNCTION,
"Containers.build_name_expr" =>
DocumenterReference.DOCTYPE_FUNCTION,
"Containers.add_additional_args" =>
DocumenterReference.DOCTYPE_FUNCTION,
"Containers.container_code" =>
DocumenterReference.DOCTYPE_FUNCTION,
Expand Down
278 changes: 235 additions & 43 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,8 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.

_get_name(c::Union{Symbol,Nothing}) = c
_get_name(c) = error("Expression `$c::$(typeof(c))` cannot be used as a name.")

function _get_name(c::Expr)
if Meta.isexpr(c, :vcat) || Meta.isexpr(c, :vect)
return nothing
elseif Meta.isexpr(c, :ref) || Meta.isexpr(c, :typed_vcat)
return _get_name(c.args[1])
end
return error("Expression $c cannot be used as a name.")
end

function _reorder_parameters(args)
if !Meta.isexpr(args[1], :parameters)
if isempty(args) || !Meta.isexpr(args[1], :parameters)
return args
end
args = collect(args)
Expand All @@ -29,32 +17,65 @@ function _reorder_parameters(args)
end

"""
parse_macro_arguments(error_fn::Function, args)
parse_macro_arguments(
error_fn::Function,
args;
valid_kwargs::Union{Nothing,Vector{Symbol}} = nothing,
num_positional_args::Union{Nothing,Int,UnitRange{Int}} = nothing,
)
Returns a `Tuple{Vector{Any},Dict{Symbol,Any}}` containing the ordered
positional arguments and a dictionary mapping the keyword arguments.
This specially handles the distinction of `@foo(key = value)` and
`@foo(; key = value)` in macros.
Throws an error if mulitple keyword arguments are passed with the same name.
An error is thrown if multiple keyword arguments are passed with the same key.
If `valid_kwargs` is a `Vector{Symbol}`, an error is thrown if a keyword is not
in `valid_kwargs`.
If `num_positional_args` is not nothing, an error is thrown if the number of
positional arguments is not in `num_positional_args`.
"""
function parse_macro_arguments(error_fn::Function, args)
pos_args, kw_args = Any[], Dict{Symbol,Any}()
function parse_macro_arguments(
error_fn::Function,
args;
valid_kwargs::Union{Nothing,Vector{Symbol}} = nothing,
num_positional_args::Union{Nothing,Int,UnitRange{Int}} = nothing,
)
pos_args, kwargs = Any[], Dict{Symbol,Any}()
for arg in _reorder_parameters(args)
if Meta.isexpr(arg, :(=), 2)
if haskey(kw_args, arg.args[1])
if haskey(kwargs, arg.args[1])
error_fn(
"the keyword argument `$(arg.args[1])` was given " *
"multiple times.",
)
elseif valid_kwargs !== nothing && !(arg.args[1] in valid_kwargs)
error_fn("unsupported keyword argument `$(arg.args[1])`.")
end
kw_args[arg.args[1]] = arg.args[2]
kwargs[arg.args[1]] = arg.args[2]
else
push!(pos_args, arg)
end
end
return pos_args, kw_args
if num_positional_args isa Int
n = length(pos_args)
if n != num_positional_args
error_fn(
"expected $num_positional_args positional arguments, got $n.",
)
end
elseif num_positional_args isa UnitRange{Int}
if !(length(pos_args) in num_positional_args)
a, b = num_positional_args.start, num_positional_args.stop
error_fn(
"expected $a to $b positional arguments, got $(length(pos_args)).",
)
end
end
return pos_args, kwargs
end

"""
Expand Down Expand Up @@ -204,8 +225,27 @@ function _has_dependent_sets(index_vars::Vector{Any}, index_sets::Vector{Any})
return false
end

function _container_name(error_fn::Function, x)
return error_fn("Expression `$x::$(typeof(x))` cannot be used as a name.")
end

_container_name(::Function, expr::Union{Symbol,Nothing}) = expr

function _container_name(error_fn::Function, expr::Expr)
if Meta.isexpr(expr, (:vcat, :vect))
return nothing
elseif Meta.isexpr(expr, (:ref, :typed_vcat))
return _container_name(error_fn, expr.args[1])
end
return error_fn("Expression $expr cannot be used as a name.")
end

"""
build_ref_sets(error_fn::Function, expr)
parse_ref_sets(
error_fn::Function,
expr;
invalid_index_variables::Vector{Symbol} = Symbol[],
)
Helper function for macros to construct container objects.
Expand All @@ -223,16 +263,64 @@ Helper function for macros to construct container objects.
## Returns
1. `index_vars`: a `Vector{Any}` of names for the index variables, e.g.,
1. `name`: the name of the container, if given, otherwise `nothing`
2. `index_vars`: a `Vector{Any}` of names for the index variables, e.g.,
`[:i, gensym(), :k]`. These may also be expressions, like `:((i, j))` from a
call like `:(x[(i, j) in S])`.
2. `indices`: an iterator over the indices, for example,
3. `indices`: an iterator over the indices, for example,
[`Containers.NestedIterator`](@ref)
## Example
See [`container_code`](@ref) for a worked example.
"""
function parse_ref_sets(
error_fn::Function,
expr::Union{Nothing,Symbol,Expr};
invalid_index_variables::Vector = Symbol[],
)
name = _container_name(error_fn, expr)
index_vars, indices = build_ref_sets(error_fn, expr)
for name in invalid_index_variables
if name in index_vars
error_fn(
"the index name `$name` conflicts with another variable in " *
"this scope. Use a different name for the index.",
)
end
end
return name, index_vars, indices
end

# This method is needed because Julia v1.10 prints LineNumberNode in the string
# representation of an expression.
function _strip_LineNumberNode(x::Expr)
if Meta.isexpr(x, :block)
return Expr(:block, filter(!Base.Fix2(isa, LineNumberNode), x.args)...)
end
return x
end

_strip_LineNumberNode(x) = x

"""
build_error_fn(macro_name, args, source)
Return a function that can be used in place of `Base.error`, but which
additionally prints the macro from which it was called.
"""
function build_error_fn(macro_name, args, source)
str_args = join(_strip_LineNumberNode.(args), ", ")
msg = "At $(source.file):$(source.line): `@$macro_name($str_args)`: "
error_fn(str...) = error(msg, str...)
return error_fn
end

"""
build_ref_sets(error_fn::Function, expr)
This function is deprecated. Use [`parse_ref_sets`](@ref) instead.
"""
function build_ref_sets(error_fn::Function, expr)
index_vars, index_sets, condition = _parse_ref_sets(error_fn, expr)
if any(_expr_is_splat, index_sets)
Expand Down Expand Up @@ -263,16 +351,107 @@ function build_ref_sets(error_fn::Function, expr)
return index_vars, indices
end

"""
add_additional_args(
call::Expr,
args::Vector,
kwargs::Dict{Symbol,Any};
kwarg_exclude::Vector{Symbol} = Symbol[],
)
Add the positional arguments `args` to the function call expression `call`,
escaping each argument expression.
This function is able to incorporate additional positional arguments to `call`s
that already have keyword arguments.
"""
function add_additional_args(
call::Expr,
args::Vector,
kwargs::Dict{Symbol,Any};
kwarg_exclude::Vector{Symbol} = Symbol[],
)
call_args = call.args
if Meta.isexpr(call, :.)
# call is broadcasted
call_args = call.args[2].args
end
# Cache all keyword arguments
kw_args = filter(Base.Fix2(Meta.isexpr, :kw), call_args)
# Remove keyowrd arguments from the end
filter!(!Base.Fix2(Meta.isexpr, :kw), call_args)
# Add the new positional arguments
append!(call_args, esc.(args))
# Re-add the cached keyword arguments back to the end
append!(call_args, kw_args)
for (key, value) in kwargs
if !(key in kwarg_exclude)
push!(call_args, esc(Expr(:kw, key, value)))
end
end
return
end

"""
build_name_expr(
name::Union{Symbol,Nothing},
index_vars::Vector,
kwargs::Dict{Symbol,Any},
)
Returns an expression for the name of a container element, where `name` and
`index_vars` are the values returned by [`parse_ref_sets`](@ref) and `kwargs`
is the dictionary returned by [`parse_macro_arguments`](@ref).
This assumes that the key in `kwargs` used to over-ride the name choice is
`:base_name`.
## Examples
```jldoctest
julia> Containers.build_name_expr(:x, [:i, :j], Dict{Symbol,Any}())
:(string("x", "[", string(\$(Expr(:escape, :i))), ",", string(\$(Expr(:escape, :j))), "]"))
julia> Containers.build_name_expr(nothing, [:i, :j], Dict{Symbol,Any}())
""
julia> Containers.build_name_expr(:y, [:i, :j], Dict{Symbol,Any}(:base_name => "y"))
:(string("y", "[", string(\$(Expr(:escape, :i))), ",", string(\$(Expr(:escape, :j))), "]"))
```
"""
function build_name_expr(
name::Union{Symbol,Nothing},
index_vars::Vector,
kwargs::Dict{Symbol,Any},
)
base_name = get(kwargs, :base_name, string(something(name, "")))
if base_name isa Expr
base_name = esc(base_name)
end
if isempty(index_vars) || base_name == ""
return base_name
end
expr = Expr(:call, :string, base_name, "[")
for index in index_vars
# Converting the arguments to strings before concatenating is faster:
# https://github.com/JuliaLang/julia/issues/29550.
push!(expr.args, :(string($(esc(index)))))
push!(expr.args, ",")
end
expr.args[end] = "]"
return expr
end

"""
container_code(
index_vars::Vector{Any},
indices::Expr,
code,
requested_container::Union{Symbol,Expr},
requested_container::Union{Symbol,Expr,Dict{Symbol,Any}},
)
Used in macros to construct a call to [`container`](@ref). This should be used
in conjunction with [`build_ref_sets`](@ref).
in conjunction with [`parse_ref_sets`](@ref).
## Arguments
Expand All @@ -285,7 +464,8 @@ in conjunction with [`build_ref_sets`](@ref).
* `requested_container`: passed to the third argument of [`container`](@ref).
For built-in JuMP types, choose one of `:Array`, `:DenseAxisArray`,
`:SparseAxisArray`, or `:Auto`. For a user-defined container, this expression
must evaluate to the correct type.
must evaluate to the correct type. You may also pass the `kwargs` dictionary
from [`parse_macro_arguments`](@ref).
!!! warning
In most cases, you should `esc(code)` before passing it to `container_code`.
Expand All @@ -294,17 +474,20 @@ in conjunction with [`build_ref_sets`](@ref).
```jldoctest
julia> macro foo(ref_sets, code)
index_vars, indices = Containers.build_ref_sets(error, ref_sets)
return Containers.container_code(
index_vars,
indices,
esc(code),
:Auto,
)
name, index_vars, indices =
Containers.parse_ref_sets(error, ref_sets)
@assert name !== nothing # Anonymous container not supported
container =
Containers.container_code(index_vars, indices, esc(code), :Auto)
return quote
\$(esc(name)) = \$container
end
end
@foo (macro with 1 method)
julia> @foo(x[i=1:2, j=["A", "B"]], j^i)
julia> @foo(x[i=1:2, j=["A", "B"]], j^i);
julia> x
2-dimensional DenseAxisArray{String,2,...} with index sets:
Dimension 1, Base.OneTo(2)
Dimension 2, ["A", "B"]
Expand Down Expand Up @@ -342,6 +525,16 @@ function container_code(
return Expr(:call, container, f, indices, container_type, index_vars)
end

function container_code(
index_vars::Vector{Any},
indices::Expr,
code,
kwargs::Dict{Symbol,Any},
)
container = get(kwargs, :container, :Auto)
return container_code(index_vars, indices, code, container)
end

"""
@container([i=..., j=..., ...], expr[, container = :Auto])
Expand Down Expand Up @@ -400,15 +593,14 @@ SparseAxisArray{Int64, 2, Tuple{Int64, Int64}} with 6 entries:
```
"""
macro container(input_args...)
args, kw_args = parse_macro_arguments(error, input_args)
container = get(kw_args, :container, :Auto)
@assert length(args) == 2
for key in keys(kw_args)
@assert key == :container
end
index_vars, indices = build_ref_sets(error, args[1])
code = container_code(index_vars, indices, esc(args[2]), container)
name = _get_name(args[1])
args, kwargs = parse_macro_arguments(
error,
input_args;
num_positional_args = 2,
valid_kwargs = [:container],
)
name, index_vars, indices = parse_ref_sets(error, args[1])
code = container_code(index_vars, indices, esc(args[2]), kwargs)
if name === nothing
return code
end
Expand Down
Loading

0 comments on commit 65d12dc

Please sign in to comment.