-
-
Notifications
You must be signed in to change notification settings - Fork 206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support automatic differentiation of the NN inside the loss function? #150
Comments
The problem is that the most efficient way to do this is to use forward-mode inside of the loss function and then reverse over that. Using ForwardDiff in here kind of works, but it requires a few workaround and hard to generalize. But Zygote got a forward mode: FluxML/Zygote.jl#503 . It needs a few more things to be fully compatible with standard NNs people want to use for physics-informed neural networks though: FluxML/Zygote.jl#654 . But we are well aware of this, and @Keno has been working on a major improvement to the AD system which make this a lot better, and @DhairyaLGandhi is aware of this use case. That said, the reason why it's not a huge issue is that the computational complexity of numerical and forward mode is the same, with forward mode just decreasing the number of primal calculations and allowing a bit more SIMD/CSE in some cases. You never see forward mode more than 4x better than the fastest numerical differentiation schemes, usually more around 2x. What's essential for performance is the reverse mode of the loss function, which is already there, since that has a massive complexity change. This is why I haven't made a big push to get "something for now and better for later", and instead am just waiting for the big nested AD changes coming later this year since the actual difference to a user will be rather slim (much slimmer than you might suspect), so it's not worth making a fuss about until special compiler tools allow for faster high order derivatives (which something like PyTorch doesn't do, so that would be something to write home about). Pinging @KirillZubov since I know he was curious about this detail as well. But indeed, this is an issue so thanks for opening it so we can formally track it. We'll start updating the public on this more often since there's a lot going on here. |
I saw that you plan to support automatic differentiation of the NN inside the loss function. Do you have a plan/roadmap for this? I'm interested to look into how this could be done so if you have some code snippets and/or notes on this please share them.
The text was updated successfully, but these errors were encountered: