Skip to content

Commit

Permalink
First tests for second order (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Mar 24, 2024
1 parent 224f8cf commit 5b3b330
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/DifferentiationTest/DifferentiationTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ using ..DifferentiationInterface
import ..DifferentiationInterface as DI
using ..DifferentiationInterface:
AutoTaped,
inner,
mode,
mysimilar,
myzero,
myzero!!,
outer,
supports_mutation,
supports_pushforward,
supports_pullback
Expand Down
4 changes: 4 additions & 0 deletions src/DifferentiationTest/printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ function backend_string(backend::AbstractADType)
error("Unknown mode")
end
end

function backend_string(backend::SecondOrder)
return "$(backend_string(outer(backend))) / $(backend_string(inner(backend)))"
end
14 changes: 7 additions & 7 deletions src/hvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ By order of preference:
"""
hvp(f, backend, x, v, [extras]) -> p
"""
function hvp(
f::F, backend::AbstractADType, x::Number, v, extras=prepare_hvp(f, backend, x)
) where {F}
return v * second_derivative(f, backend, x, extras)
end

function hvp(
f::F, backend::AbstractADType, x, v, extras=prepare_hvp(f, backend, x)
) where {F}
Expand All @@ -27,7 +21,13 @@ function hvp(
return hvp(f, new_backend, x, v, new_extras)
end

function hvp(f::F, backend::SecondOrder, x, v, extras=prepare_hvp(backend, f, x)) where {F}
function hvp(
f::F, backend::SecondOrder, x::Number, v::Number, extras=prepare_hvp(f, backend, x)
) where {F}
return v * second_derivative(f, backend, x, extras)
end

function hvp(f::F, backend::SecondOrder, x, v, extras=prepare_hvp(f, backend, x)) where {F}
return hvp_aux(f, backend, x, v, extras, hvp_mode(backend))
end

Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,8 @@ using Zygote: Zygote
@time include("zygote.jl")
end
end

@testset verbose = true "Second order" begin
include("second_order.jl")
end
end;
31 changes: 31 additions & 0 deletions test/second_order.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using ADTypes
using DifferentiationInterface
using DifferentiationInterface.DifferentiationTest
using DifferentiationInterface.DifferentiationTest: backend_string

using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Enzyme: Enzyme
using Zygote: Zygote

using JET: JET
using Test

SECOND_ORDER_BACKENDS = Dict(
"forward/forward" => [
SecondOrder(AutoEnzyme(Enzyme.Forward), AutoForwardDiff()),
SecondOrder(AutoForwardDiff(), AutoEnzyme(Enzyme.Forward)),
],
"forward/reverse" => [SecondOrder(AutoForwardDiff(), AutoZygote())],
"reverse/forward" => [],
)

@testset verbose = true "Cross backends" begin
@testset verbose = true "$second_order_mode" for (second_order_mode, backends) in
pairs(SECOND_ORDER_BACKENDS)
@info "Testing $second_order_mode..."
@time @testset "$(backend_string(backend))" for backend in backends
test_operators(backend; first_order=false, type_stability=false)
end
end
end;
3 changes: 1 addition & 2 deletions test/zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ test_operators(

test_operators(
[AutoZeroForward(), AutoZeroReverse()];
allocating=false,
correctness=false,
type_stability=true,
type_stability=false,
allocations=true,
);

Expand Down

0 comments on commit 5b3b330

Please sign in to comment.