From b626ef5f002cb7bf45716abc4b698590449ad1d4 Mon Sep 17 00:00:00 2001 From: AbdelKarim ELJANDOUBI <78537694+eljandoubi@users.noreply.github.com> Date: Mon, 2 Dec 2024 19:45:30 +0100 Subject: [PATCH] Select the DeepSpeedCPUOptimizer based on the original optimizer class. (#3255) * Select the DeepSpeedCPUOptimizer based on the original optimizer class. * abstract out optimizer selection to a deepspeed util * add deepspeed cpu Adam & AdamW --- src/accelerate/accelerator.py | 6 ++-- src/accelerate/utils/__init__.py | 1 + src/accelerate/utils/deepspeed.py | 60 +++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 9baad9b56df..d112221dff2 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -119,6 +119,7 @@ DeepSpeedSchedulerWrapper, DummyOptim, DummyScheduler, + map_pytorch_optim_to_deepspeed, ) if is_megatron_lm_available(): @@ -1839,10 +1840,7 @@ def _prepare_deepspeed(self, *args): if self.deepspeed_config["zero_optimization"].get("offload_optimizer", {}).get( "device", "none" ) != "none" and self.deepspeed_config.get("zero_force_ds_cpu_optimizer", True): - from deepspeed.ops.adam import DeepSpeedCPUAdam - - defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]} - optimizer = DeepSpeedCPUAdam(optimizer.param_groups, **defaults) + optimizer = map_pytorch_optim_to_deepspeed(optimizer) kwargs["optimizer"] = optimizer if scheduler is not None: if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES: diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 7f2b68228f1..273808597b5 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -199,6 +199,7 @@ DummyScheduler, HfDeepSpeedConfig, get_active_deepspeed_plugin, + map_pytorch_optim_to_deepspeed, ) from .bnb import has_4bit_bnb_layers, load_and_quantize_model diff --git a/src/accelerate/utils/deepspeed.py b/src/accelerate/utils/deepspeed.py index 4712daee30a..f46d69c6cc2 100644 --- a/src/accelerate/utils/deepspeed.py +++ b/src/accelerate/utils/deepspeed.py @@ -17,9 +17,69 @@ import os from copy import deepcopy +from torch import optim + from ..optimizer import AcceleratedOptimizer from ..scheduler import AcceleratedScheduler from .dataclasses import DistributedType +from .imports import is_bnb_available +from .versions import compare_versions + + +def map_pytorch_optim_to_deepspeed(optimizer): + """ + Args: + optimizer: torch.optim.Optimizer + + Returns the DeepSeedCPUOptimizer (deepspeed.ops) version of the optimizer. + """ + + defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]} + + # Select the DeepSpeedCPUOptimizer based on the original optimizer class. + # DeepSpeedCPUAdam is the default + from deepspeed.ops.adam import DeepSpeedCPUAdam + + optimizer_class = DeepSpeedCPUAdam + + # For DeepSpeedCPUAdam (adamw_mode) + if compare_versions("deepspeed", ">=", "0.3.1"): + defaults["adamw_mode"] = False + is_adaw = isinstance(optimizer, optim.AdamW) + + if is_bnb_available() and not is_adaw: + import bitsandbytes.optim as bnb_opt + + is_adaw = isinstance(optimizer, (bnb_opt.AdamW, bnb_opt.AdamW32bit)) and optimizer.optim_bits == 32 + + if is_adaw: + defaults["adamw_mode"] = True + + # For DeepSpeedCPUAdagrad + if compare_versions("deepspeed", ">=", "0.5.5"): + # Check if the optimizer is PyTorch's Adagrad. + is_ada = isinstance(optimizer, optim.Adagrad) + # If not, and bitsandbytes is available, + # # check if the optimizer is the 32-bit bitsandbytes Adagrad. + if is_bnb_available() and not is_ada: + import bitsandbytes.optim as bnb_opt + + is_ada = isinstance(optimizer, (bnb_opt.Adagrad, bnb_opt.Adagrad32bit)) and optimizer.optim_bits == 32 + if is_ada: + from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad + + optimizer_class = DeepSpeedCPUAdagrad + + # For DeepSpeedCPULion + if is_bnb_available(min_version="0.38.0") and compare_versions("deepspeed", ">=", "0.11.0"): + from bitsandbytes.optim import Lion, Lion32bit + + if isinstance(optimizer, (Lion, Lion32bit)) and optimizer.optim_bits == 32: + from deepspeed.ops.lion import DeepSpeedCPULion + + optimizer_class = DeepSpeedCPULion + + return optimizer_class(optimizer.param_groups, **defaults) def get_active_deepspeed_plugin(state):