-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Forward Mode #503
Forward Mode #503
Conversation
Sure would be great if this hooked in to ChainRulesCore.jl from the get go so that we don't all have to reinvent the wheel and can jointly benefit from each others work implementing custom tangents / pushforwards / frules. |
Additionally, it appears that frule(::typeof(f), x...) = f(x), dx->some_computation While this approach is appropriate for reverse-mode, it doesn't work for forwards mode, as discussed here. edit: looks like the internal function |
Yes, that's just syntax sugar. Things are fused under the hood but it's nicer to write the rules out that way, in most cases; and it's easy to avoid the sugar if needed. Incorporating chainrules is a bit tricky for the same reasons as in #366. Once we have a solution for that we can do it for forwards and reverse together. There are also some semantics questions to figure out, since in implementing this I've found that the "zero type" approach is tricky for forwards mode in a way that doesn't come up for reverse. |
Glad that we agree on this.
Not really sure what the blockers are here (I've not delved into the compiler details), but I agree that this should be resolved promtly. @oxinabox , @dhairyagandhi96 and yourself were discussing issues at length a while ago, and it would be fantastic if these were resolved sooner rather than later so that we can all move on with life. In particular, @MikeInnes you are always welcome in Cambridge -- both @oxinabox and I would love do hash all of this out over a beer.
@oxinabox and I agree on this. We've been discussing mutable zeros for forwards mode at length internally over the last week or two, and I think we've wound up at a similar idea of what the solution should look like to your |
julia> using ForwardDiff2:DI
julia> using Zygote
julia> f(x) = tanh(log(x * sin(1)) + cos(exp(x+1)))
f (generic function with 1 method)
julia> @btime DI($f)(1);
108.671 ns (0 allocations: 0 bytes)
julia> g = pushforward(f, 1); @btime $g(1);
10.427 μs (30 allocations: 1.61 KiB)
julia> using ForwardDiff: derivative
julia> @btime derivative($f, 1);
21.460 ns (0 allocations: 0 bytes) Can we make this fast? |
Note that I hadn't paid any attention to performance up to now, so what's in the PR definitely doesn't represent any kind of best case. The recent patch fixes an obvious issue that makes things about as fast as FD2. Will keep digging to see what else we can do. Note that @shashi found a performance bug that affects both Cassette and IRTools no-op passes. That's something we can hopefully fix eventually, but presumably doesn't matter for the Zygote vs FD2 comparison since they'd both be equally affected. |
On my machine I get about 150ns for FD2 vs 200ns for Zygote. The remaining difference is almost certainly just due to the fact that Zygote is using DiffRules without CSE; ChainRules avoids some redundant computation. I'm assuming this is pretty negligible for things FD2 was actually designed for and we can just wait for ChainRules to fix this. |
src/forward/array.jl
Outdated
|
||
@tangent A::AbstractArray * B::AbstractArray = A*B, (Ȧ, Ḃ) -> Ȧ*B .+ A*Ḃ | ||
|
||
@tangent sum(x; dims = :) = sum(x; dims = dims), ẋ -> sum(x, dims = dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this rule's wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whoops – do you want commit so you can hack on this branch?
FD2 and Zygote are now both 100ns on my machine, for whatever it's worth. |
Can this be used in #338, to avoid the use of |
yes. |
I assume the answer is yes, but I just want to make sure,. would this forward mode support non-scalar functions? |
@MasonProtter Has to, if it is going to help with #338. So by implication, I think yes 😉 |
Some promising initial results: julia> ff = Chain(Dense(784, 32, tanh), Dense(32, 10))
Chain(Dense(784, 32, tanh), Dense(32, 10))
julia> scalar_ff(x) = ff(fill(x, (784, 100)))
scalar_ff (generic function with 1 method)
julia> x = Float32(0.5)
0.5f0
julia> @btime ForwardDiff.derivative(scalar_ff, x);
13.158 ms (34 allocations: 683.69 KiB)
julia> @btime Zygote.pushforward(scalar_ff, x)(one(x));
1.068 ms (145 allocations: 840.61 KiB) So a little over 10x faster than ForwardDiff to run a single forward pass over an MLP. And of course unlike FD we can use GPUs, XLA etc. which will give us further speedups. I'm expecting to clean up the tests and merge this soon. It's going to be very beta for a while but we can add features as needed. |
Co-authored-by: Seth Axen <[email protected]>
Co-authored-by: Seth Axen <[email protected]>
We're not done here, but I think the interface and core ideas are pretty stable now, so the best way for this to move forward is having it on master so people can play around with it, open PRs for various improvements etc. bors r+ |
503: Forward Mode r=MikeInnes a=MikeInnes ```julia julia> function pow(x::Real, n::Integer) r = 1 while n > 0 n -= 1 r *= x end return r end pow (generic function with 1 method) julia> forw = pushforward(pow, 2, 3) #8 (generic function with 1 method) julia> forw(1, 0) 12 ``` Lots more to do, but this is a starting point. Co-authored-by: Mike J Innes <[email protected]> Co-authored-by: Mike Innes <[email protected]> Co-authored-by: Dhairya Gandhi <[email protected]>
Great to see this merged. |
Build failed: |
bors r+ |
503: Forward Mode r=MikeInnes a=MikeInnes ```julia julia> function pow(x::Real, n::Integer) r = 1 while n > 0 n -= 1 r *= x end return r end pow (generic function with 1 method) julia> forw = pushforward(pow, 2, 3) #8 (generic function with 1 method) julia> forw(1, 0) 12 ``` Lots more to do, but this is a starting point. Co-authored-by: Mike J Innes <[email protected]> Co-authored-by: Mike Innes <[email protected]> Co-authored-by: Dhairya Gandhi <[email protected]>
Build failed: |
bors r+ |
Merge conflict. |
bors r+ |
Build succeeded: |
Lots more to do, but this is a starting point.