From 3b848822e1feaae330be273ff29b83dfc6ad0cad Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 5 Aug 2021 03:58:59 -0400 Subject: [PATCH 1/3] Fix FP16 mode by converting model outputs to FP32 --- src/accelerate/accelerator.py | 2 ++ src/accelerate/utils.py | 40 +++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 281f4e9a784..68941435ed8 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -28,6 +28,7 @@ from .utils import ( DeepSpeedPlugin, RNGType, + convert_outputs_to_fp32, extract_model_from_parallel, gather, pad_across_processes, @@ -295,6 +296,7 @@ def prepare_model(self, model): model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) if self.native_amp: model.forward = torch.cuda.amp.autocast()(model.forward) + model.forward = convert_outputs_to_fp32(model.forward) return model def _prepare_deepspeed(self, *args): diff --git a/src/accelerate/utils.py b/src/accelerate/utils.py index 6edd1f1c00f..a6b7a7d74ae 100644 --- a/src/accelerate/utils.py +++ b/src/accelerate/utils.py @@ -137,6 +137,46 @@ def send_to_device(tensor, device): return tensor.to(device) +def convert_to_fp32(tensor): + """ + Recursively converts the lements nested list/tuple/dictionary of tensors in FP16 precision to FP32. + + Args: + tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`): + The data to convert from FP16 to FP32. + + Returns: + The same data structure as :obj:`tensor` with all tensors that were in FP16 precision converted to FP32. + """ + if isinstance(tensor, (list, tuple)): + return honor_type(tensor, (convert_to_fp32(t) for t in tensor)) + elif isinstance(tensor, dict): + return type(tensor)({k: convert_to_fp32(v) for k, v in tensor.items()}) + elif not hasattr(tensor, "dtype") or tensor.dtype != torch.float16: + return tensor + return tensor.float() + + +def convert_outputs_to_fp32(model_forward): + """ + Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP16 + precision will be convert back to FP32. + + Args: + model_forward (:obj:`Callable`): + The function which outputs we want to treat. + + Returns: + The same function as :obj:`model_forward` but with converted outputs. + """ + + def convert_outputs(*args, **kwargs): + outputs = model_forward(*args, **kwargs) + return convert_to_fp32(outputs) + + return convert_outputs + + def extract_model_from_parallel(model): """ Extract a model from its distributed containers. From e44119ce8af33b97f33060517e87630572db21bf Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 5 Aug 2021 04:15:07 -0400 Subject: [PATCH 2/3] Add context manager and doc --- docs/source/quicktour.rst | 13 +++++++++++++ src/accelerate/accelerator.py | 14 ++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/docs/source/quicktour.rst b/docs/source/quicktour.rst index 1e28fed85cc..7ccdf5dafdb 100644 --- a/docs/source/quicktour.rst +++ b/docs/source/quicktour.rst @@ -345,6 +345,19 @@ If you are using gradient clipping in your script, you should replace the calls :obj:`torch.nn.utils.clip_grad_norm_` or :obj:`torch.nn.utils.clip_grad_value_` with :obj:`accelerator.clip_grad_norm_` and :obj:`accelerator.clip_grad_value_` respectively. +Mixed Precision training +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you are running your training in Mixed Precision with Accelerate, you will get the best result with your loss being +computed inside your model (like in Transformer models for instance). Every computation outside of the model will be +executed in full precision (which is generally what you want for loss computation, expecially if it involves a +softmax). However you might want to put your loss computation inside the `accelerator.autocast` context manager: + +.. codeblock:: + + with accelerator.autocast(): + loss = complex_loss_function(outputs, target): + Internal mechanism ----------------------------------------------------------------------------------------------------------------------- diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 68941435ed8..fa891fd0af3 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -552,3 +552,17 @@ def get_state_dict(self, model): state_dict[k] = state_dict[k].float() return state_dict + + @contextmanager + def autocast(self): + """ + Will apply automatic mixed precision inside the block inside thics context manager, if it is enabled. Nothing + different will happen otherwise. + """ + if self.native_amp: + autocast_context = torch.cuda.amp.autocast() + autocast_context.__enter__() + yield + autocast_context.__exit__() + else: + yield From 1d2a0475dac2ca39dcda7c0988f5d223ed6176cb Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 5 Aug 2021 16:51:51 +0200 Subject: [PATCH 3/3] Update src/accelerate/accelerator.py Co-authored-by: Lysandre Debut --- src/accelerate/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index fa891fd0af3..a11b7a7a5dc 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -556,7 +556,7 @@ def get_state_dict(self, model): @contextmanager def autocast(self): """ - Will apply automatic mixed precision inside the block inside thics context manager, if it is enabled. Nothing + Will apply automatic mixed-precision inside the block inside this context manager, if it is enabled. Nothing different will happen otherwise. """ if self.native_amp: