-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable testing correctness against reference, reorganize testing (#103)
* Start reference tests * Continue references * Restructure test suite * No timing * Add Documenter * Qualify doctest * Fix correctness * Fix docs * Refix * fIX NESTED
- Loading branch information
Showing
51 changed files
with
1,175 additions
and
1,119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
52 changes: 45 additions & 7 deletions
52
ext/DifferentiationInterfaceComponentArraysExt/DifferentiationInterfaceComponentArraysExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,64 @@ | ||
module DifferentiationInterfaceComponentArraysExt | ||
|
||
using ComponentArrays: ComponentVector | ||
using DifferentiationInterface.DifferentiationTest: Scenario, SCALING_VEC | ||
using DifferentiationInterface.DifferentiationTest: | ||
Scenario, Reference, make_scalar_to_array, make_scalar_to_array!, scalar_to_array_ref | ||
using LinearAlgebra: dot | ||
|
||
const SCALING_CVEC = ComponentVector(; a=collect(1:7), b=collect(8:12)) | ||
|
||
function scalar_to_componentvector(x::Number)::ComponentVector | ||
return sin.(SCALING_CVEC .* x) # output size 12 | ||
end | ||
## Vector to scalar | ||
|
||
function componentvector_to_scalar(x::ComponentVector)::Number | ||
return sum(sin, x.a) + sum(cos, x.b) | ||
end | ||
|
||
componentvector_to_scalar_gradient(x) = ComponentVector(; a=cos.(x.a), b=-sin.(x.b)) | ||
|
||
function componentvector_to_scalar_pushforward(x, dx) | ||
return dot(componentvector_to_scalar_gradient(x), dx) | ||
end | ||
|
||
function componentvector_to_scalar_pullback(x, dy) | ||
return componentvector_to_scalar_gradient(x) .* dy | ||
end | ||
|
||
function componentvector_to_scalar_ref() | ||
return Reference(; | ||
pushforward=componentvector_to_scalar_pushforward, | ||
pullback=componentvector_to_scalar_pullback, | ||
gradient=componentvector_to_scalar_gradient, | ||
) | ||
end | ||
|
||
## Gather | ||
|
||
const SCALING_CVEC = ComponentVector(; a=collect(1:7), b=collect(8:12)) | ||
|
||
function component_scenarios_allocating() | ||
return [ | ||
Scenario(scalar_to_componentvector; x=2.0), | ||
Scenario( | ||
make_scalar_to_array(SCALING_CVEC); x=2.0, ref=scalar_to_array_ref(SCALING_CVEC) | ||
), | ||
Scenario( | ||
componentvector_to_scalar; | ||
x=ComponentVector{Float64}(; a=collect(1:7), b=collect(8:12)), | ||
ref=componentvector_to_scalar_ref(), | ||
), | ||
] | ||
end | ||
|
||
function component_scenarios_mutating() | ||
return [ | ||
Scenario( | ||
make_scalar_to_array!(SCALING_CVEC); | ||
x=2.0, | ||
y=float.(SCALING_CVEC), | ||
ref=scalar_to_array_ref(SCALING_CVEC), | ||
), | ||
] | ||
end | ||
|
||
function component_scenarios() | ||
return vcat(component_scenarios_allocating(), component_scenarios_mutating()) | ||
end | ||
|
||
end |
Oops, something went wrong.