From 9be236abb96b97b67c2b9c57e349f3d18461944e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 11:13:35 +0200 Subject: [PATCH 01/18] Improve type stability tests and benchmarking --- .github/workflows/Test.yml | 2 +- .../test/Back/Enzyme/test.jl | 2 +- .../test/Back/ForwardDiff/test.jl | 5 +- .../test/Misc/ZeroBackends/test.jl | 8 +- DifferentiationInterfaceTest/Project.toml | 7 +- DifferentiationInterfaceTest/docs/src/api.md | 13 + .../docs/src/tutorial.md | 2 +- .../src/DifferentiationInterfaceTest.jl | 24 +- .../src/scenarios/default.jl | 2 + .../src/scenarios/scenario.jl | 11 - .../src/test_differentiation.jl | 211 +++---- .../src/tests/benchmark.jl | 23 +- .../src/tests/benchmark_eval.jl | 557 +++++++++++++----- .../src/tests/correctness_eval.jl | 59 +- .../src/tests/type_stability_eval.jl | 365 ++++++------ DifferentiationInterfaceTest/src/utils.jl | 4 +- .../test/zero_backends.jl | 25 +- 17 files changed, 802 insertions(+), 518 deletions(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 2192752b6..10b23be4c 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: true + fail-fast: false # TODO: toggle matrix: version: - "1.10" diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 8ccc09e73..be3786c0e 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -54,7 +54,7 @@ test_differentiation( AutoEnzyme(; mode=Enzyme.Forward), # TODO: add more default_scenarios(; include_batchified=false); correctness=false, - type_stability=true, + type_stability=:prepared, second_order=false, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index ed316be0e..14b8ea48e 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -27,10 +27,7 @@ test_differentiation( ); test_differentiation( - AutoForwardDiff(; chunksize=5); - correctness=false, - type_stability=(; preparation=true, prepared_op=true, unprepared_op=false), - logging=LOGGING, + AutoForwardDiff(; chunksize=5); correctness=false, type_stability=:full, logging=LOGGING ); test_differentiation( diff --git a/DifferentiationInterface/test/Misc/ZeroBackends/test.jl b/DifferentiationInterface/test/Misc/ZeroBackends/test.jl index 3895cb544..956bef3b9 100644 --- a/DifferentiationInterface/test/Misc/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Misc/ZeroBackends/test.jl @@ -22,7 +22,7 @@ test_differentiation( AutoZeroForward(), default_scenarios(; include_batchified=false, include_constantified=true); correctness=false, - type_stability=true, + type_stability=:full, logging=LOGGING, ) @@ -30,7 +30,7 @@ test_differentiation( AutoZeroReverse(), default_scenarios(; include_batchified=false, include_constantified=true); correctness=false, - type_stability=(; preparation=true, prepared_op=true, unprepared_op=false), + type_stability=:full, logging=LOGGING, ) @@ -41,7 +41,7 @@ test_differentiation( ], default_scenarios(; include_batchified=false, include_constantified=true); correctness=false, - type_stability=(; preparation=true, prepared_op=true, unprepared_op=true), + type_stability=:full, first_order=false, logging=LOGGING, ) @@ -50,7 +50,7 @@ test_differentiation( AutoSparse.(zero_backends, coloring_algorithm=GreedyColoringAlgorithm()), default_scenarios(; include_constantified=true); correctness=false, - type_stability=(; preparation=true, prepared_op=true, unprepared_op=false), + type_stability=:full, excluded=[:pushforward, :pullback, :gradient, :derivative, :hvp, :second_derivative], logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index b8bbe9e2a..032a37b28 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterfaceTest" uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.7.1" +version = "0.8.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -9,19 +9,18 @@ Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -31,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays" -DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"] +DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux", "Functors"] DifferentiationInterfaceTestJLArraysExt = "JLArrays" DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "ForwardDiff", "Lux", "LuxTestUtils"] DifferentiationInterfaceTestStaticArraysExt = "StaticArrays" diff --git a/DifferentiationInterfaceTest/docs/src/api.md b/DifferentiationInterfaceTest/docs/src/api.md index e11ab369f..9ba404c7a 100644 --- a/DifferentiationInterfaceTest/docs/src/api.md +++ b/DifferentiationInterfaceTest/docs/src/api.md @@ -15,6 +15,13 @@ DifferentiationInterfaceTest Scenario test_differentiation benchmark_differentiation +FIRST_ORDER +SECOND_ORDER +``` + +## Utilities + +```@docs DifferentiationBenchmarkDataRow ``` @@ -30,6 +37,12 @@ gpu_scenarios static_scenarios ``` +## Utilities + +```@docs +DifferentiationBenchmarkDataRow +``` + ## Internals This is not part of the public API. diff --git a/DifferentiationInterfaceTest/docs/src/tutorial.md b/DifferentiationInterfaceTest/docs/src/tutorial.md index 3fe3a8470..2aa890123 100644 --- a/DifferentiationInterfaceTest/docs/src/tutorial.md +++ b/DifferentiationInterfaceTest/docs/src/tutorial.md @@ -55,7 +55,7 @@ test_differentiation( backends, # the backends you want to compare scenarios, # the scenarios you defined, correctness=true, # compares values against the reference - type_stability=false, # checks type stability with JET.jl + type_stability=:none, # checks type stability with JET.jl detailed=true, # prints a detailed test set ) ``` diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 214995ee7..206cbf26e 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -51,15 +51,29 @@ using DifferentiationInterface: Rewrap import DifferentiationInterface as DI using DocStringExtensions -using Functors: fmap -using JET: JET +using JET: @test_opt using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent using ProgressMeter: ProgressUnknown, next! using Random: AbstractRNG, default_rng, rand! -using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm -import SparseMatrixColorings as SMC +using SparseArrays: SparseArrays, AbstractSparseMatrix, SparseMatrixCSC, nnz, spdiagm using Test: @testset, @test +""" + FIRST_ORDER = [:pushforward, :pullback, :derivative, :gradient, :jacobian] + +List of all first-order operators, to facilitate exclusion during tests. +""" +const FIRST_ORDER = [:pushforward, :pullback, :derivative, :gradient, :jacobian] + +""" + SECOND_ORDER = [:hvp, :second_derivative, :hessian] + +List of all second-order operators, to facilitate exclusion during tests. +""" +const SECOND_ORDER = [:hvp, :second_derivative, :hessian] + +const ALL_OPS = vcat(FIRST_ORDER, SECOND_ORDER) + include("utils.jl") include("scenarios/scenario.jl") @@ -71,11 +85,11 @@ include("scenarios/extensions.jl") include("tests/correctness_eval.jl") include("tests/type_stability_eval.jl") -include("tests/sparsity.jl") include("tests/benchmark.jl") include("tests/benchmark_eval.jl") include("test_differentiation.jl") +export FIRST_ORDER, SECOND_ORDER export Scenario export default_scenarios, sparse_scenarios export test_differentiation, benchmark_differentiation diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index 25b85eda6..74506155e 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -78,6 +78,8 @@ struct NumToArr{A} end NumToArr(::Type{A}) where {A} = NumToArr{A}() Base.eltype(::NumToArr{A}) where {A} = eltype(A) +Base.show(io::IO, ::NumToArr{A}) where {A} = print(io, "num_to_arr{$A}") + function (f::NumToArr{A})(x::Number) where {A} a = multiplicator(A) return sin.(x .* a) diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index 2de9281f1..3cbb44ec8 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -1,14 +1,3 @@ -const ALL_OPS = ( - :pushforward, - :pullback, - :derivative, - :gradient, - :jacobian, - :hessian, - :hvp, - :second_derivative, -) - """ Scenario{op,pl_op,pl_fun} diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index b96ffc84b..413c84922 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -1,86 +1,94 @@ -function filter_scenarios( - scenarios::Vector{<:Scenario}; - input_type::Type, - output_type::Type, - first_order::Bool, - second_order::Bool, - excluded::Vector{Symbol}, -) - scenarios = filter(s -> (s.x isa input_type && s.y isa output_type), scenarios) - !first_order && (scenarios = filter(s -> order(s) != 1, scenarios)) - !second_order && (scenarios = filter(s -> order(s) != 2, scenarios)) - scenarios = filter(s -> !(operator(s) in excluded), scenarios) - # sort for nice printing - scenarios = sort(scenarios; by=s -> (operator(s), string(s.f))) - return scenarios -end - """ $(TYPEDSIGNATURES) -Cross-test a list of `backends` on a list of `scenarios`, running a variety of different tests. +Apply a list of `backends` on a list of `scenarios`, running a variety of different tests and/or benchmarks. + +# Return + +This function always creates and runs a `@testset`, though its contents may vary. + +- if `benchmark == :none`, it returns `nothing`. +- if `benchmark != :none`, it returns a `DataFrame` of benchmark results, whose columns correspond to the fields of [`DifferentiationBenchmarkDataRow`](@ref). -# Default arguments +# Positional arguments -- `scenarios::Vector{<:Scenario}`: the output of [`default_scenarios()`](@ref) +- `backends::Vector{<:AbstractADType}`: the backends to test +- `scenarios::Vector{<:Scenario}`: the scenarios on which to test them (defaults to the output of [`default_scenarios()`](@ref)) # Keyword arguments -Testing: +**Test categories:** - `correctness=true`: whether to compare the differentiation results with the theoretical values specified in each scenario -- `type_stability=false`: whether to check type stability of operators with JET.jl (thanks to `JET.@test_opt`). It can be either a `Bool` or a more detailed named tuple `(; preparation, prepared_op, unprepared_op)` to specify which variants should be analyzed. -- `sparsity`: whether to check sparsity of the jacobian / hessian -- `detailed=false`: whether to print a detailed or condensed test log +- `type_stability=:none`: whether (and how) to check type stability of operators with JET.jl. +- `benchmark=:none`: whether (and how) to benchmark operators with Chairmarks.jl -Filtering: +For `type_stability` and `benchmark`, the possible values are `:none`, `:prepared` or `:full`, each concerns a different subset of calls: -- `input_type=Any`, `output_type=Any`: restrict scenario inputs / outputs to subtypes of this -- `first_order=true`, `second_order=true`: include first order / second order operators +| kwarg | prepared operator | unprepared operator | preparation | +|---|---|---|---| +| `:none` | no | no | no | +| `:prepared` | yes | no | no | +| `:full` | yes | yes | yes | -Options: +**Misc options:** +- `excluded::Vector{Symbol}`: list of operators to exclude, such as [`FIRST_ORDER`](@ref) or [`SECOND_ORDER`](@ref) +- `detailed=false`: whether to create a detailed or condensed testset - `logging=false`: whether to log progress + +**Correctness options:** + - `isapprox=isapprox`: function used to compare objects approximately, with the standard signature `isapprox(x, y; atol, rtol)` - `atol=0`: absolute precision for correctness testing (when comparing to the reference outputs) - `rtol=1e-3`: relative precision for correctness testing (when comparing to the reference outputs) - `scenario_intact=true`: whether to check that the scenario remains unchanged after the operators are applied +- `sparsity`: whether to check sparsity patterns for Jacobians / Hessians + +**Type stability options:** + +- `ignored_modules=nothing`: list of modules that JET.jl should ignore + +**Benchmark options:** + +- `count_calls::Bool`: whether to also count function calls during benchmarking """ function test_differentiation( backends::Vector{<:AbstractADType}, scenarios::Vector{<:Scenario}=default_scenarios(); - # testing + # test categories correctness::Bool=true, - type_stability=false, - call_count::Bool=false, - sparsity::Bool=false, - detailed=false, - # filtering - input_type::Type=Any, - output_type::Type=Any, - first_order::Bool=true, - second_order::Bool=true, + type_stability::Symbol=:none, + benchmark::Symbol=:none, + # misc options excluded::Vector{Symbol}=Symbol[], - # options + detailed::Bool=false, logging::Bool=false, + # correctness options isapprox=isapprox, atol::Real=0, rtol::Real=1e-3, scenario_intact::Bool=true, + sparsity::Bool=true, + # type stability options + ignored_modules=nothing, + # benchmark options + count_calls::Bool=true, ) - scenarios = filter_scenarios( - scenarios; first_order, second_order, input_type, output_type, excluded - ) + @assert type_stability in (:none, :prepared, :full) + @assert benchmark in (:none, :prepared, :full) - bool_type_stability = (type_stability == true || type_stability isa NamedTuple) + scenarios = filter(s -> !(operator(s) in excluded), scenarios) + scenarios = sort(scenarios; by=s -> (operator(s), string(s.f))) title_additions = - (correctness != false ? " + correctness" : "") * - (call_count ? " + calls" : "") * - (bool_type_stability ? " + type stability" : "") * - (sparsity ? " + sparsity" : "") + (correctness ? " + correctness" : "") * + ((type_stability != :none) ? " + type stability" : "") * + ((benchmark != :none) ? " + benchmarks" : "") title = "Testing" * title_additions[3:end] + benchmark_data = DifferentiationBenchmarkDataRow[] + prog = ProgressUnknown(; desc="$title", spinner=true, enabled=logging) @testset verbose = true "$title" begin @@ -113,26 +121,40 @@ function test_differentiation( adapted_backend = adapt_batchsize(backend, scen) correctness && @testset "Correctness" begin test_correctness( - adapted_backend, scen; isapprox, atol, rtol, scenario_intact + adapted_backend, + scen; + isapprox, + atol, + rtol, + scenario_intact, + sparsity, ) end - kwargs_type_stability = if type_stability isa NamedTuple - type_stability - else - (; preparation=false, prepared_op=type_stability, unprepared_op=false) - end - bool_type_stability && @testset "Type stability" begin - test_jet(adapted_backend, scen; kwargs_type_stability...) + yield() + (type_stability != :none) && @testset "Type stability" begin + test_jet(adapted_backend, scen; subset=type_stability, ignored_modules) end - sparsity && @testset "Sparsity" begin - test_sparsity(adapted_backend, scen) + yield() + (benchmark != :none) && @testset "Benchmark" begin + run_benchmark!( + benchmark_data, + adapted_backend, + scen; + logging, + subset=benchmark, + count_calls, + ) end yield() end end end end - return nothing + if benchmark != :none + return DataFrame(benchmark_data) + else + return nothing + end end """ @@ -147,73 +169,24 @@ end """ $(TYPEDSIGNATURES) -Benchmark a list of `backends` for a list of `operators` on a list of `scenarios`. - -The object returned is a `DataFrames.DataFrame` where each column corresponds to a field of [`DifferentiationBenchmarkDataRow`](@ref). +Shortcut for [`test_differentiation`](@ref) with only benchmarks and no correctness or type stability checks. -The keyword arguments available here have the same meaning as those in [`test_differentiation`](@ref). +Specifying the set of scenarios is mandatory for this function. """ function benchmark_differentiation( - backends::Vector{<:AbstractADType}, + backends, scenarios::Vector{<:Scenario}; - # filtering - input_type::Type=Any, - output_type::Type=Any, - first_order::Bool=true, - second_order::Bool=true, + benchmark::Symbol=:full, excluded::Vector{Symbol}=Symbol[], - # options logging::Bool=false, ) - scenarios = filter_scenarios( - scenarios; first_order, second_order, input_type, output_type, excluded + return test_differentiation( + backends, + scenarios; + correctness=false, + type_stability=:none, + benchmark, + logging, + excluded, ) - - benchmark_data = DifferentiationBenchmarkDataRow[] - prog = ProgressUnknown(; desc="Benchmarking", spinner=true, enabled=logging) - for (i, backend) in enumerate(backends) - filtered_scenarios = filter(s -> compatible(backend, s), scenarios) - grouped_scenarios = group_by_operator(filtered_scenarios) - for (j, (op, op_group)) in enumerate(pairs(grouped_scenarios)) - for (k, scen) in enumerate(op_group) - next!( - prog; - showvalues=[ - (:backend, "$backend - $i/$(length(backends))"), - (:scenario_type, "$op - $j/$(length(grouped_scenarios))"), - (:scenario, "$k/$(length(op_group))"), - (:operator_place, operator_place(scen)), - (:function_place, function_place(scen)), - (:function, scen.f), - (:input_type, typeof(scen.x)), - (:input_size, mysize(scen.x)), - (:output_type, typeof(scen.y)), - (:output_size, mysize(scen.y)), - (:nb_tangents, scen.tang isa NTuple ? length(scen.tang) : nothing), - (:nb_contexts, length(scen.contexts)), - ], - ) - adapted_backend = adapt_batchsize(backend, scen) - run_benchmark!(benchmark_data, adapted_backend, scen; logging) - yield() - end - end - end - return DataFrame(benchmark_data) -end - -""" - test_allocfree(benchmark_data::DataFrame) - -Test that every row in `benchmark_data` which is not a preparation row has zero allocation. -""" -function test_allocfree(benchmark_data::DataFrame) - preparation_rows = startswith.(string.(benchmark_data[!, :operator]), Ref("prepare")) - useful_data = benchmark_data[.!preparation_rows, :] - - @testset verbose = true "No allocations" begin - @testset "$(row[:scenario]) - $(row[:operator])" for row in eachrow(useful_data) - @test row[:allocs] == 0 - end - end end diff --git a/DifferentiationInterfaceTest/src/tests/benchmark.jl b/DifferentiationInterfaceTest/src/tests/benchmark.jl index 4aaadf85a..859fe82da 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark.jl @@ -45,25 +45,14 @@ function failed_bench() return Benchmark([sample]) end -function failed_benchs(k::Integer) - return ntuple(i -> failed_bench(), k) -end +failed_benchs(k::Integer) = ntuple(i -> failed_bench(), k) +failed_calls(k::Integer) = ntuple(i -> -1, k) """ DifferentiationBenchmarkDataRow Ad-hoc storage type for differentiation benchmarking results. -If you have a vector `rows::Vector{DifferentiationBenchmarkDataRow}`, you can turn it into a `DataFrame` as follows: - -```julia -using DataFrames - -df = DataFrame(rows) -``` - -The resulting `DataFrame` will have one column for each of the following fields. - # Fields $(TYPEDFIELDS) @@ -77,6 +66,8 @@ Base.@kwdef struct DifferentiationBenchmarkDataRow scenario::Scenario "differentiation operator used for benchmarking, e.g. `:gradient` or `:hessian`" operator::Symbol + "whether the operator had been prepared" + prepared::Union{Nothing,Bool} "number of calls to the differentiated function for one call to the operator" calls::Int "number of benchmarking samples taken" @@ -96,10 +87,11 @@ Base.@kwdef struct DifferentiationBenchmarkDataRow end function record!( - data::Vector{DifferentiationBenchmarkDataRow}, + data::Vector{DifferentiationBenchmarkDataRow}; backend::AbstractADType, scenario::Scenario, - operator, + operator::String, + prepared::Union{Nothing,Bool}, bench::Benchmark, calls::Integer, ) @@ -108,6 +100,7 @@ function record!( backend=backend, scenario=scenario, operator=Symbol(operator), + prepared=prepared, calls=calls, samples=length(bench.samples), evals=Int(bench_min.evals), diff --git a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl index 9116d7d4c..2ecf0bed6 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl @@ -1,14 +1,20 @@ +@kwdef struct BenchmarkResult + prepared_valop::Benchmark = failed_bench() + prepared_op::Benchmark = failed_bench() + preparation::Benchmark = failed_bench() + unprepared_valop::Benchmark = failed_bench() + unprepared_op::Benchmark = failed_bench() +end + +@kwdef struct CallsResult + preparation::Int = -1 + prepared_valop::Int = -1 + prepared_op::Int = -1 + unprepared_valop::Int = -1 + unprepared_op::Int = -1 +end -for op in [ - :derivative, - :gradient, - :hessian, - :hvp, - :jacobian, - :pullback, - :pushforward, - :second_derivative, -] +for op in ALL_OPS op! = Symbol(op, "!") val_prefix = if op == :second_derivative "value_derivative_and_" @@ -31,257 +37,506 @@ for op in [ backend::AbstractADType, scenario::Union{$S1out,$S1in,$S2out,$S2in}; logging::Bool, + subset::Symbol, + count_calls::Bool, ) - (; bench0, bench1, bench2, calls0, calls1, calls2) = try - run_benchmark_aux(backend, scenario) + @assert subset in (:full, :prepared) + + bench_result = try + benchmark_aux(backend, scenario; subset) catch exception logging && @warn "Error during benchmarking" backend scenario exception - bench0, bench1, bench2 = failed_benchs(3) - calls0, calls1, calls2 = -1, -1, -1 - (; bench0, bench1, bench2, calls0, calls1, calls2) + BenchmarkResult() + end + + if count_calls + calls_result = try + calls_aux(backend, scenario; subset) + catch exception + logging && @warn "Error during call counting" backend scenario exception + CallsResult() + end + else + calls_result = CallsResult() end - record!(data, backend, scenario, $prep_op, bench0, calls0) + + prep_string = $(string(prep_op)) if scenario isa Union{$S1out,$S2out} - record!(data, backend, scenario, $(string(val_and_op)), bench1, calls1) - record!(data, backend, scenario, $(string(op)), bench2, calls2) - elseif scenario isa Union{$S1in,$S2in} - record!(data, backend, scenario, $(string(val_and_op!)), bench1, calls1) - record!(data, backend, scenario, $(string(op!)), bench2, calls2) + valop_string = $(string(val_and_op)) + op_string = $(string(op)) + else + valop_string = $(string(val_and_op!)) + op_string = $(string(op!)) + end + + record!( + data; + backend, + scenario, + operator=valop_string, + prepared=true, + bench=bench_result.prepared_valop, + calls=calls_result.prepared_valop, + ) + record!( + data; + backend, + scenario, + operator=op_string, + prepared=true, + bench=bench_result.prepared_op, + calls=calls_result.prepared_op, + ) + if subset == :full + record!( + data; + backend, + scenario, + operator=prep_string, + prepared=nothing, + bench=bench_result.preparation, + calls=calls_result.preparation, + ) + record!( + data; + backend, + scenario, + operator=valop_string, + prepared=false, + bench=bench_result.unprepared_valop, + calls=calls_result.unprepared_valop, + ) + record!( + data; + backend, + scenario, + operator=op_string, + prepared=false, + bench=bench_result.unprepared_op, + calls=calls_result.unprepared_op, + ) end return nothing end if op in [:derivative, :gradient, :jacobian] - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S1out) + @eval function benchmark_aux(ba::AbstractADType, scen::$S1out; subset::Symbol) (; f, x, contexts) = deepcopy(scen) - # benchmark prep = $prep_op(f, ba, x, contexts...) - bench0 = @be $prep_op(f, ba, x, contexts...) samples = 1 evals = 1 - bench1 = @be prep $val_and_op(f, _, ba, x, contexts...) evals = 1 - bench2 = @be prep $op(f, _, ba, x, contexts...) evals = 1 - # count + prepared_valop = @be prep $val_and_op(f, _, ba, x, contexts...) + prepared_op = @be prep $op(f, _, ba, x, contexts...) + if subset == :full + preparation = @be $prep_op(f, ba, x, contexts...) + unprepared_valop = @be $val_and_op(f, ba, x, contexts...) + unprepared_op = @be $op(f, ba, x, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + + @eval function calls_aux(ba::AbstractADType, scen::$S1out; subset::Symbol) + (; f, x, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, ba, x, contexts...) - calls0 = reset_count!(cc) + preparation = reset_count!(cc) $val_and_op(cc, prep, ba, x, contexts...) - calls1 = reset_count!(cc) + prepared_valop = reset_count!(cc) $op(cc, prep, ba, x, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S1in) + @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) (; f, x, res1, contexts) = deepcopy(scen) - # benchmark prep = $prep_op(f, ba, x, contexts...) - bench0 = @be $prep_op(f, ba, x, contexts...) samples = 1 evals = 1 - bench1 = @be (res1, prep) $val_and_op!(f, _[1], _[2], ba, x, contexts...) evals = - 1 - bench2 = @be (res1, prep) $op!(f, _[1], _[2], ba, x, contexts...) evals = 1 - # count + prepared_valop = @be (res1, prep) $val_and_op!( + f, _[1], _[2], ba, x, contexts... + ) + prepared_op = @be (res1, prep) $op!(f, _[1], _[2], ba, x, contexts...) + if subset == :full + preparation = @be $prep_op(f, ba, x, contexts...) + unprepared_valop = @be res1 $val_and_op!(f, _, ba, x, contexts...) + unprepared_op = @be res1 $op!(f, _, ba, x, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + + @eval function calls_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) + (; f, x, res1, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, ba, x, contexts...) - calls0 = reset_count!(cc) + preparation = reset_count!(cc) $val_and_op!(cc, res1, prep, ba, x, contexts...) - calls1 = reset_count!(cc) + prepared_valop = reset_count!(cc) $op!(cc, res1, prep, ba, x, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end op == :gradient && continue - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S2out) + @eval function benchmark_aux(ba::AbstractADType, scen::$S2out; subset::Symbol) (; f, x, y, contexts) = deepcopy(scen) - # benchmark prep = $prep_op(f, y, ba, x, contexts...) - bench0 = @be $prep_op(f, y, ba, x, contexts...) samples = 1 evals = 1 - bench1 = @be (y, prep) $val_and_op(f, _[1], _[2], ba, x, contexts...) evals = 1 - bench2 = @be (y, prep) $op(f, _[1], _[2], ba, x, contexts...) evals = 1 - # count + prepared_valop = @be (y, prep) $val_and_op(f, _[1], _[2], ba, x, contexts...) + prepared_op = @be (y, prep) $op(f, _[1], _[2], ba, x, contexts...) + if subset == :full + preparation = @be $prep_op(f, y, ba, x, contexts...) + unprepared_valop = @be y $val_and_op(f, _, ba, x, contexts...) + unprepared_op = @be y $op(f, _, ba, x, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + + @eval function calls_aux(ba::AbstractADType, scen::$S2out; subset::Symbol) + (; f, x, y, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, y, ba, x, contexts...) - calls0 = reset_count!(cc) + preparation = reset_count!(cc) $val_and_op(cc, y, prep, ba, x, contexts...) - calls1 = reset_count!(cc) + prepared_valop = reset_count!(cc) $op(cc, y, prep, ba, x, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S2in) + @eval function benchmark_aux(ba::AbstractADType, scen::$S2in; subset::Symbol) (; f, x, y, res1, contexts) = deepcopy(scen) - # benchmark prep = $prep_op(f, y, ba, x, contexts...) - bench0 = @be $prep_op(f, y, ba, x, contexts...) samples = 1 evals = 1 - bench1 = @be (y, res1, prep) $val_and_op!( + prepared_valop = @be (y, res1, prep) $val_and_op!( f, _[1], _[2], _[3], ba, x, contexts... - ) evals = 1 - bench2 = @be (y, res1, prep) $op!(f, _[1], _[2], _[3], ba, x, contexts...) evals = - 1 - # count + ) + prepared_op = @be (y, res1, prep) $op!(f, _[1], _[2], _[3], ba, x, contexts...) + if subset == :full + preparation = @be $prep_op(f, y, ba, x, contexts...) + unprepared_valop = @be (y, res1) $val_and_op!( + f, _[1], _[2], ba, x, contexts... + ) + unprepared_op = @be (y, res1) $op!(f, _[1], _[2], ba, x, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + + @eval function calls_aux(ba::AbstractADType, scen::$S2in; subset::Symbol) + (; f, x, y, res1, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, y, ba, x, contexts...) - calls0 = reset_count!(cc) + preparation = reset_count!(cc) $val_and_op!(cc, y, res1, prep, ba, x, contexts...) - calls1 = reset_count!(cc) + prepared_valop = reset_count!(cc) $op!(cc, y, res1, prep, ba, x, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end elseif op in [:hessian, :second_derivative] - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S1out) + @eval function benchmark_aux(ba::AbstractADType, scen::$S1out; subset::Symbol) (; f, x, contexts) = deepcopy(scen) - # benchmark prep = $prep_op(f, ba, x, contexts...) - bench0 = @be $prep_op(f, ba, x, contexts...) samples = 1 evals = 1 - bench1 = @be prep $val_and_op(f, _, ba, x, contexts...) evals = 1 - bench2 = @be prep $op(f, _, ba, x, contexts...) evals = 1 - # count + prepared_valop = @be prep $val_and_op(f, _, ba, x, contexts...) + prepared_op = @be prep $op(f, _, ba, x, contexts...) + if subset == :full + preparation = @be $prep_op(f, ba, x, contexts...) + unprepared_valop = @be $val_and_op(f, ba, x, contexts...) + unprepared_op = @be $op(f, ba, x, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + + @eval function calls_aux(ba::AbstractADType, scen::$S1out; subset::Symbol) + (; f, x, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, ba, x, contexts...) - calls0 = reset_count!(cc) + preparation = reset_count!(cc) $val_and_op(cc, prep, ba, x, contexts...) - calls1 = reset_count!(cc) + prepared_valop = reset_count!(cc) $op(cc, prep, ba, x, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S1in) + @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) (; f, x, res1, res2, contexts) = deepcopy(scen) - # benchmark + prep = $prep_op(f, ba, x, contexts...) - bench0 = @be $prep_op(f, ba, x, contexts...) samples = 1 evals = 1 - bench1 = @be (res1, res2, prep) $val_and_op!( + prepared_valop = @be (res1, res2, prep) $val_and_op!( f, _[1], _[2], _[3], ba, x, contexts... - ) evals = 1 - bench2 = @be (res2, prep) $op!(f, _[1], _[2], ba, x, contexts...) evals = 1 - # count + ) + prepared_op = @be (res2, prep) $op!(f, _[1], _[2], ba, x, contexts...) + if subset == :full + preparation = @be $prep_op(f, ba, x, contexts...) + unprepared_valop = @be (res1, res2) $val_and_op!( + f, _[1], _[2], ba, x, contexts... + ) + unprepared_op = @be res2 $op!(f, _, ba, x, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + + @eval function calls_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) + (; f, x, res1, res2, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, ba, x, contexts...) - calls0 = reset_count!(cc) + preparation = reset_count!(cc) $val_and_op!(cc, res1, res2, prep, ba, x, contexts...) - calls1 = reset_count!(cc) + prepared_valop = reset_count!(cc) $op!(cc, res2, prep, ba, x, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end elseif op in [:pushforward, :pullback] - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S1out) + @eval function benchmark_aux(ba::AbstractADType, scen::$S1out; subset::Symbol) (; f, x, tang, contexts) = deepcopy(scen) - # benchmark prep = $prep_op(f, ba, x, tang, contexts...) - bench0 = @be $prep_op(f, ba, x, tang, contexts...) samples = 1 evals = 1 - bench1 = @be prep $val_and_op(f, _, ba, x, tang, contexts...) evals = 1 - bench2 = @be prep $op(f, _, ba, x, tang, contexts...) evals = 1 - # count + prepared_valop = @be prep $val_and_op(f, _, ba, x, tang, contexts...) + prepared_op = @be prep $op(f, _, ba, x, tang, contexts...) + if subset == :full + preparation = @be $prep_op(f, ba, x, tang, contexts...) + unprepared_valop = @be $val_and_op(f, ba, x, tang, contexts...) + unprepared_op = @be $op(f, ba, x, tang, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + @eval function calls_aux(ba::AbstractADType, scen::$S1out; subset::Symbol) + (; f, x, tang, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, ba, x, tang, contexts...) - calls0 = reset_count!(cc) + preparation = reset_count!(cc) $val_and_op(cc, prep, ba, x, tang, contexts...) - calls1 = reset_count!(cc) + prepared_valop = reset_count!(cc) $op(cc, prep, ba, x, tang, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S1in) + @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) (; f, x, tang, res1, contexts) = deepcopy(scen) - # benchmark prep = $prep_op(f, ba, x, tang, contexts...) - bench0 = @be $prep_op(f, ba, x, tang, contexts...) samples = 1 evals = 1 - bench1 = @be (res1, prep) $val_and_op!(f, _[1], _[2], ba, x, tang, contexts...) evals = - 1 - bench2 = @be (res1, prep) $op!(f, _[1], _[2], ba, x, tang, contexts...) evals = - 1 - # count + prepared_valop = @be (res1, prep) $val_and_op!( + f, _[1], _[2], ba, x, tang, contexts... + ) + prepared_op = @be (res1, prep) $op!(f, _[1], _[2], ba, x, tang, contexts...) + if subset == :full + preparation = @be $prep_op(f, ba, x, tang, contexts...) + unprepared_valop = @be res1 $val_and_op!(f, _, ba, x, tang, contexts...) + unprepared_op = @be res1 $op!(f, _, ba, x, tang, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + + @eval function calls_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) + (; f, x, tang, res1, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, ba, x, tang, contexts...) - calls0 = reset_count!(cc) + preparation = reset_count!(cc) $val_and_op!(cc, res1, prep, ba, x, tang, contexts...) - calls1 = reset_count!(cc) + prepared_valop = reset_count!(cc) $op!(cc, res1, prep, ba, x, tang, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S2out) + @eval function benchmark_aux(ba::AbstractADType, scen::$S2out; subset::Symbol) (; f, x, y, tang, contexts) = deepcopy(scen) - # benchmark prep = $prep_op(f, y, ba, x, tang, contexts...) - bench0 = @be $prep_op(f, y, ba, x, tang, contexts...) samples = 1 evals = 1 - bench1 = @be (y, prep) $val_and_op(f, _[1], _[2], ba, x, tang, contexts...) evals = - 1 - bench2 = @be (y, prep) $op(f, _[1], _[2], ba, x, tang, contexts...) evals = 1 - # count + prepared_valop = @be (y, prep) $val_and_op( + f, _[1], _[2], ba, x, tang, contexts... + ) + prepared_op = @be (y, prep) $op(f, _[1], _[2], ba, x, tang, contexts...) + if subset == :full + preparation = @be $prep_op(f, y, ba, x, tang, contexts...) + unprepared_valop = @be y $val_and_op(f, _, ba, x, tang, contexts...) + unprepared_op = @be y $op(f, _, ba, x, tang, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + + @eval function calls_aux(ba::AbstractADType, scen::$S2out; subset::Symbol) + (; f, x, y, tang, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, y, ba, x, tang, contexts...) - calls0 = reset_count!(cc) + preparation = reset_count!(cc) $val_and_op(cc, y, prep, ba, x, tang, contexts...) - calls1 = reset_count!(cc) + prepared_valop = reset_count!(cc) $op(cc, y, prep, ba, x, tang, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S2in) + @eval function benchmark_aux(ba::AbstractADType, scen::$S2in; subset::Symbol) (; f, x, y, tang, res1, contexts) = deepcopy(scen) - # benchmark prep = $prep_op(f, y, ba, x, tang, contexts...) - bench0 = @be $prep_op(f, y, ba, x, tang, contexts...) samples = 1 evals = 1 - bench1 = @be (y, res1, prep) $val_and_op!( + prepared_valop = @be (y, res1, prep) $val_and_op!( + f, _[1], _[2], _[3], ba, x, tang, contexts... + ) + prepared_op = @be (y, res1, prep) $op!( f, _[1], _[2], _[3], ba, x, tang, contexts... - ) evals = 1 - bench2 = @be (y, res1, prep) $op!(f, _[1], _[2], _[3], ba, x, tang, contexts...) evals = - 1 - # count + ) + if subset == :full + preparation = @be $prep_op(f, y, ba, x, tang, contexts...) + unprepared_valop = @be (y, res1) $val_and_op!( + f, _[1], _[2], ba, x, tang, contexts... + ) + unprepared_op = @be (y, res1) $op!(f, _[1], _[2], ba, x, tang, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + + @eval function calls_aux(ba::AbstractADType, scen::$S2in; subset::Symbol) + (; f, x, y, tang, res1, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, y, ba, x, tang, contexts...) - calls0 = reset_count!(cc) + preparation = reset_count!(cc) $val_and_op!(cc, y, res1, prep, ba, x, tang, contexts...) - calls1 = reset_count!(cc) + prepared_valop = reset_count!(cc) $op!(cc, y, res1, prep, ba, x, tang, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end elseif op in [:hvp] - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S1out) + @eval function benchmark_aux(ba::AbstractADType, scen::$S1out; subset::Symbol) (; f, x, tang, contexts) = deepcopy(scen) - # benchmark prep = $prep_op(f, ba, x, tang, contexts...) - bench0 = @be $prep_op(f, ba, x, tang, contexts...) samples = 1 evals = 1 - bench1 = @be +(1, 1) evals = 1 # TODO: fix - bench2 = @be prep $op(f, _, ba, x, tang, contexts...) evals = 1 - # count + prepared_valop = @be +(1, 1) # TODO: fix + prepared_op = @be prep $op(f, _, ba, x, tang, contexts...) + if subset == :full + preparation = @be $prep_op(f, ba, x, tang, contexts...) + unprepared_valop = @be +(1, 1) # TODO: fix + unprepared_op = @be $op(f, ba, x, tang, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + + @eval function calls_aux(ba::AbstractADType, scen::$S1out; subset::Symbol) + (; f, x, tang, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, ba, x, tang, contexts...) - calls0 = reset_count!(cc) - calls1 = -1 # TODO: fix + preparation = reset_count!(cc) + prepared_valop = -1 # TODO: fix $op(cc, prep, ba, x, tang, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end - @eval function run_benchmark_aux(ba::AbstractADType, scen::$S1in) + @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) (; f, x, tang, res2, contexts) = deepcopy(scen) - # benchmark prep = $prep_op(f, ba, x, tang, contexts...) - bench0 = @be $prep_op(f, ba, x, tang, contexts...) samples = 1 evals = 1 - bench1 = @be +(1, 1) evals = 1 # TODO: fix - bench2 = @be (res2, prep) $op!(f, _[1], _[2], ba, x, tang, contexts...) evals = - 1 - # count + prepared_valop = @be +(1, 1) # TODO: fix + prepared_op = @be (res2, prep) $op!(f, _[1], _[2], ba, x, tang, contexts...) + if subset == :full + preparation = @be $prep_op(f, ba, x, tang, contexts...) + unprepared_valop = @be +(1, 1) # TODO: fix + unprepared_op = @be res2 $op!(f, _, ba, x, tang, contexts...) + return BenchmarkResult(; + prepared_valop, + prepared_op, + preparation, + unprepared_valop, + unprepared_op, + ) + else + return BenchmarkResult(; prepared_valop, prepared_op) + end + end + + @eval function calls_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) + (; f, x, tang, res2, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, ba, x, tang, contexts...) - calls0 = reset_count!(cc) - calls1 = -1 # TODO: fix + preparation = reset_count!(cc) + prepared_valop = -1 # TODO: fix $op!(cc, res2, prep, ba, x, tang, contexts...) - calls2 = reset_count!(cc) - return (; bench0, bench1, bench2, calls0, calls1, calls2) + prepared_op = reset_count!(cc) + return CallsResult(; preparation, prepared_valop, prepared_op) end end end diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index 7302281f2..56bfeecd9 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -1,13 +1,4 @@ -for op in [ - :derivative, - :gradient, - :hessian, - :hvp, - :jacobian, - :pullback, - :pushforward, - :second_derivative, -] +for op in ALL_OPS op! = Symbol(op, "!") val_prefix = if op == :second_derivative "value_derivative_and_" @@ -53,6 +44,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, res1, contexts) = new_scen = deepcopy(scen) xrand = myrandom(x) @@ -87,6 +79,12 @@ for op in [ @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end + if sparsity && $op == :jacobian + @test mynnz(res1_out1_val) == mynnz(scen.res1) + @test mynnz(res1_out2_val) == mynnz(scen.res1) + @test mynnz(res1_out1_noval) == mynnz(scen.res1) + @test mynnz(res1_out2_noval) == mynnz(scen.res1) + end end scenario_intact && @test new_scen == scen return nothing @@ -99,6 +97,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, res1, contexts) = new_scen = deepcopy(scen) xrand = myrandom(x) @@ -145,6 +144,12 @@ for op in [ @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end + if sparsity && $op == :jacobian + @test mynnz(res1_out1_val) == mynnz(scen.res1) + @test mynnz(res1_out2_val) == mynnz(scen.res1) + @test mynnz(res1_out1_noval) == mynnz(scen.res1) + @test mynnz(res1_out2_noval) == mynnz(scen.res1) + end end scenario_intact && @test new_scen == scen return nothing @@ -159,6 +164,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, res1, contexts) = new_scen = deepcopy(scen) xrand, yrand = myrandom(x), myrandom(y) @@ -200,6 +206,12 @@ for op in [ @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end + if sparsity && $op == :jacobian + @test mynnz(res1_out1_val) == mynnz(scen.res1) + @test mynnz(res1_out2_val) == mynnz(scen.res1) + @test mynnz(res1_out1_noval) == mynnz(scen.res1) + @test mynnz(res1_out2_noval) == mynnz(scen.res1) + end end scenario_intact && @test new_scen == scen return nothing @@ -212,6 +224,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, res1, contexts) = new_scen = deepcopy(scen) xrand, yrand = myrandom(x), myrandom(y) @@ -261,6 +274,12 @@ for op in [ @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end + if sparsity && $op == :jacobian + @test mynnz(res1_out1_val) == mynnz(scen.res1) + @test mynnz(res1_out2_val) == mynnz(scen.res1) + @test mynnz(res1_out1_noval) == mynnz(scen.res1) + @test mynnz(res1_out2_noval) == mynnz(scen.res1) + end end scenario_intact && @test new_scen == scen return nothing @@ -274,6 +293,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, res1, res2, contexts) = new_scen = deepcopy(scen) xrand = myrandom(x) @@ -310,6 +330,12 @@ for op in [ @test res2_out1_noval ≈ scen.res2 @test res2_out2_noval ≈ scen.res2 end + if sparsity && $op == :hessian + @test mynnz(res2_out1_val) == mynnz(scen.res2) + @test mynnz(res2_out2_val) == mynnz(scen.res2) + @test mynnz(res2_out1_noval) == mynnz(scen.res2) + @test mynnz(res2_out2_noval) == mynnz(scen.res2) + end end scenario_intact && @test new_scen == scen return nothing @@ -322,6 +348,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, res1, res2, contexts) = new_scen = deepcopy(scen) xrand = myrandom(x) @@ -372,6 +399,12 @@ for op in [ @test res2_out1_noval ≈ scen.res2 @test res2_out2_noval ≈ scen.res2 end + if sparsity && $op == :hessian + @test mynnz(res2_out1_val) == mynnz(scen.res2) + @test mynnz(res2_out2_val) == mynnz(scen.res2) + @test mynnz(res2_out1_noval) == mynnz(scen.res2) + @test mynnz(res2_out2_noval) == mynnz(scen.res2) + end end scenario_intact && @test new_scen == scen return nothing @@ -385,6 +418,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen) xrand, tangrand = myrandom(x), myrandom(tang) @@ -433,6 +467,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen) xrand, tangrand = myrandom(x), myrandom(tang) @@ -493,6 +528,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen) xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang) @@ -552,6 +588,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen) xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang) @@ -630,6 +667,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, tang, res2, contexts) = new_scen = deepcopy(scen) xrand, tangrand = myrandom(x), myrandom(tang) @@ -666,6 +704,7 @@ for op in [ atol::Real, rtol::Real, scenario_intact::Bool, + sparsity::Bool, ) (; f, x, y, tang, res2, contexts) = new_scen = deepcopy(scen) xrand, tangrand = myrandom(x), myrandom(tang) diff --git a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl index 0a67ca05f..bb67b1856 100644 --- a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl @@ -1,13 +1,4 @@ -for op in [ - :derivative, - :gradient, - :hessian, - :hvp, - :jacobian, - :pullback, - :pushforward, - :second_derivative, -] +for op in ALL_OPS op! = Symbol(op, "!") val_prefix = if op == :second_derivative "value_derivative_and_" @@ -27,259 +18,277 @@ for op in [ if op in [:derivative, :gradient, :jacobian] @eval function test_jet( - ba::AbstractADType, - scen::$S1out; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules ) (; f, x, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) - preparation && JET.@test_opt $prep_op(f, ba, x, contexts...) - prepared_op && JET.@test_opt $op(f, prep, ba, x, contexts...) - prepared_op && JET.@test_call $op(f, prep, ba, x, contexts...) - prepared_op && JET.@test_opt $val_and_op(f, prep, ba, x, contexts...) - prepared_op && JET.@test_call $val_and_op(f, prep, ba, x, contexts...) - unprepared_op && JET.@test_opt $op(f, ba, x, contexts...) - unprepared_op && JET.@test_call $op(f, ba, x, contexts...) - unprepared_op && JET.@test_opt $val_and_op(f, ba, x, contexts...) - unprepared_op && JET.@test_call $val_and_op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules $op(f, prep, ba, x, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( + f, prep, ba, x, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op(f, ba, x, contexts...) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $val_and_op( + f, ba, x, contexts... + ) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S1in; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules ) (; f, x, res1, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) - preparation && JET.@test_opt $prep_op(f, ba, x, contexts...) - prepared_op && JET.@test_opt $op!(f, res1, prep, ba, x, contexts...) - prepared_op && JET.@test_call $op!(f, res1, prep, ba, x, contexts...) - prepared_op && JET.@test_opt $val_and_op!(f, res1, prep, ba, x, contexts...) - prepared_op && JET.@test_call $val_and_op!(f, res1, prep, ba, x, contexts...) - unprepared_op && JET.@test_opt $op!(f, res1, ba, x, contexts...) - unprepared_op && JET.@test_call $op!(f, res1, ba, x, contexts...) - unprepared_op && JET.@test_opt $val_and_op!(f, res1, ba, x, contexts...) - unprepared_op && JET.@test_call $val_and_op!(f, res1, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( + f, res1, prep, ba, x, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( + f, res1, prep, ba, x, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op!( + f, res1, ba, x, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $val_and_op!( + f, res1, ba, x, contexts... + ) return nothing end op == :gradient && continue @eval function test_jet( - ba::AbstractADType, - scen::$S2out; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S2out; subset::Symbol, ignored_modules ) (; f, x, y, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, contexts...) - preparation && JET.@test_opt $prep_op(f, y, ba, x, contexts...) - prepared_op && JET.@test_opt $op(f, y, prep, ba, x, contexts...) - prepared_op && JET.@test_call $op(f, y, prep, ba, x, contexts...) - prepared_op && JET.@test_opt $val_and_op(f, y, prep, ba, x, contexts...) - prepared_op && JET.@test_call $val_and_op(f, y, prep, ba, x, contexts...) - unprepared_op && JET.@test_opt $op(f, y, ba, x, contexts...) - unprepared_op && JET.@test_call $op(f, y, ba, x, contexts...) - unprepared_op && JET.@test_opt $val_and_op(f, y, ba, x, contexts...) - unprepared_op && JET.@test_call $val_and_op(f, y, ba, x, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( + f, y, ba, x, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $op( + f, y, prep, ba, x, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( + f, y, prep, ba, x, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op(f, y, ba, x, contexts...) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $val_and_op( + f, y, ba, x, contexts... + ) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S2in; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S2in; subset::Symbol, ignored_modules ) (; f, x, y, res1, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, contexts...) - preparation && JET.@test_opt $prep_op(f, y, ba, x, contexts...) - prepared_op && JET.@test_opt $op!(f, y, res1, prep, ba, x, contexts...) - prepared_op && JET.@test_call $op!(f, y, res1, prep, ba, x, contexts...) - prepared_op && JET.@test_opt $val_and_op!(f, y, res1, prep, ba, x, contexts...) - prepared_op && JET.@test_call $val_and_op!(f, y, res1, prep, ba, x, contexts...) - unprepared_op && JET.@test_opt $op!(f, y, res1, ba, x, contexts...) - unprepared_op && JET.@test_call $op!(f, y, res1, ba, x, contexts...) - unprepared_op && JET.@test_opt $val_and_op!(f, y, res1, ba, x, contexts...) - unprepared_op && JET.@test_call $val_and_op!(f, y, res1, ba, x, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( + f, y, ba, x, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( + f, y, res1, prep, ba, x, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( + f, y, res1, prep, ba, x, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op!( + f, y, res1, ba, x, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $val_and_op!( + f, y, res1, ba, x, contexts... + ) return nothing end elseif op in [:second_derivative, :hessian] @eval function test_jet( - ba::AbstractADType, - scen::$S1out; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules ) (; f, x, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) - preparation && JET.@test_opt $prep_op(f, ba, x, contexts...) - prepared_op && JET.@test_opt $op(f, prep, ba, x, contexts...) - prepared_op && JET.@test_call $op(f, prep, ba, x, contexts...) - prepared_op && JET.@test_opt $val_and_op(f, prep, ba, x, contexts...) - prepared_op && JET.@test_call $val_and_op(f, prep, ba, x, contexts...) - unprepared_op && JET.@test_opt $op(f, ba, x, contexts...) - unprepared_op && JET.@test_call $op(f, ba, x, contexts...) - unprepared_op && JET.@test_opt $val_and_op(f, ba, x, contexts...) - unprepared_op && JET.@test_call $val_and_op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules $op(f, prep, ba, x, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( + f, prep, ba, x, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op(f, ba, x, contexts...) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $val_and_op( + f, ba, x, contexts... + ) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S1in; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules ) (; f, x, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) - preparation && JET.@test_opt $prep_op(f, ba, x, contexts...) - prepared_op && JET.@test_opt $op!(f, res2, prep, ba, x, contexts...) - prepared_op && JET.@test_call $op!(f, res2, prep, ba, x, contexts...) - prepared_op && - JET.@test_opt $val_and_op!(f, res1, res2, prep, ba, x, contexts...) - prepared_op && - JET.@test_call $val_and_op!(f, res1, res2, prep, ba, x, contexts...) - unprepared_op && JET.@test_opt $op!(f, res2, ba, x, contexts...) - unprepared_op && JET.@test_call $op!(f, res2, ba, x, contexts...) - unprepared_op && JET.@test_opt $val_and_op!(f, res1, res2, ba, x, contexts...) - unprepared_op && JET.@test_call $val_and_op!(f, res1, res2, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( + f, res2, prep, ba, x, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( + f, res1, res2, prep, ba, x, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op!( + f, res2, ba, x, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $val_and_op!( + f, res1, res2, ba, x, contexts... + ) return nothing end elseif op in [:pushforward, :pullback] @eval function test_jet( - ba::AbstractADType, - scen::$S1out; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules ) (; f, x, tang, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - preparation && JET.@test_opt $prep_op(f, ba, x, tang, contexts...) - prepared_op && JET.@test_opt $op(f, prep, ba, x, tang, contexts...) - prepared_op && JET.@test_call $op(f, prep, ba, x, tang, contexts...) - prepared_op && JET.@test_opt $val_and_op(f, prep, ba, x, tang, contexts...) - prepared_op && JET.@test_call $val_and_op(f, prep, ba, x, tang, contexts...) - unprepared_op && JET.@test_opt $op(f, ba, x, tang, contexts...) - unprepared_op && JET.@test_call $op(f, ba, x, tang, contexts...) - unprepared_op && JET.@test_opt $val_and_op(f, ba, x, tang, contexts...) - unprepared_op && JET.@test_call $val_and_op(f, ba, x, tang, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( + f, ba, x, tang, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $op( + f, prep, ba, x, tang, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( + f, prep, ba, x, tang, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op(f, ba, x, tang, contexts...) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $val_and_op( + f, ba, x, tang, contexts... + ) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S1in; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules ) (; f, x, tang, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - preparation && JET.@test_opt $prep_op(f, ba, x, tang, contexts...) - prepared_op && JET.@test_opt $op!(f, res1, prep, ba, x, tang, contexts...) - prepared_op && JET.@test_call $op!(f, res1, prep, ba, x, tang, contexts...) - prepared_op && - JET.@test_opt $val_and_op!(f, res1, prep, ba, x, tang, contexts...) - prepared_op && - JET.@test_call $val_and_op!(f, res1, prep, ba, x, tang, contexts...) - unprepared_op && JET.@test_opt $op!(f, res1, ba, x, tang, contexts...) - unprepared_op && JET.@test_call $op!(f, res1, ba, x, tang, contexts...) - unprepared_op && JET.@test_opt $val_and_op!(f, res1, ba, x, tang, contexts...) - unprepared_op && JET.@test_call $val_and_op!(f, res1, ba, x, tang, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( + f, ba, x, tang, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( + f, res1, prep, ba, x, tang, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( + f, res1, prep, ba, x, tang, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op!( + f, res1, ba, x, tang, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $val_and_op!( + f, res1, ba, x, tang, contexts... + ) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S2out; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S2out; subset::Symbol, ignored_modules ) (; f, x, y, tang, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, tang, contexts...) - preparation && JET.@test_opt $prep_op(f, y, ba, x, tang, contexts...) - prepared_op && JET.@test_opt $op(f, y, prep, ba, x, tang, contexts...) - prepared_op && JET.@test_call $op(f, y, prep, ba, x, tang, contexts...) - prepared_op && JET.@test_opt $val_and_op(f, y, prep, ba, x, tang, contexts...) - prepared_op && JET.@test_call $val_and_op(f, y, prep, ba, x, tang, contexts...) - unprepared_op && JET.@test_opt $op(f, y, ba, x, tang, contexts...) - unprepared_op && JET.@test_call $op(f, y, ba, x, tang, contexts...) - unprepared_op && JET.@test_opt $val_and_op(f, y, ba, x, tang, contexts...) - unprepared_op && JET.@test_call $val_and_op(f, y, ba, x, tang, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( + f, y, ba, x, tang, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $op( + f, y, prep, ba, x, tang, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( + f, y, prep, ba, x, tang, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op( + f, y, ba, x, tang, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $val_and_op( + f, y, ba, x, tang, contexts... + ) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S2in; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S2in; subset::Symbol, ignored_modules ) (; f, x, y, tang, res1, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, tang, contexts...) - preparation && JET.@test_opt $prep_op(f, y, ba, x, tang, contexts...) - prepared_op && JET.@test_opt $op!(f, y, res1, prep, ba, x, tang, contexts...) - prepared_op && JET.@test_call $op!(f, y, res1, prep, ba, x, tang, contexts...) - prepared_op && - JET.@test_opt $val_and_op!(f, y, res1, prep, ba, x, tang, contexts...) - prepared_op && - JET.@test_call $val_and_op!(f, y, res1, prep, ba, x, tang, contexts...) - unprepared_op && JET.@test_opt $op!(f, y, res1, ba, x, tang, contexts...) - unprepared_op && JET.@test_call $op!(f, y, res1, ba, x, tang, contexts...) - unprepared_op && - JET.@test_opt $val_and_op!(f, y, res1, ba, x, tang, contexts...) - unprepared_op && - JET.@test_call $val_and_op!(f, y, res1, ba, x, tang, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( + f, y, ba, x, tang, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( + f, y, res1, prep, ba, x, tang, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( + f, y, res1, prep, ba, x, tang, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op!( + f, y, res1, ba, x, tang, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $val_and_op!( + f, y, res1, ba, x, tang, contexts... + ) return nothing end elseif op in [:hvp] @eval function test_jet( - ba::AbstractADType, - scen::$S1out; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules ) (; f, x, tang, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - preparation && JET.@test_opt $prep_op(f, ba, x, tang, contexts...) - prepared_op && JET.@test_opt $op(f, prep, ba, x, tang, contexts...) - prepared_op && JET.@test_call $op(f, prep, ba, x, tang, contexts...) - unprepared_op && JET.@test_opt $op(f, ba, x, tang, contexts...) - unprepared_op && JET.@test_call $op(f, ba, x, tang, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( + f, ba, x, tang, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $op( + f, prep, ba, x, tang, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op(f, ba, x, tang, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, - scen::$S1in; - preparation::Bool, - prepared_op::Bool, - unprepared_op::Bool, + ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules ) (; f, x, tang, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - preparation && JET.@test_opt $prep_op(f, ba, x, tang, contexts...) - prepared_op && JET.@test_opt $op!(f, res2, prep, ba, x, tang, contexts...) - prepared_op && JET.@test_call $op!(f, res2, prep, ba, x, tang, contexts...) - unprepared_op && JET.@test_opt $op!(f, res2, ba, x, tang, contexts...) - unprepared_op && JET.@test_call $op!(f, res2, ba, x, tang, contexts...) + (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( + f, ba, x, tang, contexts... + ) + (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( + f, res2, prep, ba, x, tang, contexts... + ) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op!( + f, res2, ba, x, tang, contexts... + ) return nothing end end diff --git a/DifferentiationInterfaceTest/src/utils.jl b/DifferentiationInterfaceTest/src/utils.jl index 7e2f9f0b5..ada6fa2d6 100644 --- a/DifferentiationInterfaceTest/src/utils.jl +++ b/DifferentiationInterfaceTest/src/utils.jl @@ -24,5 +24,5 @@ mymultiply(x::AbstractArray, a::Number) = a .* x mymultiply(x::NTuple, a::Number) = map(Base.Fix2(mymultiply, a), x) mymultiply(::Nothing, a::Number) = nothing -mynnz(A::AbstractMatrix) = nnz(A) -mynnz(A::Union{Transpose,Adjoint}) = nnz(parent(A)) # fix for Julia 1.6 +mynnz(A::AbstractMatrix) = count(!iszero, A) +mynnz(A::AbstractSparseMatrix) = nnz(A) diff --git a/DifferentiationInterfaceTest/test/zero_backends.jl b/DifferentiationInterfaceTest/test/zero_backends.jl index 92e53d5a1..81d4902c9 100644 --- a/DifferentiationInterfaceTest/test/zero_backends.jl +++ b/DifferentiationInterfaceTest/test/zero_backends.jl @@ -2,7 +2,7 @@ using ADTypes using DifferentiationInterface using DifferentiationInterface: AutoZeroForward, AutoZeroReverse using DifferentiationInterfaceTest -using DifferentiationInterfaceTest: test_allocfree, allocfree_scenarios +using DifferentiationInterfaceTest: allocfree_scenarios using Test @@ -12,25 +12,26 @@ LOGGING = get(ENV, "CI", "false") == "false" test_differentiation( AutoZeroForward(), - zero.(default_scenarios(; include_batchified=false)); - correctness=true, - type_stability=(; preparation=true, prepared_op=true, unprepared_op=true), + default_scenarios(; include_batchified=false); + correctness=false, + type_stability=:full, logging=LOGGING, ) test_differentiation( AutoZeroReverse(), - zero.(default_scenarios(; include_batchified=false)); - correctness=true, - type_stability=true, + default_scenarios(; include_batchified=false); + correctness=false, + type_stability=:full, logging=LOGGING, ) ## Benchmark data1 = benchmark_differentiation( - [AutoZeroForward()], + AutoZeroForward(), default_scenarios(; include_batchified=false, include_constantified=true); + benchmark=:full, logging=LOGGING, ); @@ -38,7 +39,7 @@ struct FakeBackend <: ADTypes.AbstractADType end ADTypes.mode(::FakeBackend) = ADTypes.ForwardMode() data2 = benchmark_differentiation( - [FakeBackend()], default_scenarios(; include_batchified=false); logging=false + FakeBackend(), default_scenarios(; include_batchified=false); logging=false ); @testset "Benchmarking DataFrame" begin @@ -58,17 +59,17 @@ end data_allocfree = vcat( benchmark_differentiation( - [AutoZeroForward()], + AutoZeroForward(), allocfree_scenarios(); excluded=[:pullback, :gradient], logging=LOGGING, ), benchmark_differentiation( - [AutoZeroReverse()], + AutoZeroReverse(), allocfree_scenarios(); excluded=[:pushforward, :derivative], logging=LOGGING, ), ) -test_allocfree(data_allocfree); +@test is_allocfree(data_allocfree); From 387c01d096bfcd9af606e192fead03d2bee4716b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 11:20:37 +0200 Subject: [PATCH 02/18] Remove `first_order` and `second_order` --- .../ChainRulesBackends/chainrules_zygote.jl | 2 +- .../test/Back/ChainRulesBackends/diffractor.jl | 5 ++++- .../test/Back/Enzyme/test.jl | 17 +++++++---------- .../test/Back/FiniteDifferences/test.jl | 2 +- .../test/Back/ForwardDiff/test.jl | 2 +- .../test/Back/Mooncake/test.jl | 2 +- .../test/Back/Tracker/test.jl | 2 +- .../test/Back/Zygote/test.jl | 4 ++-- .../test/Misc/DifferentiateWith/test.jl | 2 +- .../test/Misc/ZeroBackends/test.jl | 1 - DifferentiationInterfaceTest/test/weird.jl | 4 ++-- 11 files changed, 21 insertions(+), 22 deletions(-) diff --git a/DifferentiationInterface/test/Back/ChainRulesBackends/chainrules_zygote.jl b/DifferentiationInterface/test/Back/ChainRulesBackends/chainrules_zygote.jl index af28e8e08..ecbd4b8f8 100644 --- a/DifferentiationInterface/test/Back/ChainRulesBackends/chainrules_zygote.jl +++ b/DifferentiationInterface/test/Back/ChainRulesBackends/chainrules_zygote.jl @@ -23,6 +23,6 @@ test_differentiation( test_differentiation( AutoChainRules(ZygoteRuleConfig()), default_scenarios(; include_normal=false, include_constantified=true); - second_order=false, + excluded=SECOND_ORDER, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/ChainRulesBackends/diffractor.jl b/DifferentiationInterface/test/Back/ChainRulesBackends/diffractor.jl index 23f6412a6..e8b890300 100644 --- a/DifferentiationInterface/test/Back/ChainRulesBackends/diffractor.jl +++ b/DifferentiationInterface/test/Back/ChainRulesBackends/diffractor.jl @@ -13,5 +13,8 @@ for backend in [AutoDiffractor()] end test_differentiation( - AutoDiffractor(), default_scenarios(; linalg=false); second_order=false, logging=LOGGING + AutoDiffractor(), + default_scenarios(; linalg=false); + excluded=SECOND_ORDER, + logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index be3786c0e..4433bc580 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -31,19 +31,19 @@ end; ## First order -test_differentiation(backends, default_scenarios(); second_order=false, logging=LOGGING); +test_differentiation(backends, default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING); test_differentiation( backends[1:3], default_scenarios(; include_normal=false, include_constantified=true); - second_order=false, + excluded=SECOND_ORDER, logging=LOGGING, ); test_differentiation( duplicated_backends, default_scenarios(; include_normal=false, include_closurified=true); - second_order=false, + excluded=SECOND_ORDER, logging=LOGGING, ); @@ -55,7 +55,7 @@ test_differentiation( default_scenarios(; include_batchified=false); correctness=false, type_stability=:prepared, - second_order=false, + excluded=SECOND_ORDER, logging=LOGGING, ); =# @@ -65,27 +65,24 @@ test_differentiation( test_differentiation( AutoEnzyme(), default_scenarios(; include_constantified=true); - first_order=false, + excluded=FIRST_ORDER, logging=LOGGING, ); test_differentiation( AutoEnzyme(; mode=Enzyme.Forward); - first_order=false, - excluded=[:hessian, :hvp], + excluded=vcat(FIRST_ORDER, [:hessian, :hvp]), logging=LOGGING, ); test_differentiation( AutoEnzyme(; mode=Enzyme.Reverse); - first_order=false, - excluded=[:second_derivative], + excluded=vcat(FIRST_ORDER, [:second_derivative]), logging=LOGGING, ); test_differentiation( SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse), AutoEnzyme(; mode=Enzyme.Forward)); - first_order=false, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl index cae649e42..a1ad1bda0 100644 --- a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl @@ -15,6 +15,6 @@ end test_differentiation( AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), default_scenarios(; include_constantified=true); - second_order=false, + excluded=SECOND_ORDER, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 14b8ea48e..806b83521 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -34,7 +34,7 @@ test_differentiation( backends, vcat(component_scenarios(), static_scenarios()); # FD accesses individual indices excluded=[:jacobian], # jacobian is super slow for some reason - second_order=false, + excluded=SECOND_ORDER, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 9323e7732..8dee4e17d 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -17,6 +17,6 @@ end test_differentiation( backends, default_scenarios(; include_constantified=true); - second_order=false, + excluded=SECOND_ORDER, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/Tracker/test.jl b/DifferentiationInterface/test/Back/Tracker/test.jl index 8be27cfac..ef2df3832 100644 --- a/DifferentiationInterface/test/Back/Tracker/test.jl +++ b/DifferentiationInterface/test/Back/Tracker/test.jl @@ -15,6 +15,6 @@ end test_differentiation( AutoTracker(), default_scenarios(; include_constantified=true); - second_order=false, + excluded=SECOND_ORDER, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/Zygote/test.jl b/DifferentiationInterface/test/Back/Zygote/test.jl index 8db507747..d7316bec1 100644 --- a/DifferentiationInterface/test/Back/Zygote/test.jl +++ b/DifferentiationInterface/test/Back/Zygote/test.jl @@ -28,12 +28,12 @@ test_differentiation( logging=LOGGING, ); -test_differentiation(second_order_backends; first_order=false, logging=LOGGING); +test_differentiation(second_order_backends; logging=LOGGING); test_differentiation( backends[1], vcat(component_scenarios(), gpu_scenarios()); - second_order=false, + excluded=SECOND_ORDER, logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Misc/DifferentiateWith/test.jl b/DifferentiationInterface/test/Misc/DifferentiateWith/test.jl index 0086a7965..c8ea57c0b 100644 --- a/DifferentiationInterface/test/Misc/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Misc/DifferentiateWith/test.jl @@ -24,6 +24,6 @@ end test_differentiation( [AutoForwardDiff(), AutoZygote()], differentiatewith_scenarios(); - second_order=false, + excluded=SECOND_ORDER, logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Misc/ZeroBackends/test.jl b/DifferentiationInterface/test/Misc/ZeroBackends/test.jl index 956bef3b9..b352fee87 100644 --- a/DifferentiationInterface/test/Misc/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Misc/ZeroBackends/test.jl @@ -42,7 +42,6 @@ test_differentiation( default_scenarios(; include_batchified=false, include_constantified=true); correctness=false, type_stability=:full, - first_order=false, logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index dc3f17567..89dd59c45 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -36,14 +36,14 @@ test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING) test_differentiation(AutoForwardDiff(), component_scenarios(); logging=LOGGING) -test_differentiation(AutoZygote(), gpu_scenarios(); second_order=false, logging=LOGGING) +test_differentiation(AutoZygote(), gpu_scenarios(); excluded=SECOND_ORDER, logging=LOGGING) ## Closures test_differentiation( AutoFiniteDiff(), default_scenarios(; include_normal=false, include_closurified=true); - second_order=false, + excluded=SECOND_ORDER, logging=LOGGING, ); From e6b9d33b5d19ce45327e892ada2f9fc8a3ed8019 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 11:49:45 +0200 Subject: [PATCH 03/18] Docs --- DifferentiationInterfaceTest/docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterfaceTest/docs/src/api.md b/DifferentiationInterfaceTest/docs/src/api.md index 9ba404c7a..a9267d9dd 100644 --- a/DifferentiationInterfaceTest/docs/src/api.md +++ b/DifferentiationInterfaceTest/docs/src/api.md @@ -22,7 +22,7 @@ SECOND_ORDER ## Utilities ```@docs -DifferentiationBenchmarkDataRow +DifferentiationInterfaceTest.DifferentiationBenchmarkDataRow ``` ## Pre-made scenario lists From 2ea0c77b3cfe29b2b3e67a00b594ebe6633b5539 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 12:21:44 +0200 Subject: [PATCH 04/18] Zero allocs --- DifferentiationInterface/test/Back/ForwardDiff/test.jl | 2 +- DifferentiationInterfaceTest/docs/src/api.md | 6 ------ DifferentiationInterfaceTest/test/zero_backends.jl | 4 +++- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 806b83521..8adc6e2dd 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -23,7 +23,7 @@ test_differentiation( ); test_differentiation( - AutoForwardDiff(); correctness=false, type_stability=true, logging=LOGGING + AutoForwardDiff(); correctness=false, type_stability=:full, logging=LOGGING ); test_differentiation( diff --git a/DifferentiationInterfaceTest/docs/src/api.md b/DifferentiationInterfaceTest/docs/src/api.md index a9267d9dd..b6c2c95ba 100644 --- a/DifferentiationInterfaceTest/docs/src/api.md +++ b/DifferentiationInterfaceTest/docs/src/api.md @@ -37,12 +37,6 @@ gpu_scenarios static_scenarios ``` -## Utilities - -```@docs -DifferentiationBenchmarkDataRow -``` - ## Internals This is not part of the public API. diff --git a/DifferentiationInterfaceTest/test/zero_backends.jl b/DifferentiationInterfaceTest/test/zero_backends.jl index 81d4902c9..b62bd8480 100644 --- a/DifferentiationInterfaceTest/test/zero_backends.jl +++ b/DifferentiationInterfaceTest/test/zero_backends.jl @@ -62,14 +62,16 @@ data_allocfree = vcat( AutoZeroForward(), allocfree_scenarios(); excluded=[:pullback, :gradient], + benchmark=:prepared, logging=LOGGING, ), benchmark_differentiation( AutoZeroReverse(), allocfree_scenarios(); excluded=[:pushforward, :derivative], + benchmark=:prepared, logging=LOGGING, ), ) -@test is_allocfree(data_allocfree); +@test all(iszero, data_allocfree[!, :allocs]) From 60b1a156e26769ed430a5379239f2eb3bea72acd Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 13:46:04 +0200 Subject: [PATCH 05/18] Fixes --- .../test/Back/ForwardDiff/test.jl | 8 +- .../src/test_differentiation.jl | 9 +- .../src/tests/benchmark.jl | 3 - .../src/tests/correctness_eval.jl | 12 +- .../src/tests/sparsity.jl | 93 ---------------- .../src/tests/type_stability_eval.jl | 103 +++++++++--------- .../test/zero_backends.jl | 15 ++- 7 files changed, 84 insertions(+), 159 deletions(-) delete mode 100644 DifferentiationInterfaceTest/src/tests/sparsity.jl diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 8adc6e2dd..a9a77f4aa 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -23,11 +23,15 @@ test_differentiation( ); test_differentiation( - AutoForwardDiff(); correctness=false, type_stability=:full, logging=LOGGING + AutoForwardDiff(); correctness=false, type_stability=:prepared, logging=LOGGING ); test_differentiation( - AutoForwardDiff(; chunksize=5); correctness=false, type_stability=:full, logging=LOGGING + AutoForwardDiff(; chunksize=5); + correctness=false, + type_stability=:full, + excluded=[:hessian], + logging=LOGGING, ); test_differentiation( diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index 413c84922..ec6c83907 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -132,7 +132,12 @@ function test_differentiation( end yield() (type_stability != :none) && @testset "Type stability" begin - test_jet(adapted_backend, scen; subset=type_stability, ignored_modules) + test_jet( + adapted_backend, + scen; + subset=type_stability, + ignored_modules=ignored_modules, + ) end yield() (benchmark != :none) && @testset "Benchmark" begin @@ -176,7 +181,7 @@ Specifying the set of scenarios is mandatory for this function. function benchmark_differentiation( backends, scenarios::Vector{<:Scenario}; - benchmark::Symbol=:full, + benchmark::Symbol=:prepared, excluded::Vector{Symbol}=Symbol[], logging::Bool=false, ) diff --git a/DifferentiationInterfaceTest/src/tests/benchmark.jl b/DifferentiationInterfaceTest/src/tests/benchmark.jl index 859fe82da..5ee559922 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark.jl @@ -45,9 +45,6 @@ function failed_bench() return Benchmark([sample]) end -failed_benchs(k::Integer) = ntuple(i -> failed_bench(), k) -failed_calls(k::Integer) = ntuple(i -> -1, k) - """ DifferentiationBenchmarkDataRow diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index 56bfeecd9..c7a2d6da7 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -79,7 +79,7 @@ for op in ALL_OPS @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end - if sparsity && $op == :jacobian + if sparsity && op == :jacobian @test mynnz(res1_out1_val) == mynnz(scen.res1) @test mynnz(res1_out2_val) == mynnz(scen.res1) @test mynnz(res1_out1_noval) == mynnz(scen.res1) @@ -144,7 +144,7 @@ for op in ALL_OPS @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end - if sparsity && $op == :jacobian + if sparsity && op == :jacobian @test mynnz(res1_out1_val) == mynnz(scen.res1) @test mynnz(res1_out2_val) == mynnz(scen.res1) @test mynnz(res1_out1_noval) == mynnz(scen.res1) @@ -206,7 +206,7 @@ for op in ALL_OPS @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end - if sparsity && $op == :jacobian + if sparsity && op == :jacobian @test mynnz(res1_out1_val) == mynnz(scen.res1) @test mynnz(res1_out2_val) == mynnz(scen.res1) @test mynnz(res1_out1_noval) == mynnz(scen.res1) @@ -274,7 +274,7 @@ for op in ALL_OPS @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end - if sparsity && $op == :jacobian + if sparsity && op == :jacobian @test mynnz(res1_out1_val) == mynnz(scen.res1) @test mynnz(res1_out2_val) == mynnz(scen.res1) @test mynnz(res1_out1_noval) == mynnz(scen.res1) @@ -330,7 +330,7 @@ for op in ALL_OPS @test res2_out1_noval ≈ scen.res2 @test res2_out2_noval ≈ scen.res2 end - if sparsity && $op == :hessian + if sparsity && op == :hessian @test mynnz(res2_out1_val) == mynnz(scen.res2) @test mynnz(res2_out2_val) == mynnz(scen.res2) @test mynnz(res2_out1_noval) == mynnz(scen.res2) @@ -399,7 +399,7 @@ for op in ALL_OPS @test res2_out1_noval ≈ scen.res2 @test res2_out2_noval ≈ scen.res2 end - if sparsity && $op == :hessian + if sparsity && op == :hessian @test mynnz(res2_out1_val) == mynnz(scen.res2) @test mynnz(res2_out2_val) == mynnz(scen.res2) @test mynnz(res2_out1_noval) == mynnz(scen.res2) diff --git a/DifferentiationInterfaceTest/src/tests/sparsity.jl b/DifferentiationInterfaceTest/src/tests/sparsity.jl deleted file mode 100644 index 6a92ce38d..000000000 --- a/DifferentiationInterfaceTest/src/tests/sparsity.jl +++ /dev/null @@ -1,93 +0,0 @@ -## Jacobian - -function test_sparsity(ba::AbstractADType, scen::Scenario{:jacobian,:out,:out}) - (; f, x, y, contexts) = scen = deepcopy(scen) - prep = prepare_jacobian(f, ba, x, contexts...) - - _, jac1 = value_and_jacobian(f, prep, ba, x, contexts...) - jac2 = jacobian(f, prep, ba, x, contexts...) - - @testset "Sparsity pattern" begin - @test mynnz(jac1) == mynnz(scen.res1) - @test mynnz(jac2) == mynnz(scen.res1) - end - return nothing -end - -function test_sparsity(ba::AbstractADType, scen::Scenario{:jacobian,:in,:out}) - (; f, x, y, contexts) = deepcopy(scen) - prep = prepare_jacobian(f, ba, x, contexts...) - - _, jac1 = value_and_jacobian!(f, mysimilar(scen.res1), prep, ba, x, contexts...) - jac2 = jacobian!(f, mysimilar(scen.res1), prep, ba, x, contexts...) - - @testset "Sparsity pattern" begin - @test mynnz(jac1) == mynnz(scen.res1) - @test mynnz(jac2) == mynnz(scen.res1) - end - return nothing -end - -function test_sparsity(ba::AbstractADType, scen::Scenario{:jacobian,:out,:in}) - (; f, x, y, contexts) = deepcopy(scen) - f! = f - prep = prepare_jacobian(f!, mysimilar(y), ba, x, contexts...) - - _, jac1 = value_and_jacobian(f!, mysimilar(y), prep, ba, x, contexts...) - jac2 = jacobian(f!, mysimilar(y), prep, ba, x, contexts...) - - @testset "Sparsity pattern" begin - @test mynnz(jac1) == mynnz(scen.res1) - @test mynnz(jac2) == mynnz(scen.res1) - end - return nothing -end - -function test_sparsity(ba::AbstractADType, scen::Scenario{:jacobian,:in,:in}) - (; f, x, y, contexts) = deepcopy(scen) - f! = f - prep = prepare_jacobian(f!, mysimilar(y), ba, x, contexts...) - - _, jac1 = value_and_jacobian!( - f!, mysimilar(y), mysimilar(scen.res1), prep, ba, x, contexts... - ) - jac2 = jacobian!(f!, mysimilar(y), mysimilar(scen.res1), prep, ba, x, contexts...) - - @testset "Sparsity pattern" begin - @test mynnz(jac1) == mynnz(scen.res1) - @test mynnz(jac2) == mynnz(scen.res1) - end - return nothing -end - -## Hessian - -function test_sparsity(ba::AbstractADType, scen::Scenario{:hessian,:out,:out}) - (; f, x, y, contexts) = deepcopy(scen) - prep = prepare_hessian(f, ba, x, contexts...) - - hess1 = hessian(f, prep, ba, x, contexts...) - _, _, hess2 = value_gradient_and_hessian(f, prep, ba, x, contexts...) - - @testset "Sparsity pattern" begin - @test mynnz(hess1) == mynnz(scen.res2) - @test mynnz(hess2) == mynnz(scen.res2) - end - return nothing -end - -function test_sparsity(ba::AbstractADType, scen::Scenario{:hessian,:in,:out}) - (; f, x, y, contexts) = deepcopy(scen) - prep = prepare_hessian(f, ba, x, contexts...) - - hess1 = hessian!(f, mysimilar(scen.res2), prep, ba, x, contexts...) - _, _, hess2 = value_gradient_and_hessian!( - f, mysimilar(x), mysimilar(scen.res2), prep, ba, x, contexts... - ) - - @testset "Sparsity pattern" begin - @test mynnz(hess1) == mynnz(scen.res2) - @test mynnz(hess2) == mynnz(scen.res2) - end - return nothing -end diff --git a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl index bb67b1856..1afb4a2f8 100644 --- a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl @@ -25,15 +25,15 @@ for op in ALL_OPS (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, prep, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules $op(f, ba, x, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, prep, ba, x, contexts... + f, ba, x, contexts... ) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op(f, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules $op(f, prep, ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, ba, x, contexts... + f, prep, ba, x, contexts... ) return nothing end @@ -46,18 +46,18 @@ for op in ALL_OPS (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res1, prep, ba, x, contexts... + f, res1, ba, x, contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, prep, ba, x, contexts... + f, res1, ba, x, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $op!( - f, res1, ba, x, contexts... + f, res1, prep, ba, x, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, ba, x, contexts... + f, res1, prep, ba, x, contexts... ) return nothing end @@ -72,17 +72,18 @@ for op in ALL_OPS (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( f, y, ba, x, contexts... ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op( - f, y, prep, ba, x, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules $op(f, y, ba, x, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, prep, ba, x, contexts... + f, y, ba, x, contexts... ) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op(f, y, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules $op( + f, y, prep, ba, x, contexts... + ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, ba, x, contexts... + f, y, prep, ba, x, contexts... ) return nothing end @@ -96,18 +97,18 @@ for op in ALL_OPS f, y, ba, x, contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, prep, ba, x, contexts... + f, y, res1, ba, x, contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, prep, ba, x, contexts... + f, y, res1, ba, x, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, ba, x, contexts... + f, y, res1, prep, ba, x, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, ba, x, contexts... + f, y, res1, prep, ba, x, contexts... ) return nothing end @@ -121,15 +122,15 @@ for op in ALL_OPS (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, prep, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules $op(f, ba, x, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, prep, ba, x, contexts... + f, ba, x, contexts... ) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op(f, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules $op(f, prep, ba, x, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, ba, x, contexts... + f, prep, ba, x, contexts... ) return nothing end @@ -142,18 +143,18 @@ for op in ALL_OPS (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res2, prep, ba, x, contexts... + f, res2, ba, x, contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, res2, prep, ba, x, contexts... + f, res1, res2, ba, x, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $op!( - f, res2, ba, x, contexts... + f, res2, prep, ba, x, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, res2, ba, x, contexts... + f, res1, res2, prep, ba, x, contexts... ) return nothing end @@ -167,17 +168,18 @@ for op in ALL_OPS (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( f, ba, x, tang, contexts... ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op( - f, prep, ba, x, tang, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules $op(f, ba, x, tang, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, prep, ba, x, tang, contexts... + f, ba, x, tang, contexts... ) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op(f, ba, x, tang, contexts...) + @test_opt ignored_modules = ignored_modules $op( + f, prep, ba, x, tang, contexts... + ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, ba, x, tang, contexts... + f, prep, ba, x, tang, contexts... ) return nothing end @@ -191,18 +193,18 @@ for op in ALL_OPS f, ba, x, tang, contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res1, prep, ba, x, tang, contexts... + f, res1, ba, x, tang, contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, prep, ba, x, tang, contexts... + f, res1, ba, x, tang, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $op!( - f, res1, ba, x, tang, contexts... + f, res1, prep, ba, x, tang, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, ba, x, tang, contexts... + f, res1, prep, ba, x, tang, contexts... ) return nothing end @@ -216,18 +218,18 @@ for op in ALL_OPS f, y, ba, x, tang, contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules $op( - f, y, prep, ba, x, tang, contexts... + f, y, ba, x, tang, contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, prep, ba, x, tang, contexts... + f, y, ba, x, tang, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $op( - f, y, ba, x, tang, contexts... + f, y, prep, ba, x, tang, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, ba, x, tang, contexts... + f, y, prep, ba, x, tang, contexts... ) return nothing end @@ -241,18 +243,18 @@ for op in ALL_OPS f, y, ba, x, tang, contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, prep, ba, x, tang, contexts... + f, y, res1, ba, x, tang, contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, prep, ba, x, tang, contexts... + f, y, res1, ba, x, tang, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, ba, x, tang, contexts... + f, y, res1, prep, ba, x, tang, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, ba, x, tang, contexts... + f, y, res1, prep, ba, x, tang, contexts... ) return nothing end @@ -266,11 +268,12 @@ for op in ALL_OPS (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( f, ba, x, tang, contexts... ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op( - f, prep, ba, x, tang, contexts... - ) - (subset in (:prepared, :full)) && + (subset == :full) && @test_opt ignored_modules = ignored_modules $op(f, ba, x, tang, contexts...) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules $op( + f, prep, ba, x, tang, contexts... + ) return nothing end @@ -283,11 +286,11 @@ for op in ALL_OPS f, ba, x, tang, contexts... ) (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res2, prep, ba, x, tang, contexts... + f, res2, ba, x, tang, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules $op!( - f, res2, ba, x, tang, contexts... + f, res2, prep, ba, x, tang, contexts... ) return nothing end diff --git a/DifferentiationInterfaceTest/test/zero_backends.jl b/DifferentiationInterfaceTest/test/zero_backends.jl index b62bd8480..28a99e1b1 100644 --- a/DifferentiationInterfaceTest/test/zero_backends.jl +++ b/DifferentiationInterfaceTest/test/zero_backends.jl @@ -22,15 +22,21 @@ test_differentiation( AutoZeroReverse(), default_scenarios(; include_batchified=false); correctness=false, - type_stability=:full, + type_stability=:prepared, logging=LOGGING, ) ## Benchmark -data1 = benchmark_differentiation( +data0 = benchmark_differentiation( AutoZeroForward(), default_scenarios(; include_batchified=false, include_constantified=true); + logging=LOGGING, +); + +data1 = benchmark_differentiation( + AutoZeroForward(), + default_scenarios(; include_batchified=false); benchmark=:full, logging=LOGGING, ); @@ -39,7 +45,10 @@ struct FakeBackend <: ADTypes.AbstractADType end ADTypes.mode(::FakeBackend) = ADTypes.ForwardMode() data2 = benchmark_differentiation( - FakeBackend(), default_scenarios(; include_batchified=false); logging=false + FakeBackend(), + default_scenarios(; include_batchified=false); + count_calls=false, + logging=false, ); @testset "Benchmarking DataFrame" begin From 371eadfefd0988ead7ecb8e854de03973f83b68c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 14:11:22 +0200 Subject: [PATCH 06/18] Call count --- .../src/tests/benchmark_eval.jl | 94 ++++++++++++++++--- 1 file changed, 82 insertions(+), 12 deletions(-) diff --git a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl index 2ecf0bed6..3e1b22c68 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl @@ -150,7 +150,13 @@ for op in ALL_OPS prepared_valop = reset_count!(cc) $op(cc, prep, ba, x, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + $val_and_op(cc, ba, x, contexts...) + unprepared_valop = reset_count!(cc) + $op(cc, ba, x, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) @@ -185,7 +191,13 @@ for op in ALL_OPS prepared_valop = reset_count!(cc) $op!(cc, res1, prep, ba, x, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + $val_and_op!(cc, res1, ba, x, contexts...) + unprepared_valop = reset_count!(cc) + $op!(cc, res1, ba, x, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end op == :gradient && continue @@ -220,7 +232,13 @@ for op in ALL_OPS prepared_valop = reset_count!(cc) $op(cc, y, prep, ba, x, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + $val_and_op(cc, y,, ba, x, contexts...) + unprepared_valop = reset_count!(cc) + $op(cc, y,, ba, x, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end @eval function benchmark_aux(ba::AbstractADType, scen::$S2in; subset::Symbol) @@ -257,7 +275,13 @@ for op in ALL_OPS prepared_valop = reset_count!(cc) $op!(cc, y, res1, prep, ba, x, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + $val_and_op!(cc, y, res1, ba, x, contexts...) + unprepared_valop = reset_count!(cc) + $op!(cc, y, res1, ba, x, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end elseif op in [:hessian, :second_derivative] @@ -291,7 +315,13 @@ for op in ALL_OPS prepared_valop = reset_count!(cc) $op(cc, prep, ba, x, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + $val_and_op(cc, ba, x, contexts...) + unprepared_valop = reset_count!(cc) + $op(cc, ba, x, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) @@ -329,7 +359,13 @@ for op in ALL_OPS prepared_valop = reset_count!(cc) $op!(cc, res2, prep, ba, x, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + $val_and_op!(cc, res1, res2, ba, x, contexts...) + unprepared_valop = reset_count!(cc) + $op!(cc, res2, ba, x, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end elseif op in [:pushforward, :pullback] @@ -362,7 +398,13 @@ for op in ALL_OPS prepared_valop = reset_count!(cc) $op(cc, prep, ba, x, tang, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + $val_and_op(cc, ba, x, tang, contexts...) + unprepared_valop = reset_count!(cc) + $op(cc, ba, x, tang, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) @@ -397,7 +439,13 @@ for op in ALL_OPS prepared_valop = reset_count!(cc) $op!(cc, res1, prep, ba, x, tang, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + $val_and_op!(cc, res1, ba, x, tang, contexts...) + unprepared_valop = reset_count!(cc) + $op!(cc, res1, ba, x, tang, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end @eval function benchmark_aux(ba::AbstractADType, scen::$S2out; subset::Symbol) @@ -432,7 +480,13 @@ for op in ALL_OPS prepared_valop = reset_count!(cc) $op(cc, y, prep, ba, x, tang, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + $val_and_op(cc, y, ba, x, tang, contexts...) + unprepared_valop = reset_count!(cc) + $op(cc, y, ba, x, tang, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end @eval function benchmark_aux(ba::AbstractADType, scen::$S2in; subset::Symbol) @@ -471,7 +525,13 @@ for op in ALL_OPS prepared_valop = reset_count!(cc) $op!(cc, y, res1, prep, ba, x, tang, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + $val_and_op!(cc, y, res1, ba, x, tang, contexts...) + unprepared_valop = reset_count!(cc) + $op!(cc, y, res1, ba, x, tang, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end elseif op in [:hvp] @@ -504,7 +564,12 @@ for op in ALL_OPS prepared_valop = -1 # TODO: fix $op(cc, prep, ba, x, tang, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + unprepared_valop = -1 # TODO: fix + $op(cc, ba, x, tang, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) @@ -536,7 +601,12 @@ for op in ALL_OPS prepared_valop = -1 # TODO: fix $op!(cc, res2, prep, ba, x, tang, contexts...) prepared_op = reset_count!(cc) - return CallsResult(; preparation, prepared_valop, prepared_op) + unprepared_valop = -1 # TODO: fix + $op!(cc, res2, ba, x, tang, contexts...) + unprepared_op = reset_count!(cc) + return CallsResult(; + prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op + ) end end end From a9d986b7c8b3bae3db8cfe7e8a43eeb71d4734a5 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 14:15:27 +0200 Subject: [PATCH 07/18] Fix --- DifferentiationInterfaceTest/src/tests/benchmark_eval.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl index 3e1b22c68..8e72f71a2 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl @@ -232,9 +232,9 @@ for op in ALL_OPS prepared_valop = reset_count!(cc) $op(cc, y, prep, ba, x, contexts...) prepared_op = reset_count!(cc) - $val_and_op(cc, y,, ba, x, contexts...) + $val_and_op(cc, y, ba, x, contexts...) unprepared_valop = reset_count!(cc) - $op(cc, y,, ba, x, contexts...) + $op(cc, y, ba, x, contexts...) unprepared_op = reset_count!(cc) return CallsResult(; prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op From 9fb00b66081b68054c62ef8d8293770b619df2d3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 14:33:23 +0200 Subject: [PATCH 08/18] Fix --- .../src/tests/correctness_eval.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index c7a2d6da7..7d990fe92 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -79,7 +79,7 @@ for op in ALL_OPS @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end - if sparsity && op == :jacobian + if sparsity && $op == jacobian @test mynnz(res1_out1_val) == mynnz(scen.res1) @test mynnz(res1_out2_val) == mynnz(scen.res1) @test mynnz(res1_out1_noval) == mynnz(scen.res1) @@ -144,7 +144,7 @@ for op in ALL_OPS @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end - if sparsity && op == :jacobian + if sparsity && $op == jacobian @test mynnz(res1_out1_val) == mynnz(scen.res1) @test mynnz(res1_out2_val) == mynnz(scen.res1) @test mynnz(res1_out1_noval) == mynnz(scen.res1) @@ -206,7 +206,7 @@ for op in ALL_OPS @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end - if sparsity && op == :jacobian + if sparsity && $op == jacobian @test mynnz(res1_out1_val) == mynnz(scen.res1) @test mynnz(res1_out2_val) == mynnz(scen.res1) @test mynnz(res1_out1_noval) == mynnz(scen.res1) @@ -274,7 +274,7 @@ for op in ALL_OPS @test res1_out1_noval ≈ scen.res1 @test res1_out2_noval ≈ scen.res1 end - if sparsity && op == :jacobian + if sparsity && $op == jacobian @test mynnz(res1_out1_val) == mynnz(scen.res1) @test mynnz(res1_out2_val) == mynnz(scen.res1) @test mynnz(res1_out1_noval) == mynnz(scen.res1) @@ -330,7 +330,7 @@ for op in ALL_OPS @test res2_out1_noval ≈ scen.res2 @test res2_out2_noval ≈ scen.res2 end - if sparsity && op == :hessian + if sparsity && $op == hessian @test mynnz(res2_out1_val) == mynnz(scen.res2) @test mynnz(res2_out2_val) == mynnz(scen.res2) @test mynnz(res2_out1_noval) == mynnz(scen.res2) @@ -399,7 +399,7 @@ for op in ALL_OPS @test res2_out1_noval ≈ scen.res2 @test res2_out2_noval ≈ scen.res2 end - if sparsity && op == :hessian + if sparsity && $op == hessian @test mynnz(res2_out1_val) == mynnz(scen.res2) @test mynnz(res2_out2_val) == mynnz(scen.res2) @test mynnz(res2_out1_noval) == mynnz(scen.res2) From dddf71ec5b59b0ab958fa2632ecc720f7722531e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 14:44:07 +0200 Subject: [PATCH 09/18] Add count calls --- DifferentiationInterfaceTest/src/test_differentiation.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index ec6c83907..1edd85b5a 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -184,6 +184,7 @@ function benchmark_differentiation( benchmark::Symbol=:prepared, excluded::Vector{Symbol}=Symbol[], logging::Bool=false, + count_calls::Bool, ) return test_differentiation( backends, @@ -193,5 +194,6 @@ function benchmark_differentiation( benchmark, logging, excluded, + count_calls, ) end From 544a63d4f35aedc81b7e6cee3674e2645aecd7a4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 14:54:04 +0200 Subject: [PATCH 10/18] Default count calls --- DifferentiationInterfaceTest/src/test_differentiation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index 1edd85b5a..a10b0b68d 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -184,7 +184,7 @@ function benchmark_differentiation( benchmark::Symbol=:prepared, excluded::Vector{Symbol}=Symbol[], logging::Bool=false, - count_calls::Bool, + count_calls::Bool=true, ) return test_differentiation( backends, From cd6e1f607a9568d7dbc02fc88d15856652053514 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:32:11 +0200 Subject: [PATCH 11/18] Fix --- DifferentiationInterface/test/Back/ForwardDiff/test.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index a9a77f4aa..ff2225127 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -37,8 +37,7 @@ test_differentiation( test_differentiation( backends, vcat(component_scenarios(), static_scenarios()); # FD accesses individual indices - excluded=[:jacobian], # jacobian is super slow for some reason - excluded=SECOND_ORDER, + excluded=vcat(SECOND_ORDER, [:jacobian]), # jacobian is super slow for some reason logging=LOGGING, ); From fc3967954d6a0555686734636312e6146c806f64 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:26:50 +0200 Subject: [PATCH 12/18] Custom stacking for StaticArrays --- DifferentiationInterface/Project.toml | 3 +++ .../DifferentiationInterfaceStaticArraysExt.jl | 10 ++++++++++ .../src/DifferentiationInterface.jl | 1 + DifferentiationInterface/src/first_order/jacobian.jl | 4 ++-- DifferentiationInterface/src/second_order/hessian.jl | 2 +- DifferentiationInterface/src/utils/linalg.jl | 2 ++ 6 files changed, 19 insertions(+), 3 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl create mode 100644 DifferentiationInterface/src/utils/linalg.jl diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index ce7502c68..63cbaecec 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -20,6 +20,7 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -37,6 +38,7 @@ DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff" DifferentiationInterfaceReverseDiffExt = "ReverseDiff" DifferentiationInterfaceSparseArraysExt = "SparseArrays" DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings" +DifferentiationInterfaceStaticArraysExt = "StaticArrays" DifferentiationInterfaceSymbolicsExt = "Symbolics" DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] @@ -56,6 +58,7 @@ PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "<0.0.1,1" SparseConnectivityTracer = "0.5.0,0.6" +StaticArrays = "1.9.7" SparseMatrixColorings = "0.4.5" Symbolics = "5.27.1, 6" Tracker = "0.2.33" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl new file mode 100644 index 000000000..53d6e7aff --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl @@ -0,0 +1,10 @@ +module DifferentiationInterfaceStaticArraysExt + +import DifferentiationInterface as DI +using StaticArrays: SArray + +function DI.stack_vec_col(t::NTuple{B,<:SArray}) where {B} + return hcat(map(vec, t)...) +end + +end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 8cee0ad18..65c434b5c 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -43,6 +43,7 @@ include("utils/check.jl") include("utils/exceptions.jl") include("utils/printing.jl") include("utils/context.jl") +include("utils/linalg.jl") include("first_order/pushforward.jl") include("first_order/pullback.jl") diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 3984dd49f..bde1c5255 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -241,7 +241,7 @@ function _jacobian_aux( batched_seeds[a], contexts..., ) - block = stack(vec, dy_batch; dims=2) + block = stack_vec_col(dy_batch) if N % B != 0 && a == lastindex(batched_seeds) block = block[:, 1:(N - (a - 1) * B)] end @@ -269,7 +269,7 @@ function _jacobian_aux( dx_batch = pullback( f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts... ) - block = stack(vec, dx_batch; dims=1) + block = stack_vec_row(dx_batch) if M % B != 0 && a == lastindex(batched_seeds) block = block[1:(M - (a - 1) * B), :] end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index f8e612fd5..471d9e89b 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -113,7 +113,7 @@ function hessian( hess_blocks = map(eachindex(batched_seeds)) do a dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...) - block = stack(vec, dg_batch; dims=2) + block = stack_vec_col(dg_batch) if N % B != 0 && a == lastindex(batched_seeds) block = block[:, 1:(N - (a - 1) * B)] end diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl new file mode 100644 index 000000000..392c7416f --- /dev/null +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -0,0 +1,2 @@ +stack_vec_col(t::NTuple) = stack(vec, t; dims=2) +stack_vec_row(t::NTuple) = stack(vec, t; dims=1) From 1b1f9ab868ceb45a490e2bb3259f829729614cf8 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:27:23 +0200 Subject: [PATCH 13/18] Bump --- DifferentiationInterface/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 63cbaecec..06169c44e 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.9" +version = "0.6.10" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 85d319589fe1fc36d200cb9053acdce0279e1d87 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 17:17:51 +0200 Subject: [PATCH 14/18] Clearer modulo --- .../hessian.jl | 2 +- .../jacobian.jl | 4 ++-- DifferentiationInterface/src/first_order/jacobian.jl | 4 ++-- DifferentiationInterface/src/second_order/hessian.jl | 6 ++---- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 11624950e..e976e59fc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -62,7 +62,7 @@ function _prepare_sparse_hessian_aux( seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2) batched_seeds = [ - ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B)) for + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for a in 1:div(Ng, B, RoundUp) ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index b8ed2cbe7..6dd5c0caa 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -113,7 +113,7 @@ function _prepare_sparse_jacobian_aux( seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(y)), groups; dims=2) batched_seeds = [ - ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B)) for + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for a in 1:div(Ng, B, RoundUp) ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] @@ -150,7 +150,7 @@ function _prepare_sparse_jacobian_aux( seeds = [multibasis(backend, y, eachindex(y)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=1) batched_seeds = [ - ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B)) for + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for a in 1:div(Ng, B, RoundUp) ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index bde1c5255..0a2304b13 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -111,7 +111,7 @@ function _prepare_jacobian_aux( N = length(x) seeds = [basis(backend, x, ind) for ind in eachindex(x)] batched_seeds = [ - ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for a in 1:div(N, B, RoundUp) ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] @@ -138,7 +138,7 @@ function _prepare_jacobian_aux( M = length(y) seeds = [basis(backend, y, ind) for ind in eachindex(y)] batched_seeds = [ - ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B)) for + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), M)], Val(B)) for a in 1:div(M, B, RoundUp) ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 025c69315..5a200b671 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -82,7 +82,7 @@ function _prepare_hessian_aux( N = length(x) seeds = [basis(backend, x, ind) for ind in eachindex(x)] batched_seeds = [ - ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for a in 1:div(N, B, RoundUp) ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] @@ -111,7 +111,7 @@ function hessian( f, hvp_prep, backend, x, batched_seeds[1], contexts... ) - hess_blocks = map(eachindex(batched_seeds)) do a + hess = mapreduce(hcat, eachindex(batched_seeds)) do a dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...) block = stack_vec_col(dg_batch) if N % B != 0 && a == lastindex(batched_seeds) @@ -119,8 +119,6 @@ function hessian( end block end - - hess = reduce(hcat, hess_blocks) return hess end From 720be01d67abdbe708204adcf1cbd63aa3c6ee26 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 17:19:10 +0200 Subject: [PATCH 15/18] Woops --- DifferentiationInterface/src/second_order/hessian.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 5a200b671..d56027d50 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -111,7 +111,7 @@ function hessian( f, hvp_prep, backend, x, batched_seeds[1], contexts... ) - hess = mapreduce(hcat, eachindex(batched_seeds)) do a + hess_blocks = map(eachindex(batched_seeds)) do a dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...) block = stack_vec_col(dg_batch) if N % B != 0 && a == lastindex(batched_seeds) @@ -119,6 +119,8 @@ function hessian( end block end + + hess = reduce(hcat, hess_blocks) return hess end From 71e1f4551547a019af38df1fea7bfc78f20853fa Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:06:32 +0200 Subject: [PATCH 16/18] Undo mo1 --- .../hessian.jl | 2 +- .../jacobian.jl | 4 ++-- DifferentiationInterface/src/first_order/jacobian.jl | 4 ++-- DifferentiationInterface/src/second_order/hessian.jl | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index e976e59fc..11624950e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -62,7 +62,7 @@ function _prepare_sparse_hessian_aux( seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2) batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for + ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B)) for a in 1:div(Ng, B, RoundUp) ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 6dd5c0caa..b8ed2cbe7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -113,7 +113,7 @@ function _prepare_sparse_jacobian_aux( seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(y)), groups; dims=2) batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for + ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B)) for a in 1:div(Ng, B, RoundUp) ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] @@ -150,7 +150,7 @@ function _prepare_sparse_jacobian_aux( seeds = [multibasis(backend, y, eachindex(y)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=1) batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for + ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B)) for a in 1:div(Ng, B, RoundUp) ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 0a2304b13..bde1c5255 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -111,7 +111,7 @@ function _prepare_jacobian_aux( N = length(x) seeds = [basis(backend, x, ind) for ind in eachindex(x)] batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for + ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:div(N, B, RoundUp) ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] @@ -138,7 +138,7 @@ function _prepare_jacobian_aux( M = length(y) seeds = [basis(backend, y, ind) for ind in eachindex(y)] batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), M)], Val(B)) for + ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B)) for a in 1:div(M, B, RoundUp) ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index d56027d50..025c69315 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -82,7 +82,7 @@ function _prepare_hessian_aux( N = length(x) seeds = [basis(backend, x, ind) for ind in eachindex(x)] batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for + ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:div(N, B, RoundUp) ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] From ba3c1309c9ca939c3f07fe826e4707c8cf053deb Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:45:18 +0200 Subject: [PATCH 17/18] Mapreduce --- DifferentiationInterface/src/first_order/jacobian.jl | 8 ++------ DifferentiationInterface/src/second_order/hessian.jl | 4 +--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index bde1c5255..7919d223b 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -232,7 +232,7 @@ function _jacobian_aux( f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts... ) - jac_blocks = map(eachindex(batched_seeds)) do a + jac = mapreduce(hcat, eachindex(batched_seeds)) do a dy_batch = pushforward( f_or_f!y..., pushforward_prep_same, @@ -247,8 +247,6 @@ function _jacobian_aux( end block end - - jac = reduce(hcat, jac_blocks) return jac end @@ -265,7 +263,7 @@ function _jacobian_aux( f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts... ) - jac_blocks = map(eachindex(batched_seeds)) do a + jac = mapreduce(vcat, eachindex(batched_seeds)) do a dx_batch = pullback( f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts... ) @@ -275,8 +273,6 @@ function _jacobian_aux( end block end - - jac = reduce(vcat, jac_blocks) return jac end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 025c69315..d5d9921fc 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -111,7 +111,7 @@ function hessian( f, hvp_prep, backend, x, batched_seeds[1], contexts... ) - hess_blocks = map(eachindex(batched_seeds)) do a + hess = mapreduce(hcat, eachindex(batched_seeds)) do a dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...) block = stack_vec_col(dg_batch) if N % B != 0 && a == lastindex(batched_seeds) @@ -119,8 +119,6 @@ function hessian( end block end - - hess = reduce(hcat, hess_blocks) return hess end From 4ca7fa187589848afb6bb65961e42b4991d4127d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 20:29:49 +0200 Subject: [PATCH 18/18] Add function filter to type stability checks --- .../src/test_differentiation.jl | 9 +- .../src/tests/type_stability_eval.jl | 359 ++++++++++-------- 2 files changed, 207 insertions(+), 161 deletions(-) diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index a10b0b68d..6929033c6 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -48,6 +48,7 @@ For `type_stability` and `benchmark`, the possible values are `:none`, `:prepare **Type stability options:** - `ignored_modules=nothing`: list of modules that JET.jl should ignore +- `function_filter`: filter for functions that JET.jl should ignore (with a reasonable default) **Benchmark options:** @@ -72,6 +73,11 @@ function test_differentiation( sparsity::Bool=true, # type stability options ignored_modules=nothing, + function_filter=if VERSION >= v"1.11" + @nospecialize(f) -> true + else + @nospecialize(f) -> f != Base.mapreduce_empty # fix for `mapreduce` in jacobian and hessian + end, # benchmark options count_calls::Bool=true, ) @@ -136,7 +142,8 @@ function test_differentiation( adapted_backend, scen; subset=type_stability, - ignored_modules=ignored_modules, + ignored_modules, + function_filter, ) end yield() diff --git a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl index 1afb4a2f8..01e173951 100644 --- a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl @@ -18,280 +18,319 @@ for op in ALL_OPS if op in [:derivative, :gradient, :jacobian] @eval function test_jet( - ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op(f, prep, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op( - f, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, prep, ba, x, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, res1, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res1, ba, x, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res1, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, res1, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res1, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, prep, ba, x, contexts...) return nothing end op == :gradient && continue @eval function test_jet( - ba::AbstractADType, scen::$S2out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S2out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, y, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, y, ba, x, contexts... - ) - (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, y, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, ba, x, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, y, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, y, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, y, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op( - f, y, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, y, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, y, prep, ba, x, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S2in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S2in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, y, res1, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, y, ba, x, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, ba, x, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, ba, x, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, y, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, y, res1, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, y, res1, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, y, res1, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, y, res1, prep, ba, x, contexts...) return nothing end elseif op in [:second_derivative, :hessian] @eval function test_jet( - ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op(f, prep, ba, x, contexts...) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op( - f, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, prep, ba, x, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) (subset == :full) && - @test_opt ignored_modules = ignored_modules $prep_op(f, ba, x, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res2, ba, x, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, res2, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res2, ba, x, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, res2, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, res2, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res2, prep, ba, x, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, res2, prep, ba, x, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, res2, prep, ba, x, contexts...) return nothing end elseif op in [:pushforward, :pullback] @eval function test_jet( - ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, tang, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, ba, x, tang, contexts... - ) - (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, ba, x, tang, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op( - f, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, prep, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op( - f, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, prep, ba, x, tang, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, tang, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res1, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, ba, x, tang, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res1, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, res1, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res1, prep, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op!( - f, res1, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, res1, prep, ba, x, tang, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S2out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S2out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, y, tang, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, y, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op( - f, y, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, ba, x, tang, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, y, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, y, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, y, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op( - f, y, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, y, prep, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op( - f, y, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, y, prep, ba, x, tang, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S2in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S2in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, y, tang, res1, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, y, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, ba, x, tang, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, y, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, y, res1, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, y, res1, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, y, res1, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, y, res1, prep, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $val_and_op!( - f, y, res1, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!(f, y, res1, prep, ba, x, tang, contexts...) return nothing end elseif op in [:hvp] @eval function test_jet( - ba::AbstractADType, scen::$S1out; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1out; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, tang, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, ba, x, tang, contexts... - ) (subset == :full) && - @test_opt ignored_modules = ignored_modules $op(f, ba, x, tang, contexts...) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op( - f, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op(f, prep, ba, x, tang, contexts...) return nothing end @eval function test_jet( - ba::AbstractADType, scen::$S1in; subset::Symbol, ignored_modules + ba::AbstractADType, + scen::$S1in; + subset::Symbol, + ignored_modules, + function_filter, ) (; f, x, tang, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && @test_opt ignored_modules = ignored_modules $prep_op( - f, ba, x, tang, contexts... - ) - (subset == :full) && @test_opt ignored_modules = ignored_modules $op!( - f, res2, ba, x, tang, contexts... - ) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $prep_op(f, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res2, ba, x, tang, contexts...) (subset in (:prepared, :full)) && - @test_opt ignored_modules = ignored_modules $op!( - f, res2, prep, ba, x, tang, contexts... - ) + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $op!(f, res2, prep, ba, x, tang, contexts...) return nothing end end