Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fp16 by converting outputs back to FP32 #134

Merged
merged 3 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/quicktour.rst
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,19 @@ If you are using gradient clipping in your script, you should replace the calls
:obj:`torch.nn.utils.clip_grad_norm_` or :obj:`torch.nn.utils.clip_grad_value_` with :obj:`accelerator.clip_grad_norm_`
and :obj:`accelerator.clip_grad_value_` respectively.

Mixed Precision training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

If you are running your training in Mixed Precision with Accelerate, you will get the best result with your loss being
computed inside your model (like in Transformer models for instance). Every computation outside of the model will be
executed in full precision (which is generally what you want for loss computation, expecially if it involves a
softmax). However you might want to put your loss computation inside the `accelerator.autocast` context manager:

.. codeblock::

with accelerator.autocast():
loss = complex_loss_function(outputs, target):


Internal mechanism
-----------------------------------------------------------------------------------------------------------------------
Expand Down
16 changes: 16 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .utils import (
DeepSpeedPlugin,
RNGType,
convert_outputs_to_fp32,
extract_model_from_parallel,
gather,
pad_across_processes,
Expand Down Expand Up @@ -295,6 +296,7 @@ def prepare_model(self, model):
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
if self.native_amp:
model.forward = torch.cuda.amp.autocast()(model.forward)
model.forward = convert_outputs_to_fp32(model.forward)
return model

def _prepare_deepspeed(self, *args):
Expand Down Expand Up @@ -550,3 +552,17 @@ def get_state_dict(self, model):
state_dict[k] = state_dict[k].float()

return state_dict

@contextmanager
def autocast(self):
"""
Will apply automatic mixed-precision inside the block inside this context manager, if it is enabled. Nothing
different will happen otherwise.
"""
if self.native_amp:
autocast_context = torch.cuda.amp.autocast()
autocast_context.__enter__()
yield
autocast_context.__exit__()
else:
yield
40 changes: 40 additions & 0 deletions src/accelerate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,46 @@ def send_to_device(tensor, device):
return tensor.to(device)


def convert_to_fp32(tensor):
"""
Recursively converts the lements nested list/tuple/dictionary of tensors in FP16 precision to FP32.

Args:
tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`):
The data to convert from FP16 to FP32.

Returns:
The same data structure as :obj:`tensor` with all tensors that were in FP16 precision converted to FP32.
"""
if isinstance(tensor, (list, tuple)):
return honor_type(tensor, (convert_to_fp32(t) for t in tensor))
elif isinstance(tensor, dict):
return type(tensor)({k: convert_to_fp32(v) for k, v in tensor.items()})
elif not hasattr(tensor, "dtype") or tensor.dtype != torch.float16:
return tensor
return tensor.float()


def convert_outputs_to_fp32(model_forward):
"""
Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP16
precision will be convert back to FP32.

Args:
model_forward (:obj:`Callable`):
The function which outputs we want to treat.

Returns:
The same function as :obj:`model_forward` but with converted outputs.
"""

def convert_outputs(*args, **kwargs):
outputs = model_forward(*args, **kwargs)
return convert_to_fp32(outputs)

return convert_outputs


def extract_model_from_parallel(model):
"""
Extract a model from its distributed containers.
Expand Down