-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor the metal backend to always reuse command encoders/buffers unless a shared memory access is requested #2037
base: main
Are you sure you want to change the base?
Conversation
…readability" This reverts commit bb594bc.
Hoping this branch will address some portion of #1939 |
…mandEncoderReuse
…mandEncoderReuse
Okay, I'm happy with the PR in it's current state! @ivarflakstad it'd be great to get your eyes on this change as well! Feel free to reach out on discord if you want to go through parts of the change together and avoid back and forth 👍 |
@LaurentMazare I did some cleanup of the areas commented above, lmk if you have any other ideas on improving docs |
if args.metal_tracing { | ||
use candle::Device; | ||
if let Device::Metal(metal_device) = device.clone() { | ||
metal_device.capture("/tmp/candle.gputrace")?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make the metal-tracing
arg an Option<String>
so that it specifies the filename for the trace rather than hardcoding it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Adding to this, does it make sense to have this as an env var?
In other words let gpu tracing be independent on what type of model you’re running.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In which case we should log that gpu tracing is enabled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would rather keep it as an explicit argument as we try to not rely on env variable (though there might be one remaining), the gist of it is that they are hard to discover whereas arguments are documented. Also this way it doesn't depend on whether the env variable gets changed by the user code calling candle etc.
(and on another note, candle tries actively to be log free outside of the main.rs
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I was thinking it would only take effect with the debug
or release-with-debug
profiles. I imagine no one is doing gpu tracing in a production environment.
Debug logging is compiled away in release builds.
Just an update, I've run through most examples and they work fine, the exception is stable diffusion (wuerstchen is fine though...) which outputs garbage on the photo compared to main, i'll be looking into this and will tag when I find the root cause, not immediately obvious so far. |
Gave it a quick look and it indeed looks pretty weird. When enabling |
…mandEncoderReuse
Any luck with this? It would be great to have to analyze performance on metal so pretty keen to have it if it still works with stable diffusion etc. |
Unfortunately haven't had the bandwidth the last two weeks, I should have some more time coming up where I can spend some time investigating. I was going back through the implementation I had and was trying to hunt down where the change was introduced, so far nothing seemed apparent, so now I'm running through the operations that are performed in the stable diffusion model so that I can understand what's going on. |
…mandEncoderReuse
updated the branch here with main so we don't stay too far out of sync |
Just an update @LaurentMazare I did finally trace this down to something odd going on in the memory of certain tensors. It appears that at some point the buffer for the "encoder_hidden_states" is getting overwritten with zeros, this seems to also be happening with other buffers... Hence why images generated are coming out black right now. I'm not sure where this is coming from.. I initially thought it may be the buffer reuse, but after disabling that the issue is still reproducible... As for next steps, I'm not quite sure... If anyone has a clue/inkling why this may be occurring for the stable diffusion example but not for any other example I would love to know, I'm just not familiar enough to have a clear "aha" moment on this yet. |
This change aims to replace the pattern of each tensor provisioning a command buffer and encoder for each kernel operation that occurs, to a pattern where an encoder is provided to a kernel to setup it's operations.
Prior to this change, we relied on the fact that command buffers were executed sequentially to handle ordering of operations and ensuring that an output from one operation has completed before being used in a downstream operation. We now leverage Metal's resource based memory barriers to do this more effectively by ensuring we only block operations on the completion of operations on input dependencies.
What are the outcomes in terms of performance for this change? Well in summary not much... in theory I would have expected a minor gain due to the improved parallelization and lower overhead of only using a few command buffers and encoders, however in practice for the example models we do not see a change in performance.
Now although we don't see a change in performance on our models, we get a much more stable gputrace output as a result of this change. On main generating a gputrace would often cause OOM errors due to the amount of extra recorded command buffer/encoders, we now get a much more streamlined recording of the models, and the memory required to load these traces has been cut down drastically.
The major changes to note here are the following: