Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Enable fp8_weight recomputation during backwards pass #185

Closed
drisspg opened this issue Jan 13, 2024 · 2 comments
Closed

Enable fp8_weight recomputation during backwards pass #185

drisspg opened this issue Jan 13, 2024 · 2 comments

Comments

@drisspg
Copy link
Contributor

drisspg commented Jan 13, 2024

Summary

Currently for both delayed and dynamic linear we have made a memory versus computation trade off. We always save the fp8_weight casted version of the weight for backwards.

On single node this will cause the both the model.weight tensors in high precision and the casted Float8Tensor version of the weight to exist on gpu memory until the casted weight is freed during backward pass.

We would like to provide an option to disable this.

This becomes more of an issue for existing FSDP implementation. The high precision wight will be sharded among the Gpus. When running forward, we will all-gather and unshard the weights to each gpu. But since since the nn.parameter is not used for mamtul( rather the casted float8 tensor is) the existing fsdp mechanism for removing the un-sharded weight will not work and we will end up saving the un-sharded tensor for backwards on each device. This extra memory usage scales w/ num gpus.

@drisspg
Copy link
Contributor Author

drisspg commented Jan 31, 2024

@vkuzo
Copy link
Contributor

vkuzo commented Jul 30, 2024

pytorch/ao#562

@vkuzo vkuzo closed this as completed Jul 30, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants