You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
Currently for both delayed and dynamic linear we have made a memory versus computation trade off. We always save the fp8_weight casted version of the weight for backwards.
On single node this will cause the both the model.weight tensors in high precision and the casted Float8Tensor version of the weight to exist on gpu memory until the casted weight is freed during backward pass.
We would like to provide an option to disable this.
This becomes more of an issue for existing FSDP implementation. The high precision wight will be sharded among the Gpus. When running forward, we will all-gather and unshard the weights to each gpu. But since since the nn.parameter is not used for mamtul( rather the casted float8 tensor is) the existing fsdp mechanism for removing the un-sharded weight will not work and we will end up saving the un-sharded tensor for backwards on each device. This extra memory usage scales w/ num gpus.
The text was updated successfully, but these errors were encountered:
Summary
Currently for both delayed and dynamic linear we have made a memory versus computation trade off. We always save the fp8_weight casted version of the weight for backwards.
On single node this will cause the both the model.weight tensors in high precision and the casted Float8Tensor version of the weight to exist on gpu memory until the casted weight is freed during backward pass.
We would like to provide an option to disable this.
This becomes more of an issue for existing FSDP implementation. The high precision wight will be sharded among the Gpus. When running forward, we will all-gather and unshard the weights to each gpu. But since since the nn.parameter is not used for mamtul( rather the casted float8 tensor is) the existing fsdp mechanism for removing the un-sharded weight will not work and we will end up saving the un-sharded tensor for backwards on each device. This extra memory usage scales w/ num gpus.
The text was updated successfully, but these errors were encountered: