Skip to content

Commit

Permalink
Check torch version for ADOPT optimizer + integrating new ADOPT updat…
Browse files Browse the repository at this point in the history
…es (#2104)

* added torch check for adopt, wip

* lint

* gonna put torch version checking somewhere else

* added ENVcapabilities class for torch version checking

* lint + pydantic

* ENVCapabilities -> EnvCapabilities

* forgot to git add v0_4_1/__init__.py

* removed redundancy

* add check if env_capabilities not specified

* make env_capabilities compulsory [skip e2e]

* fixup env_capabilities

* modified test_validation.py to accomodate env_capabilities

* adopt torch version test [skip e2e]

* raise error

* test correct torch version

* test torch version above requirement

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <[email protected]>

* removed unused is_totch_min

---------

Co-authored-by: Wing Lian <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent 9f6d0b5 commit d5f58b6
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 61 deletions.
2 changes: 1 addition & 1 deletion docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ lr_div_factor: # Learning rate div factor
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_apex_fused
# - adopt_adamw (only for torch version >= 2.5.1)
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
# - adafactor
# - adamw_anyprecision
# - sgd
Expand Down
7 changes: 5 additions & 2 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def print_dep_versions():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}")
pkg_version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
print("*" * 40)


Expand Down Expand Up @@ -444,6 +444,9 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
},
)

prepare_optim_env(cfg)
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,9 @@ def create_optimizer(self):

self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)

Expand Down
20 changes: 17 additions & 3 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,11 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja


def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
def validate_config(
cfg: DictDefault,
capabilities: Optional[dict] = None,
env_capabilities: Optional[dict] = None,
):
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase

Expand All @@ -239,14 +243,24 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args()

if capabilities:
if capabilities or env_capabilities:
if (capabilities and not env_capabilities) or (
env_capabilities and not capabilities
):
raise ValueError(
"Both capabilities and env_capabilities must be provided or not provided."
)

return DictDefault(
dict(
AxolotlConfigWCapabilities(
**cfg.to_dict(), capabilities=capabilities
**cfg.to_dict(),
capabilities=capabilities,
env_capabilities=env_capabilities,
).model_dump(exclude_none=True)
)
)

return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
)
22 changes: 21 additions & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union

