You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, is it necessary to wrap the forward pass in autocast when using FSDP2? I noticed that the torchtitan training loop does not.
If I wrap in torch.autocast(device_type="cuda", dtype=torch.bfloat16) my matmuls will be bfloat16, but my softmaxes (say) will be in float32. This behavior requires the autocast wrapper:
This is the usual way to do DDP or non-distributed mixed-precision training.
It seems to me that this behavior is lost in the torchtitan training loop which doesn't use the autocastcontext manager. Is this not true? Does FSDP2 somehow still perform the upcast for the usual upcasted amp ops like softmax? Not seeing how it might do so, and can't test easily at the moment.
I believe I correctly understand that MixedPrecisionPolicy controls the dtypes that weights are held in, reductions are performed in, and whether to cast a given module's outputs to a certain dtype, but that is all orthogonal to the dispatcher flags that autocast controls, IIUC.
Hi, is it necessary to wrap the forward pass in
autocast
when using FSDP2? I noticed that thetorchtitan
training loop does not.If I wrap in
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
my matmuls will bebfloat16
, but my softmaxes (say) will be infloat32
. This behavior requires the autocast wrapper:This is the usual way to do DDP or non-distributed mixed-precision training.
It seems to me that this behavior is lost in the
torchtitan
training loop which doesn't use theautocast
context manager. Is this not true? Does FSDP2 somehow still perform the upcast for the usual upcasted amp ops like softmax? Not seeing how it might do so, and can't test easily at the moment.I believe I correctly understand that
MixedPrecisionPolicy
controls thedtype
s that weights are held in, reductions are performed in, and whether to cast a given module's outputs to a certaindtype
, but that is all orthogonal to the dispatcher flags thatautocast
controls, IIUC.Relates to #600 and #591. Also, I believe OLMo uses autocast with FSDP, but that is FSDP1 last time I checked.
CC @awgu
The text was updated successfully, but these errors were encountered: