Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward Mode #503

Merged
merged 27 commits into from
Jun 22, 2020
Merged

Forward Mode #503

merged 27 commits into from
Jun 22, 2020

Conversation

MikeInnes
Copy link
Member

julia> function pow(x::Real, n::Integer)
         r = 1
         while n > 0
           n -= 1
           r *= x
         end
         return r
       end
pow (generic function with 1 method)

julia> forw = pushforward(pow, 2, 3)
#8 (generic function with 1 method)

julia> forw(1, 0)
12

Lots more to do, but this is a starting point.

@willtebbutt
Copy link
Member

willtebbutt commented Feb 7, 2020

Sure would be great if this hooked in to ChainRulesCore.jl from the get go so that we don't all have to reinvent the wheel and can jointly benefit from each others work implementing custom tangents / pushforwards / frules.

@willtebbutt
Copy link
Member

willtebbutt commented Feb 7, 2020

Additionally, it appears that @tangent functions don't fuse function evaluation and differential computation. In chain rules language, the tangents are of the form

frule(::typeof(f), x...) = f(x), dx->some_computation

While this approach is appropriate for reverse-mode, it doesn't work for forwards mode, as discussed here.

edit: looks like the internal function _tangent can handle this though. Maybe not a problem.

@MikeInnes
Copy link
Member Author

Yes, that's just syntax sugar. Things are fused under the hood but it's nicer to write the rules out that way, in most cases; and it's easy to avoid the sugar if needed.

Incorporating chainrules is a bit tricky for the same reasons as in #366. Once we have a solution for that we can do it for forwards and reverse together. There are also some semantics questions to figure out, since in implementing this I've found that the "zero type" approach is tricky for forwards mode in a way that doesn't come up for reverse.

@willtebbutt
Copy link
Member

willtebbutt commented Feb 8, 2020

Yes, that's just syntax sugar. Things are fused under the hood but it's nicer to write the rules out that way, in most cases; and it's easy to avoid the sugar if needed.

Glad that we agree on this.

Incorporating chainrules is a bit tricky for the same reasons as in #366.

Not really sure what the blockers are here (I've not delved into the compiler details), but I agree that this should be resolved promtly. @oxinabox , @dhairyagandhi96 and yourself were discussing issues at length a while ago, and it would be fantastic if these were resolved sooner rather than later so that we can all move on with life. In particular, @MikeInnes you are always welcome in Cambridge -- both @oxinabox and I would love do hash all of this out over a beer.

There are also some semantics questions to figure out, since in implementing this I've found that the "zero type" approach is tricky for forwards mode in a way that doesn't come up for reverse.

@oxinabox and I agree on this. We've been discussing mutable zeros for forwards mode at length internally over the last week or two, and I think we've wound up at a similar idea of what the solution should look like to your zeros_like.

@shashi
Copy link

shashi commented Mar 5, 2020

julia> using ForwardDiff2:DI

julia> using Zygote

julia> f(x) = tanh(log(x * sin(1)) + cos(exp(x+1)))
f (generic function with 1 method)

julia> @btime DI($f)(1);
  108.671 ns (0 allocations: 0 bytes)

julia> g = pushforward(f, 1); @btime $g(1);
  10.427 μs (30 allocations: 1.61 KiB)

julia> using ForwardDiff: derivative

julia> @btime derivative($f, 1);
  21.460 ns (0 allocations: 0 bytes)

Can we make this fast?

@MikeInnes
Copy link
Member Author

Note that I hadn't paid any attention to performance up to now, so what's in the PR definitely doesn't represent any kind of best case.

The recent patch fixes an obvious issue that makes things about as fast as FD2. Will keep digging to see what else we can do.

Note that @shashi found a performance bug that affects both Cassette and IRTools no-op passes. That's something we can hopefully fix eventually, but presumably doesn't matter for the Zygote vs FD2 comparison since they'd both be equally affected.

@MikeInnes
Copy link
Member Author

On my machine I get about 150ns for FD2 vs 200ns for Zygote. The remaining difference is almost certainly just due to the fact that Zygote is using DiffRules without CSE; ChainRules avoids some redundant computation. I'm assuming this is pretty negligible for things FD2 was actually designed for and we can just wait for ChainRules to fix this.


@tangent A::AbstractArray * B::AbstractArray = A*B, (Ȧ, Ḃ) -> Ȧ*B .+ A*Ḃ

@tangent sum(x; dims = :) = sum(x; dims = dims), ẋ -> sum(x, dims = dims)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this rule's wrong

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops – do you want commit so you can hack on this branch?

@MikeInnes
Copy link
Member Author

FD2 and Zygote are now both 100ns on my machine, for whatever it's worth.

@cossio
Copy link
Contributor

cossio commented Apr 26, 2020

Can this be used in #338, to avoid the use of Dual numbers? That is, to have differentiation of functions that take native types.

@oxinabox
Copy link
Member

yes.

@MasonProtter
Copy link
Contributor

MasonProtter commented May 13, 2020

I assume the answer is yes, but I just want to make sure,. would this forward mode support non-scalar functions?

@cossio
Copy link
Contributor

cossio commented May 15, 2020

@MasonProtter Has to, if it is going to help with #338. So by implication, I think yes 😉

@MikeInnes
Copy link
Member Author

Some promising initial results:

julia> ff = Chain(Dense(784, 32, tanh), Dense(32, 10))
Chain(Dense(784, 32, tanh), Dense(32, 10))

julia> scalar_ff(x) = ff(fill(x, (784, 100)))
scalar_ff (generic function with 1 method)

julia> x = Float32(0.5)
0.5f0

julia> @btime ForwardDiff.derivative(scalar_ff, x);
  13.158 ms (34 allocations: 683.69 KiB)

julia> @btime Zygote.pushforward(scalar_ff, x)(one(x));
  1.068 ms (145 allocations: 840.61 KiB)

So a little over 10x faster than ForwardDiff to run a single forward pass over an MLP. And of course unlike FD we can use GPUs, XLA etc. which will give us further speedups.

I'm expecting to clean up the tests and merge this soon. It's going to be very beta for a while but we can add features as needed.

@MikeInnes
Copy link
Member Author

We're not done here, but I think the interface and core ideas are pretty stable now, so the best way for this to move forward is having it on master so people can play around with it, open PRs for various improvements etc.

bors r+

bors bot added a commit that referenced this pull request Jun 18, 2020
503: Forward Mode r=MikeInnes a=MikeInnes

```julia
julia> function pow(x::Real, n::Integer)
         r = 1
         while n > 0
           n -= 1
           r *= x
         end
         return r
       end
pow (generic function with 1 method)

julia> forw = pushforward(pow, 2, 3)
#8 (generic function with 1 method)

julia> forw(1, 0)
12
```

Lots more to do, but this is a starting point.

Co-authored-by: Mike J Innes <[email protected]>
Co-authored-by: Mike Innes <[email protected]>
Co-authored-by: Dhairya Gandhi <[email protected]>
@cossio
Copy link
Contributor

cossio commented Jun 18, 2020

Great to see this merged.
If there are docs, or at least more code comments, it will be easier for more people to try this and contribute.

@bors
Copy link
Contributor

bors bot commented Jun 18, 2020

Build failed:

@MikeInnes
Copy link
Member Author

bors r+

bors bot added a commit that referenced this pull request Jun 18, 2020
503: Forward Mode r=MikeInnes a=MikeInnes

```julia
julia> function pow(x::Real, n::Integer)
         r = 1
         while n > 0
           n -= 1
           r *= x
         end
         return r
       end
pow (generic function with 1 method)

julia> forw = pushforward(pow, 2, 3)
#8 (generic function with 1 method)

julia> forw(1, 0)
12
```

Lots more to do, but this is a starting point.

Co-authored-by: Mike J Innes <[email protected]>
Co-authored-by: Mike Innes <[email protected]>
Co-authored-by: Dhairya Gandhi <[email protected]>
@bors
Copy link
Contributor

bors bot commented Jun 18, 2020

Build failed:

@MikeInnes
Copy link
Member Author

bors r+

@bors
Copy link
Contributor

bors bot commented Jun 18, 2020

Merge conflict.

@MikeInnes
Copy link
Member Author

bors r+

@bors
Copy link
Contributor

bors bot commented Jun 22, 2020

Build succeeded:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants