diff --git a/docs/config.qmd b/docs/config.qmd index bc3730095d..f01a2ce267 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -409,7 +409,7 @@ lr_div_factor: # Learning rate div factor # - adamw_torch_fused # - adamw_torch_xla # - adamw_apex_fused -# - adopt_adamw (only for torch version >= 2.5.1) +# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1) # - adafactor # - adamw_anyprecision # - sgd diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 86cc30a401..2ef78e07d8 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -100,8 +100,8 @@ def print_dep_versions(): print("*" * 40) print("**** Axolotl Dependency Versions *****") for pkg in packages: - version = _is_package_available(pkg, return_version=True) - print(f"{pkg: >{max_len}}: {version[1]: <15}") + pkg_version = _is_package_available(pkg, return_version=True) + print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}") print("*" * 40) @@ -444,6 +444,9 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), "compute_capability": gpu_version, }, + env_capabilities={ + "torch_version": str(torch.__version__).split("+", maxsplit=1)[0] + }, ) prepare_optim_env(cfg) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e4bf7de229..9b03563e0c 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -562,7 +562,9 @@ def create_optimizer(self): self.optimizer = ( # pylint: disable=attribute-defined-outside-init ADOPT( - optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs + optimizer_grouped_parameters, + decouple=True, + **optimizer_kwargs, ) ) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 0100f23ea5..422ed78efb 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -229,7 +229,11 @@ def normalize_cfg_datasets(cfg): cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja -def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): +def validate_config( + cfg: DictDefault, + capabilities: Optional[dict] = None, + env_capabilities: Optional[dict] = None, +): AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase AxolotlInputConfig = AxolotlInputConfigBase @@ -239,14 +243,24 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): AxolotlInputConfig, # pylint: disable=invalid-name ) = merge_input_args() - if capabilities: + if capabilities or env_capabilities: + if (capabilities and not env_capabilities) or ( + env_capabilities and not capabilities + ): + raise ValueError( + "Both capabilities and env_capabilities must be provided or not provided." + ) + return DictDefault( dict( AxolotlConfigWCapabilities( - **cfg.to_dict(), capabilities=capabilities + **cfg.to_dict(), + capabilities=capabilities, + env_capabilities=env_capabilities, ).model_dump(exclude_none=True) ) ) + return DictDefault( dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True)) ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1ac7efbfa5..0f01a7cadc 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -9,6 +9,7 @@ from enum import Enum from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union +from packaging import version from pydantic import ( BaseModel, Field, @@ -21,7 +22,7 @@ from transformers.training_args import OptimizerNames from transformers.utils.import_utils import is_torch_npu_available -from axolotl.utils.config.models.internals import GPUCapabilities +from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities LOG = logging.getLogger("axolotl.utils.config.models.input") @@ -1477,6 +1478,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options""" capabilities: GPUCapabilities + env_capabilities: EnvCapabilities @model_validator(mode="after") def check_bf16(self): @@ -1551,3 +1553,21 @@ def check_multigpu_unsloth(cls, data): "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training." ) return data + + @model_validator(mode="before") + @classmethod + def check_adopt_torch_version(cls, data): + if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")): + env_capabilities = data.get("env_capabilities", {}) + torch_version = env_capabilities.get("torch_version") + + if torch_version is None: + import torch + + torch_version = str(torch.__version__).split("+", maxsplit=1)[0] + + if version.parse(torch_version) < version.parse("2.5.1"): + raise ValueError( + "ADOPT optimizer is incompatible with torch version < 2.5.1" + ) + return data diff --git a/src/axolotl/utils/config/models/internals/__init__.py b/src/axolotl/utils/config/models/internals/__init__.py index dd742caf45..7b4a12e035 100644 --- a/src/axolotl/utils/config/models/internals/__init__.py +++ b/src/axolotl/utils/config/models/internals/__init__.py @@ -12,3 +12,9 @@ class GPUCapabilities(BaseModel): n_gpu: int = Field(default=1) n_node: int = Field(default=1) compute_capability: Optional[str] = Field(default=None) + + +class EnvCapabilities(BaseModel): + """model to manage the environment capabilities statically""" + + torch_version: Optional[str] = Field(default=None) diff --git a/src/axolotl/utils/optimizers/adopt.py b/src/axolotl/utils/optimizers/adopt.py index 7e133285f7..36217730b3 100644 --- a/src/axolotl/utils/optimizers/adopt.py +++ b/src/axolotl/utils/optimizers/adopt.py @@ -6,21 +6,29 @@ """ # mypy: ignore-errors # pylint: skip-file +# flake8: noqa # mypy: allow-untyped-decorators # mypy: allow-untyped-defs -from typing import List, Optional, Tuple, Union, cast +from typing import Callable, List, Optional, Tuple, Union, cast import torch from torch import Tensor -from torch.optim.optimizer import ( +from torch.optim.optimizer import ( # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling, + DeviceDict, Optimizer, ParamsT, + _capturable_doc, _default_to_fused_or_foreach, _device_dtype_check_for_fused, + _differentiable_doc, _disable_dynamo_if_unsupported, + _foreach_doc, + _fused_doc, _get_capturable_supported_devices, _get_scalar_dtype, _get_value, + _maximize_doc, + _stack_if_compiling, _use_grad_for_differentiable, _view_as_real, ) @@ -35,8 +43,9 @@ def __init__( lr: Union[float, Tensor] = 1e-3, betas: Tuple[float, float] = (0.9, 0.9999), eps: float = 1e-6, + clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25, weight_decay: float = 0.0, - decoupled: bool = False, + decouple: bool = False, *, foreach: Optional[bool] = None, maximize: bool = False, @@ -62,12 +71,14 @@ def __init__( if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") + self.clip_lambda = clip_lambda + defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, - decoupled=decoupled, + decouple=decouple, maximize=maximize, foreach=foreach, capturable=capturable, @@ -219,8 +230,9 @@ def step(self, closure=None): beta1=beta1, beta2=beta2, lr=group["lr"], + clip_lambda=self.clip_lambda, weight_decay=group["weight_decay"], - decoupled=group["decoupled"], + decouple=group["decouple"], eps=group["eps"], maximize=group["maximize"], foreach=group["foreach"], @@ -247,8 +259,9 @@ def _single_tensor_adopt( beta1: float, beta2: float, lr: Union[float, Tensor], + clip_lambda: Optional[Callable[[int], float]], weight_decay: float, - decoupled: bool, + decouple: bool, eps: float, maximize: bool, capturable: bool, @@ -276,14 +289,10 @@ def _single_tensor_adopt( and param.device.type in capturable_supported_devices ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." - # update step - step_t += 1 + step = step_t if capturable or differentiable else _get_value(step_t) - if weight_decay != 0: - if decoupled: - param.add_(param, alpha=-lr * weight_decay) - else: - grad = grad.add(param, alpha=weight_decay) + if weight_decay != 0 and not decouple: + grad = grad.add(param, alpha=weight_decay) if torch.is_complex(param): grad = torch.view_as_real(grad) @@ -293,20 +302,29 @@ def _single_tensor_adopt( exp_avg_sq = torch.view_as_real(exp_avg_sq) param = torch.view_as_real(param) - step = step_t if capturable or differentiable else _get_value(step_t) - if step == 1: + if step == 0: exp_avg_sq.addcmul_(grad, grad.conj()) + # update step + step_t += 1 continue + if weight_decay != 0 and decouple: + param.add_(param, alpha=-lr * weight_decay) + denom = torch.clamp(exp_avg_sq.sqrt(), eps) - if step == 2: - exp_avg.addcdiv_(grad, denom) - else: - exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1) + normed_grad = grad.div(denom) + if clip_lambda is not None: + clip = clip_lambda(step) + normed_grad.clamp_(-clip, clip) + + exp_avg.lerp_(normed_grad, 1 - beta1) param.add_(exp_avg, alpha=-lr) exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + # update step + step_t += 1 + def _multi_tensor_adopt( params: List[Tensor], @@ -321,8 +339,9 @@ def _multi_tensor_adopt( beta1: float, beta2: float, lr: Union[float, Tensor], + clip_lambda: Optional[Callable[[int], float]], weight_decay: float, - decoupled: bool, + decouple: bool, eps: float, maximize: bool, capturable: bool, @@ -376,45 +395,44 @@ def _multi_tensor_adopt( if maximize: device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] - # Update steps - # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over - # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just - # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: - torch._foreach_add_( - device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 - ) - else: - torch._foreach_add_(device_state_steps, 1) + if weight_decay != 0 and not decouple: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + if device_state_steps[0] == 0: + torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads) - if weight_decay != 0: - if decoupled: + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: torch._foreach_add_( - device_params, device_params, alpha=-lr * weight_decay + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) else: - # Re-use the intermediate memory (device_grads) already allocated for maximize - if maximize: - torch._foreach_add_(device_grads, device_params, alpha=weight_decay) - else: - device_grads = torch._foreach_add( # type: ignore[assignment] - device_grads, device_params, alpha=weight_decay - ) + torch._foreach_add_(device_state_steps, 1) - if device_state_steps[0] == 1: - torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads) continue + if weight_decay != 0 and decouple: + torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay) + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) - exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps) + torch._foreach_maximum_(exp_avg_sq_sqrt, eps) - if device_state_steps[0] == 2: - torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt) - else: - torch._foreach_mul_(device_exp_avgs, beta1) - torch._foreach_addcdiv_( - device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1 - ) + normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt) + if clip_lambda is not None: + clip = clip_lambda(device_state_steps[0]) + torch._foreach_maximum_(normed_grad, -clip) + torch._foreach_minimum_(normed_grad, clip) + + torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1) torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr) torch._foreach_mul_(device_exp_avg_sqs, beta2) @@ -422,6 +440,17 @@ def _multi_tensor_adopt( device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2 ) + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) def adopt( @@ -443,8 +472,9 @@ def adopt( beta1: float, beta2: float, lr: Union[float, Tensor], + clip_lambda: Optional[Callable[[int], float]], weight_decay: float, - decoupled: bool, + decouple: bool, eps: float, maximize: bool, ): @@ -497,8 +527,9 @@ def adopt( beta1=beta1, beta2=beta2, lr=lr, + clip_lambda=clip_lambda, weight_decay=weight_decay, - decoupled=decoupled, + decouple=decouple, eps=eps, maximize=maximize, capturable=capturable, diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 439f6ac6f8..92e647e678 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -53,7 +53,7 @@ def is_min_2_3_1(): def require_torch_2_5_1(test_case): """ - Decorator marking a test that requires torch >= 2.3.1 + Decorator marking a test that requires torch >= 2.5.1 """ def is_min_2_5_1(): diff --git a/tests/test_validation.py b/tests/test_validation.py index 491f230c33..2e6fbab101 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -672,6 +672,9 @@ def test_merge_lora_no_bf16_fail(self, minimal_cfg): { "bf16": True, "capabilities": {"bf16": False}, + "env_capabilities": { + "torch_version": "2.5.1", + }, } ) | minimal_cfg @@ -1160,6 +1163,38 @@ def test_eval_strategy_remap(self, minimal_cfg): in self._caplog.records[0].message ) + def test_torch_version_adopt_req(self, minimal_cfg): + cfg = ( + DictDefault( + { + "optimizer": "adopt_adamw", + } + ) + | minimal_cfg + ) + + with pytest.raises( + ValueError, + match=r".*ADOPT optimizer is incompatible with torch version*", + ): + env_capabilities = {"torch_version": "2.3.0"} + capabilities = {"bf16": False} + _ = validate_config( + cfg, capabilities=capabilities, env_capabilities=env_capabilities + ) + + env_capabilities = {"torch_version": "2.5.1"} + capabilities = {"bf16": False} + _ = validate_config( + cfg, capabilities=capabilities, env_capabilities=env_capabilities + ) + + env_capabilities = {"torch_version": "2.5.2"} + capabilities = {"bf16": False} + _ = validate_config( + cfg, capabilities=capabilities, env_capabilities=env_capabilities + ) + class TestValidationCheckModelConfig(BaseValidation): """ diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py index 14f9d34627..89f642051b 100644 --- a/tests/test_validation_dataset.py +++ b/tests/test_validation_dataset.py @@ -72,6 +72,9 @@ def _check_config(): "n_gpu": 1, "compute_capability": "8.0", }, + env_capabilities={ + "torch_version": "2.5.1", + }, ) _check_config() @@ -124,6 +127,9 @@ def _check_config(): "n_gpu": 1, "compute_capability": "8.0", }, + env_capabilities={ + "torch_version": "2.5.1", + }, ) _check_config() @@ -177,6 +183,9 @@ def _check_config(): "n_gpu": 1, "compute_capability": "8.0", }, + env_capabilities={ + "torch_version": "2.5.1", + }, ) _check_config() @@ -231,6 +240,9 @@ def _check_config(): "n_gpu": 1, "compute_capability": "8.0", }, + env_capabilities={ + "torch_version": "2.5.1", + }, ) _check_config()