diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index bde1c5255..7919d223b 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -232,7 +232,7 @@ function _jacobian_aux( f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts... ) - jac_blocks = map(eachindex(batched_seeds)) do a + jac = mapreduce(hcat, eachindex(batched_seeds)) do a dy_batch = pushforward( f_or_f!y..., pushforward_prep_same, @@ -247,8 +247,6 @@ function _jacobian_aux( end block end - - jac = reduce(hcat, jac_blocks) return jac end @@ -265,7 +263,7 @@ function _jacobian_aux( f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts... ) - jac_blocks = map(eachindex(batched_seeds)) do a + jac = mapreduce(vcat, eachindex(batched_seeds)) do a dx_batch = pullback( f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts... ) @@ -275,8 +273,6 @@ function _jacobian_aux( end block end - - jac = reduce(vcat, jac_blocks) return jac end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 025c69315..d5d9921fc 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -111,7 +111,7 @@ function hessian( f, hvp_prep, backend, x, batched_seeds[1], contexts... ) - hess_blocks = map(eachindex(batched_seeds)) do a + hess = mapreduce(hcat, eachindex(batched_seeds)) do a dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...) block = stack_vec_col(dg_batch) if N % B != 0 && a == lastindex(batched_seeds) @@ -119,8 +119,6 @@ function hessian( end block end - - hess = reduce(hcat, hess_blocks) return hess end diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index a10b0b68d..6929033c6 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -48,6 +48,7 @@ For `type_stability` and `benchmark`, the possible values are `:none`, `:prepare **Type stability options:** - `ignored_modules=nothing`: list of modules that JET.jl should ignore +- `function_filter`: filter for functions that JET.jl should ignore (with a reasonable default) **Benchmark options:** @@ -72,6 +73,11 @@ function test_differentiation( sparsity::Bool=true, # type stability options ignored_modules=nothing, + function_filter=if VERSION >= v"1.11" + @nospecialize(f) -> true + else + @nospecialize(f) -> f != Base.mapreduce_empty # fix for `mapreduce` in jacobian and hessian + end, # benchmark options count_calls::Bool=true, ) @@ -136,7 +142,8 @@ function test_differentiation( adapted_backend, scen; subset=type_stability, - ignored_modules=ignored_modules, + ignored_modules, + function_filter, ) end yield() diff --git a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl index 1afb4a2f8..01e173951 100644 --- a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl @@ -18,280 +18,319 @@ for op in ALL_OPS if op in [:derivative, :gradient, :jacobian] @eval function test_jet( - ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op(f, prep, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op( - f, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, prep, ba, x, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, res1, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res1, ba, x, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res1, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, res1, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res1, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, prep, ba, x, contexts...) return nothing end op == :gradient && continue @eval function test_jet( - ba::AbstractADType, scen::$S2out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S2out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, y, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, y, ba, x, contexts... - ) - (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, y, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, ba, x, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, y, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, y, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, y, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op( - f, y, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, y, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, y, prep, ba, x, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S2in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S2in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, y, res1, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, y, ba, x, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, ba, x, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, ba, x, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, y, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, y, res1, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, y, res1, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, y, res1, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, y, res1, prep, ba, x, contexts...) return nothing end elseif op in [:second_derivative, :hessian] @eval function test_jet( - ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op(f, prep, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op( - f, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, prep, ba, x, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res2, ba, x, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, res2, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res2, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, res2, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, res2, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res2, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, res2, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, res2, prep, ba, x, contexts...) return nothing end elseif op in [:pushforward, :pullback] @eval function test_jet( - ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, tang, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, ba, x, tang, contexts... - ) - (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, ba, x, tang, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op( - f, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, prep, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op( - f, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, prep, ba, x, tang, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, tang, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res1, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, ba, x, tang, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res1, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, res1, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res1, prep, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, prep, ba, x, tang, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S2out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S2out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, y, tang, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, y, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op( - f, y, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, ba, x, tang, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, y, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, y, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, y, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op( - f, y, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, y, prep, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, y, prep, ba, x, tang, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S2in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S2in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, y, tang, res1, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, y, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, ba, x, tang, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, y, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, y, res1, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, y, res1, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, y, res1, prep, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, y, res1, prep, ba, x, tang, contexts...) return nothing end elseif op in [:hvp] @eval function test_jet( - ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, tang, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, ba, x, tang, contexts... - ) (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, ba, x, tang, contexts...) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op( - f, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, prep, ba, x, tang, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, tang, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res2, ba, x, tang, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res2, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, res2, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res2, prep, ba, x, tang, contexts...) return nothing end end