from packaging import version
from pydantic import (
BaseModel,
Field,
Expand All @@ -21,7 +22,7 @@
from transformers.training_args import OptimizerNames
from transformers.utils.import_utils import is_torch_npu_available

from axolotl.utils.config.models.internals import GPUCapabilities
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities

LOG = logging.getLogger("axolotl.utils.config.models.input")

Expand Down Expand Up @@ -1477,6 +1478,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""

capabilities: GPUCapabilities
env_capabilities: EnvCapabilities

@model_validator(mode="after")
def check_bf16(self):
Expand Down Expand Up @@ -1551,3 +1553,21 @@ def check_multigpu_unsloth(cls, data):
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data

@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")

if torch_version is None:
import torch

torch_version = str(torch.__version__).split("+", maxsplit=1)[0]

if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data
6 changes: 6 additions & 0 deletions src/axolotl/utils/config/models/internals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ class GPUCapabilities(BaseModel):
n_gpu: int = Field(default=1)
n_node: int = Field(default=1)
compute_capability: Optional[str] = Field(default=None)


class EnvCapabilities(BaseModel):
"""model to manage the environment capabilities statically"""

torch_version: Optional[str] = Field(default=None)
135 changes: 83 additions & 52 deletions src/axolotl/utils/optimizers/adopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,29 @@
"""
# mypy: ignore-errors
# pylint: skip-file
# flake8: noqa
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union, cast
from typing import Callable, List, Optional, Tuple, Union, cast

import torch
from torch import Tensor
from torch.optim.optimizer import (
from torch.optim.optimizer import ( # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling,
DeviceDict,
Optimizer,
ParamsT,
_capturable_doc,
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_fused_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_stack_if_compiling,
_use_grad_for_differentiable,
_view_as_real,
)
Expand All @@ -35,8 +43,9 @@ def __init__(
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6,
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
weight_decay: float = 0.0,
decoupled: bool = False,
decouple: bool = False,
*,
foreach: Optional[bool] = None,
maximize: bool = False,
Expand All @@ -62,12 +71,14 @@ def __init__(
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")

self.clip_lambda = clip_lambda

defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
decoupled=decoupled,
decouple=decouple,
maximize=maximize,
foreach=foreach,
capturable=capturable,
Expand Down Expand Up @@ -219,8 +230,9 @@ def step(self, closure=None):
beta1=beta1,
beta2=beta2,
lr=group["lr"],
clip_lambda=self.clip_lambda,
weight_decay=group["weight_decay"],
decoupled=group["decoupled"],
decouple=group["decouple"],
eps=group["eps"],
maximize=group["maximize"],
foreach=group["foreach"],
Expand All @@ -247,8 +259,9 @@ def _single_tensor_adopt(
beta1: float,
beta2: float,
lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float,
decoupled: bool,
decouple: bool,
eps: float,
maximize: bool,
capturable: bool,
Expand Down Expand Up @@ -276,14 +289,10 @@ def _single_tensor_adopt(
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."

# update step
step_t += 1
step = step_t if capturable or differentiable else _get_value(step_t)

if weight_decay != 0:
if decoupled:
param.add_(param, alpha=-lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if weight_decay != 0 and not decouple:
grad = grad.add(param, alpha=weight_decay)

if torch.is_complex(param):
grad = torch.view_as_real(grad)
Expand All @@ -293,20 +302,29 @@ def _single_tensor_adopt(
exp_avg_sq = torch.view_as_real(exp_avg_sq)
param = torch.view_as_real(param)

step = step_t if capturable or differentiable else _get_value(step_t)
if step == 1:
if step == 0:
exp_avg_sq.addcmul_(grad, grad.conj())
# update step
step_t += 1
continue

if weight_decay != 0 and decouple:
param.add_(param, alpha=-lr * weight_decay)

denom = torch.clamp(exp_avg_sq.sqrt(), eps)
if step == 2:
exp_avg.addcdiv_(grad, denom)
else:
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
normed_grad = grad.div(denom)
if clip_lambda is not None:
clip = clip_lambda(step)
normed_grad.clamp_(-clip, clip)

exp_avg.lerp_(normed_grad, 1 - beta1)

param.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)

# update step
step_t += 1


def _multi_tensor_adopt(
params: List[Tensor],
Expand All @@ -321,8 +339,9 @@ def _multi_tensor_adopt(
beta1: float,
beta2: float,
lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float,
decoupled: bool,
decouple: bool,
eps: float,
maximize: bool,
capturable: bool,
Expand Down Expand Up @@ -376,52 +395,62 @@ def _multi_tensor_adopt(
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]

# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0 and not decouple:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)

if device_state_steps[0] == 0:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)

if weight_decay != 0:
if decoupled:
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_params, device_params, alpha=-lr * weight_decay
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
torch._foreach_add_(device_state_steps, 1)

if device_state_steps[0] == 1:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
continue

if weight_decay != 0 and decouple:
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)

exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)

if device_state_steps[0] == 2:
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
else:
torch._foreach_mul_(device_exp_avgs, beta1)
torch._foreach_addcdiv_(
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
)
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
if clip_lambda is not None:
clip = clip_lambda(device_state_steps[0])
torch._foreach_maximum_(normed_grad, -clip)
torch._foreach_minimum_(normed_grad, clip)

torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)

torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
)

# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)


@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
def adopt(
Expand All @@ -443,8 +472,9 @@ def adopt(
beta1: float,
beta2: float,
lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float,
decoupled: bool,
decouple: bool,
eps: float,
maximize: bool,
):
Expand Down Expand Up @@ -497,8 +527,9 @@ def adopt(
beta1=beta1,
beta2=beta2,
lr=lr,
clip_lambda=clip_lambda,
weight_decay=weight_decay,
decoupled=decoupled,
decouple=decouple,
eps=eps,
maximize=maximize,
capturable=capturable,
Expand Down
Loading

0 comments on commit d5f58b6

Please sign in to comment.