Skip to content

Commit

Permalink
Select the DeepSpeedCPUOptimizer based on the original optimizer clas…
Browse files Browse the repository at this point in the history
…s. (#3255)

* Select the DeepSpeedCPUOptimizer based on the original optimizer class.

* abstract out optimizer selection to a deepspeed util

* add deepspeed cpu Adam & AdamW
  • Loading branch information
eljandoubi authored Dec 2, 2024
1 parent dd68af8 commit b626ef5
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
DeepSpeedSchedulerWrapper,
DummyOptim,
DummyScheduler,
map_pytorch_optim_to_deepspeed,
)

if is_megatron_lm_available():
Expand Down Expand Up @@ -1839,10 +1840,7 @@ def _prepare_deepspeed(self, *args):
if self.deepspeed_config["zero_optimization"].get("offload_optimizer", {}).get(
"device", "none"
) != "none" and self.deepspeed_config.get("zero_force_ds_cpu_optimizer", True):
from deepspeed.ops.adam import DeepSpeedCPUAdam

defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]}
optimizer = DeepSpeedCPUAdam(optimizer.param_groups, **defaults)
optimizer = map_pytorch_optim_to_deepspeed(optimizer)
kwargs["optimizer"] = optimizer
if scheduler is not None:
if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES:
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@
DummyScheduler,
HfDeepSpeedConfig,
get_active_deepspeed_plugin,
map_pytorch_optim_to_deepspeed,
)

from .bnb import has_4bit_bnb_layers, load_and_quantize_model
Expand Down
60 changes: 60 additions & 0 deletions src/accelerate/utils/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,69 @@
import os
from copy import deepcopy

from torch import optim

from ..optimizer import AcceleratedOptimizer
from ..scheduler import AcceleratedScheduler
from .dataclasses import DistributedType
from .imports import is_bnb_available
from .versions import compare_versions


def map_pytorch_optim_to_deepspeed(optimizer):
"""
Args:
optimizer: torch.optim.Optimizer
Returns the DeepSeedCPUOptimizer (deepspeed.ops) version of the optimizer.
"""

defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]}

# Select the DeepSpeedCPUOptimizer based on the original optimizer class.
# DeepSpeedCPUAdam is the default
from deepspeed.ops.adam import DeepSpeedCPUAdam

optimizer_class = DeepSpeedCPUAdam

# For DeepSpeedCPUAdam (adamw_mode)
if compare_versions("deepspeed", ">=", "0.3.1"):
defaults["adamw_mode"] = False
is_adaw = isinstance(optimizer, optim.AdamW)

if is_bnb_available() and not is_adaw:
import bitsandbytes.optim as bnb_opt

is_adaw = isinstance(optimizer, (bnb_opt.AdamW, bnb_opt.AdamW32bit)) and optimizer.optim_bits == 32

if is_adaw:
defaults["adamw_mode"] = True

# For DeepSpeedCPUAdagrad
if compare_versions("deepspeed", ">=", "0.5.5"):
# Check if the optimizer is PyTorch's Adagrad.
is_ada = isinstance(optimizer, optim.Adagrad)
# If not, and bitsandbytes is available,
# # check if the optimizer is the 32-bit bitsandbytes Adagrad.
if is_bnb_available() and not is_ada:
import bitsandbytes.optim as bnb_opt

is_ada = isinstance(optimizer, (bnb_opt.Adagrad, bnb_opt.Adagrad32bit)) and optimizer.optim_bits == 32
if is_ada:
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad

optimizer_class = DeepSpeedCPUAdagrad

# For DeepSpeedCPULion
if is_bnb_available(min_version="0.38.0") and compare_versions("deepspeed", ">=", "0.11.0"):
from bitsandbytes.optim import Lion, Lion32bit

if isinstance(optimizer, (Lion, Lion32bit)) and optimizer.optim_bits == 32:
from deepspeed.ops.lion import DeepSpeedCPULion

optimizer_class = DeepSpeedCPULion

return optimizer_class(optimizer.param_groups, **defaults)


def get_active_deepspeed_plugin(state):
Expand Down

0 comments on commit b626ef5

Please sign in to comment.