Skip to content

Commit

Permalink
More bugfixing in Tracker and Enzyme gradients (#96)
Browse files Browse the repository at this point in the history
* More bugfixing in Tracker and Enzyme gradients

* Replace zeros with similar

* Use myzero in Enzyme
  • Loading branch information
gdalle authored Mar 25, 2024
1 parent 448f8d6 commit efef587
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DifferentiationInterfaceChairmarksExt
using ADTypes: AbstractADType
using Chairmarks: @be, Benchmark, Sample
using DifferentiationInterface
using DifferentiationInterface: myzero
using DifferentiationInterface: mysimilar
using DifferentiationInterface.DifferentiationTest: Scenario, BenchmarkData, record!
using Test: @testset, @test

Expand All @@ -18,7 +18,7 @@ function run_benchmark!(
)
(; f, x, dx, dy) = deepcopy(scen)
extras = prepare_pushforward(f, ba, x)
bench1 = @be myzero(dy) value_and_pushforward!!(f, _, ba, x, dx, extras)
bench1 = @be mysimilar(dy) value_and_pushforward!!(f, _, ba, x, dx, extras)
if allocations && dy isa Number
@test 0 == minimum(bench1).allocs
end
Expand All @@ -36,7 +36,7 @@ function run_benchmark!(
(; f, x, y, dx, dy) = deepcopy(scen)
f! = f
extras = prepare_pushforward(f!, ba, y, x)
bench1 = @be (myzero(y), myzero(dy)) value_and_pushforward!!(
bench1 = @be (mysimilar(y), mysimilar(dy)) value_and_pushforward!!(
f!, _[1], _[2], ba, x, dx, extras
)
if allocations
Expand All @@ -57,7 +57,7 @@ function run_benchmark!(
)
(; f, x, dx, dy) = deepcopy(scen)
extras = prepare_pullback(f, ba, x)
bench1 = @be myzero(dx) value_and_pullback!!(f, _, ba, x, dy, extras)
bench1 = @be mysimilar(dx) value_and_pullback!!(f, _, ba, x, dy, extras)
if allocations && dy isa Number
@test 0 == minimum(bench1).allocs
end
Expand All @@ -75,7 +75,7 @@ function run_benchmark!(
(; f, x, y, dx, dy) = deepcopy(scen)
f! = f
extras = prepare_pullback(f!, ba, y, x)
bench1 = @be (myzero(y), myzero(dx)) value_and_pullback!!(
bench1 = @be (mysimilar(y), mysimilar(dx)) value_and_pullback!!(
f!, _[1], _[2], ba, x, dy, extras
)
if allocations
Expand All @@ -96,7 +96,7 @@ function run_benchmark!(
)
(; f, x, y, dy) = deepcopy(scen)
extras = prepare_derivative(f, ba, x)
bench1 = @be myzero(dy) value_and_derivative!!(f, _, ba, x, extras)
bench1 = @be mysimilar(dy) value_and_derivative!!(f, _, ba, x, extras)
# only test allocations if the output is scalar
if allocations && y isa Number
@test 0 == minimum(bench1).allocs
Expand All @@ -115,7 +115,7 @@ function run_benchmark!(
(; f, x, y, dy) = deepcopy(scen)
f! = f
extras = prepare_derivative(f!, ba, y, x)
bench1 = @be (myzero(y), myzero(dy)) value_and_derivative!!(
bench1 = @be (mysimilar(y), mysimilar(dy)) value_and_derivative!!(
f!, _[1], _[2], ba, x, extras
)
if allocations
Expand All @@ -136,7 +136,7 @@ function run_benchmark!(
)
(; f, x, dx) = deepcopy(scen)
extras = prepare_gradient(f, ba, x)
bench1 = @be myzero(dx) value_and_gradient!!(f, _, ba, x, extras)
bench1 = @be mysimilar(dx) value_and_gradient!!(f, _, ba, x, extras)
if allocations
@test 0 == minimum(bench1).allocs
end
Expand All @@ -155,8 +155,8 @@ function run_benchmark!(
)
(; f, x, y) = deepcopy(scen)
extras = prepare_jacobian(f, ba, x)
jac_template = zeros(eltype(y), length(y), length(x))
bench1 = @be myzero(jac_template) value_and_jacobian!!(f, _, ba, x, extras)
jac_template = similar(y, length(y), length(x))
bench1 = @be mysimilar(jac_template) value_and_jacobian!!(f, _, ba, x, extras)
# never test allocations
record!(data, ba, op, value_and_jacobian!!, scen, bench1)
return nothing
Expand All @@ -172,8 +172,8 @@ function run_benchmark!(
(; f, x, y) = deepcopy(scen)
f! = f
extras = prepare_jacobian(f!, ba, y, x)
jac_template = zeros(eltype(y), length(y), length(x))
bench1 = @be (myzero(y), myzero(jac_template)) value_and_jacobian!!(
jac_template = similar(y, length(y), length(x))
bench1 = @be (mysimilar(y), mysimilar(jac_template)) value_and_jacobian!!(
f!, _[1], _[2], ba, x, extras
)
if allocations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module DifferentiationInterfaceCorrectnessTestExt

using ADTypes: AbstractADType
using DifferentiationInterface
using DifferentiationInterface: myisapprox, myzero
using DifferentiationInterface: myisapprox, mysimilar
using DifferentiationInterface.DifferentiationTest: Scenario
import DifferentiationInterface.DifferentiationTest as DT
using ForwardDiff: ForwardDiff
Expand All @@ -22,15 +22,15 @@ end
## Pushforward

function test_correctness(ba::AbstractADType, ::typeof(pushforward), scen::Scenario{false})
(; f, x, y, dx) = new_scen = deepcopy(scen)
(; f, x, y, dx, dy) = new_scen = deepcopy(scen)
dy_true = true_pushforward(f, x, y, dx; mutating=false)

y_out1, dy_out1 = value_and_pushforward(f, ba, x, dx)
dy_in2 = myzero(dy_out1)
dy_in2 = mysimilar(dy)
y_out2, dy_out2 = value_and_pushforward!!(f, dy_in2, ba, x, dx)

dy_out3 = pushforward(f, ba, x, dx)
dy_in4 = myzero(dy_out3)
dy_in4 = mysimilar(dy)
dy_out4 = pushforward!!(f, dy_in4, ba, x, dx)

@testset "Primal value" begin
Expand All @@ -47,12 +47,12 @@ function test_correctness(ba::AbstractADType, ::typeof(pushforward), scen::Scena
end

function test_correctness(ba::AbstractADType, ::typeof(pushforward), scen::Scenario{true})
(; f, x, y, dx) = new_scen = deepcopy(scen)
(; f, x, y, dx, dy) = new_scen = deepcopy(scen)
f! = f
dy_true = true_pushforward(f!, x, y, dx; mutating=true)

y_in = myzero(y)
dy_in = myzero(dy_true)
y_in = mysimilar(y)
dy_in = mysimilar(dy)
y_out, dy_out = value_and_pushforward!!(f!, y_in, dy_in, ba, x, dx)

@testset "Primal value" begin
Expand All @@ -72,15 +72,15 @@ end
## Pullback

function test_correctness(ba::AbstractADType, ::typeof(pullback), scen::Scenario{false})
(; f, x, y, dy) = new_scen = deepcopy(scen)
(; f, x, y, dx, dy) = new_scen = deepcopy(scen)
dx_true = true_pullback(f, x, y, dy; mutating=false)

y_out1, dx_out1 = value_and_pullback(f, ba, x, dy)
dx_in2 = myzero(dx_out1)
dx_in2 = mysimilar(dx)
y_out2, dx_out2 = value_and_pullback!!(f, dx_in2, ba, x, dy)

dx_out3 = pullback(f, ba, x, dy)
dx_in4 = myzero(dx_out3)
dx_in4 = mysimilar(dx)
dx_out4 = pullback!!(f, dx_in4, ba, x, dy)

@testset "Primal value" begin
Expand All @@ -97,12 +97,12 @@ function test_correctness(ba::AbstractADType, ::typeof(pullback), scen::Scenario
end

function test_correctness(ba::AbstractADType, ::typeof(pullback), scen::Scenario{true})
(; f, x, y, dy) = new_scen = deepcopy(scen)
(; f, x, y, dx, dy) = new_scen = deepcopy(scen)
f! = f
dx_true = true_pullback(f, x, y, dy; mutating=true)

y_in = myzero(y)
dx_in = myzero(dx_true)
y_in = mysimilar(y)
dx_in = mysimilar(dx)
y_out, dx_out = value_and_pullback!!(f!, y_in, dx_in, ba, x, dy)

@testset "Primal value" begin
Expand All @@ -122,15 +122,15 @@ end
## Derivative

function test_correctness(ba::AbstractADType, ::typeof(derivative), scen::Scenario{false})
(; f, x, y) = new_scen = deepcopy(scen)
(; f, x, y, dx, dy) = new_scen = deepcopy(scen)
der_true = ForwardDiff.derivative(f, x)

y_out1, der_out1 = value_and_derivative(f, ba, x)
der_in2 = myzero(der_out1)
der_in2 = mysimilar(dy)
y_out2, der_out2 = value_and_derivative!!(f, der_in2, ba, x)

der_out3 = derivative(f, ba, x)
der_in4 = myzero(der_out3)
der_in4 = mysimilar(dy)
der_out4 = derivative!!(f, der_in4, ba, x)

@testset "Primal value" begin
Expand All @@ -147,12 +147,12 @@ function test_correctness(ba::AbstractADType, ::typeof(derivative), scen::Scenar
end

function test_correctness(ba::AbstractADType, ::typeof(derivative), scen::Scenario{true})
(; f, x, y) = new_scen = deepcopy(scen)
(; f, x, y, dx, dy) = new_scen = deepcopy(scen)
f! = f
der_true = ForwardDiff.derivative(f!, y, x)

y_in = myzero(y)
der_in = myzero(der_true)
y_in = mysimilar(y)
der_in = mysimilar(dy)
y_out, der_out = value_and_derivative!!(f!, y_in, der_in, ba, x)

@testset "Primal value" begin
Expand All @@ -172,19 +172,19 @@ end
## Gradient

function test_correctness(ba::AbstractADType, ::typeof(gradient), scen::Scenario{false})
(; f, x, y) = new_scen = deepcopy(scen)
(; f, x, y, dx, dy) = new_scen = deepcopy(scen)
grad_true = if x isa Number
ForwardDiff.derivative(f, x)
else
only(Zygote.gradient(f, x))
end

y_out1, grad_out1 = value_and_gradient(f, ba, x)
grad_in2 = myzero(grad_out1)
grad_in2 = mysimilar(dx)
y_out2, grad_out2 = value_and_gradient!!(f, grad_in2, ba, x)

grad_out3 = gradient(f, ba, x)
grad_in4 = myzero(grad_out3)
grad_in4 = mysimilar(dx)
grad_out4 = gradient!!(f, grad_in4, ba, x)

@testset "Primal value" begin
Expand All @@ -207,11 +207,11 @@ function test_correctness(ba::AbstractADType, ::typeof(jacobian), scen::Scenario
jac_true = ForwardDiff.jacobian(f, x)

y_out1, jac_out1 = value_and_jacobian(f, ba, x)
jac_in2 = myzero(jac_out1)
jac_in2 = mysimilar(jac_true)
y_out2, jac_out2 = value_and_jacobian!!(f, jac_in2, ba, x)

jac_out3 = jacobian(f, ba, x)
jac_in4 = myzero(jac_out3)
jac_in4 = mysimilar(jac_true)
jac_out4 = jacobian!!(f, jac_in4, ba, x)

@testset "Primal value" begin
Expand All @@ -236,8 +236,8 @@ function test_correctness(ba::AbstractADType, ::typeof(jacobian), scen::Scenario
f! = f
jac_true = ForwardDiff.jacobian(f!, y, x)

y_in = myzero(y)
jac_in = similar(y, length(y), length(x))
y_in = mysimilar(y)
jac_in = mysimilar(jac_true)
y_out, jac_out = value_and_jacobian!!(f!, y_in, jac_in, ba, x)

@testset "Primal value" begin
Expand Down Expand Up @@ -318,7 +318,7 @@ function true_pushforward(f, x::Number, y::AbstractArray, dx; mutating)
end

function true_pushforward(f, x::AbstractArray, y::Number, dx; mutating)
return dot(ForwardDiff.gradient(f, x), dx)
return dot(Zygote.gradient(f, x)[1], dx)
end

function true_pushforward(f, x::AbstractArray, y::AbstractArray, dx; mutating)
Expand All @@ -342,7 +342,7 @@ function true_pullback(f, x::Number, y::AbstractArray, dy; mutating)
end

function true_pullback(f, x::AbstractArray, y::Number, dy; mutating)
return ForwardDiff.gradient(f, x) .* dy
return Zygote.gradient(f, x)[1] .* dy
end

function true_pullback(f, x::AbstractArray, y::AbstractArray, dy; mutating)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DifferentiationInterfaceEnzymeExt

using ADTypes: ADTypes, AutoEnzyme
using DifferentiationInterface: myupdate!!, mysimilar, myzero, myzero!!
using DifferentiationInterface: mymul!!, myupdate!!, mysimilar, myzero, myzero!!
import DifferentiationInterface as DI
using DocStringExtensions
using Enzyme:
Expand Down
22 changes: 21 additions & 1 deletion ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function DI.value_and_pullback!!(
dx_sametype = convert(typeof(x), dx)
dx_sametype = myzero!!(dx_sametype)
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype))
dx_sametype .*= dy
dx_sametype = mymul!!(dx_sametype, dy)
return y, myupdate!!(dx, dx_sametype)
end

Expand All @@ -24,3 +24,23 @@ function DI.value_and_pullback(
dx = mysimilar(x)
return DI.value_and_pullback!!(f, dx, backend, x, dy, extras)
end

## Gradient

function DI.gradient(f::F, backend::AutoReverseEnzyme, x, extras::Nothing) where {F}
return gradient(Reverse, f, x)
end

function DI.gradient!!(f::F, grad, backend::AutoReverseEnzyme, x, extras::Nothing) where {F}
return gradient!(Reverse, grad, f, x)
end

function DI.gradient(f::F, backend::AutoReverseEnzyme, x::Number, extras::Nothing) where {F}
return autodiff(Reverse, f, Active(x))
end

function DI.gradient!!(
f::F, grad, backend::AutoReverseEnzyme, x::Number, extras::Nothing
) where {F}
return autodiff(Reverse, f, Active(x))
end
Loading

0 comments on commit efef587

Please sign in to comment.