Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: customizable DIT benchmarks #636

Merged
merged 3 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ function DIT.lux_scenarios(rng::AbstractRNG=default_rng())
scen = Scenario{:gradient,:out}(
square_loss,
ComponentArray(ps);
contexts=(Constant(model), Constant(x), Constant(st)),
contexts=(DI.Constant(model), DI.Constant(x), DI.Constant(st)),
res1=g,
)
push!(scens, scen)
Expand Down
10 changes: 10 additions & 0 deletions DifferentiationInterfaceTest/src/test_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ Each setting tests/benchmarks a different subset of calls:

- `count_calls=true`: whether to also count function calls during benchmarking
- `benchmark_test=true`: whether to include tests which succeed iff benchmark doesn't error
- `benchmark_seconds=1`: how long to run each benchmark for
- `benchmark_aggregation=minimum`: function used to aggregate sample measurements
"""
function test_differentiation(
backends::Vector{<:AbstractADType},
Expand Down Expand Up @@ -87,6 +89,8 @@ function test_differentiation(
# benchmark options
count_calls::Bool=true,
benchmark_test::Bool=true,
benchmark_seconds::Real=1,
benchmark_aggregation=minimum,
)
@assert type_stability in (:none, :prepared, :full)
@assert allocations in (:none, :prepared, :full)
Expand Down Expand Up @@ -173,6 +177,8 @@ function test_differentiation(
subset=benchmark,
count_calls,
benchmark_test,
benchmark_seconds,
benchmark_aggregation,
)
end
yield()
Expand Down Expand Up @@ -211,6 +217,8 @@ function benchmark_differentiation(
logging::Bool=false,
count_calls::Bool=true,
benchmark_test::Bool=true,
benchmark_seconds::Real=1,
benchmark_aggregation=minimum,
)
return test_differentiation(
backends,
Expand All @@ -223,5 +231,7 @@ function benchmark_differentiation(
excluded,
count_calls,
benchmark_test,
benchmark_seconds,
benchmark_aggregation,
)
end
37 changes: 19 additions & 18 deletions DifferentiationInterfaceTest/src/tests/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

See the documentation of [Chairmarks.jl](https://github.com/LilithHafner/Chairmarks.jl) for more details on the measurement fields.
"""
Base.@kwdef struct DifferentiationBenchmarkDataRow
Base.@kwdef struct DifferentiationBenchmarkDataRow{T}
"backend used for benchmarking"
backend::AbstractADType
"scenario used for benchmarking"
Expand All @@ -71,16 +71,16 @@
samples::Int
"number of evaluations used for averaging in each sample"
evals::Int
"minimum runtime over all samples, in seconds"
time::Float64
"minimum number of allocations over all samples"
allocs::Float64
"minimum memory allocated over all samples, in bytes"
bytes::Float64
"minimum fraction of time spent in garbage collection over all samples, between 0.0 and 1.0"
gc_fraction::Float64
"minimum fraction of time spent compiling over all samples, between 0.0 and 1.0"
compile_fraction::Float64
"aggregated runtime over all samples, in seconds"
time::T
"aggregated number of allocations over all samples"
allocs::T
"aggregated memory allocated over all samples, in bytes"
bytes::T
"aggregated fraction of time spent in garbage collection over all samples, between 0.0 and 1.0"
gc_fraction::T
"aggregated fraction of time spent compiling over all samples, between 0.0 and 1.0"
compile_fraction::T
end

function record!(
Expand All @@ -91,21 +91,22 @@
prepared::Union{Nothing,Bool},
bench::Benchmark,
calls::Integer,
aggregation,
)
bench_min = minimum(bench)
bench_agg = aggregation(bench)

Check warning on line 96 in DifferentiationInterfaceTest/src/tests/benchmark.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterfaceTest/src/tests/benchmark.jl#L96

Added line #L96 was not covered by tests
row = DifferentiationBenchmarkDataRow(;
backend=backend,
scenario=scenario,
operator=Symbol(operator),
prepared=prepared,
calls=calls,
samples=length(bench.samples),
evals=Int(bench_min.evals),
time=bench_min.time,
allocs=bench_min.allocs,
bytes=bench_min.bytes,
gc_fraction=bench_min.gc_fraction,
compile_fraction=bench_min.compile_fraction,
evals=Int(bench_agg.evals),
time=bench_agg.time,
allocs=bench_agg.allocs,
bytes=bench_agg.bytes,
gc_fraction=bench_agg.gc_fraction,
compile_fraction=bench_agg.compile_fraction,
)
return push!(data, row)
end
Loading
Loading