diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 2c2e6e2d35..9d1d1dda71 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch @@ -58,15 +58,17 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): device, or b) step() is executed on a non-CUDA parameter. """ - def __init__(self, - params: Iterable[torch.Tensor], - lr: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0, - quantize: bool = True, - compress_state_dict: bool = False, - error_correction: bool = False, - _fused: bool = True): # XXX this flag is mostly for testing... + def __init__( + self, + params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0, + quantize: bool = True, + compress_state_dict: bool = False, + error_correction: bool = False, + _fused: bool = True, # XXX this flag is mostly for testing... + ): if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index c31917efc6..14196c3ef9 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -1,10 +1,13 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import functools import logging import os +import re import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from collections import OrderedDict +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from composer import algorithms @@ -155,18 +158,121 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]) -> Algorithm: raise ValueError(f'Not sure how to build algorithm: {name}') +def _extract_param_groups( + model: torch.nn.Module, + optimizer_config: Dict[str, Any], +) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: + """Extracts parameter groups defined in the optimizer config. + + The optimizer_config defines the optimizer args. It can additionally have key + `disable_grad` which is a string or list of strings. If a string matches a + parameter name, then that parameter will have `requires_grad=False`. This is + useful for freezing parameters. It can additionally have a key + `param_groups` which is a list of dicts. In this dict, key `param_str_match` + defines a string; if a parameter name contains this string, then it will be + in this parameter group. This is useful for grouping parameters together. + The dict can also contain any other key that is a valid optimizer arg. + Note: to handle name overlap conflicts, params are assigned to parameter + groups and added to `param_groups` in the order that `param_str_match` appear + in `param_groups`. + + Usage + To disable gradient for all parameters that contain the string "norm" or "bias": + ``` + optimizer_config: { + "name": "decoupled_lionw", + "lr": 1e-3, + "weight_decay": 1e-2, + "betas": [0.9, 0.999], + "eps": 1e-8, + "disable_grad": ["norm", "bias"] + } + ``` + + To create and modify the optimizer parameters for all parameters that contain + the string "norm" and "bias" separately: + ``` + optimizer_config: { + "name": "decoupled_lionw", + "lr": 1e-3, + "weight_decay": 1e-2, + "betas": [0.9, 0.999], + "eps": 1e-8, + "param_groups": [ + { + "param_str_match": "norm", + "lr": 1e-4, + "weight_decay": 0.0, + }, + { + "param_str_match": "bias", + "lr": 5e-4, + "weight_decay": 0.0, + }, + ], + } + ``` + + Args: + model (torch.nn.Module): model to extract parameters from + optimizer_config (Dict[str, Any]): optimizer config + + Returns: + Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of + torch.Tensor's or dict's. Specifies what Tensors should be optimized + and their param groupings. + """ + if 'disable_grad' in optimizer_config.keys(): + str_matches = optimizer_config.pop('disable_grad') + if isinstance(str_matches, str): + str_matches = [str_matches] + for str_match in str_matches: + for n, p in model.named_parameters(): + if re.search(str_match, n): + p.requires_grad = False + log.debug(f'Setting `{n}.requires_grad = False`.') + + param_groups_config = optimizer_config.pop('param_groups', None) + if param_groups_config is not None: + params = [] + param_dict = OrderedDict((n, p) for n, p in model.named_parameters()) + + log.debug(f'Default optimizer settings: {optimizer_config}.') + for param_group_config in param_groups_config: + str_match = param_group_config.pop('param_str_match') + filter_fn = functools.partial(re.search, str_match) + param_names = [n for n in param_dict.keys() if filter_fn(n)] + group_params = {'params': [param_dict.pop(n) for n in param_names]} + group_params.update(param_group_config) + + log.debug( + f'Creating optimizer param_group with parameters: {param_names} ' +\ + f'(extracted using {str_match=}). The param_group optimizer ' +\ + f'setting overrides are: {param_group_config}.') + + params.append(group_params) + + params.insert(0, {'params': param_dict.values()}) + return params + + return model.parameters() + + def build_optimizer(model: torch.nn.Module, name: str, optimizer_config: Dict[str, Any]) -> Optimizer: + + params = _extract_param_groups(model, optimizer_config) + if name == 'decoupled_adamw': - return DecoupledAdamW(model.parameters(), **optimizer_config) + return DecoupledAdamW(params, **optimizer_config) elif name == 'decoupled_lionw': - return DecoupledLionW(model.parameters(), **optimizer_config) + return DecoupledLionW(params, **optimizer_config) elif name == 'clip_lion': - return DecoupledClipLion(model.parameters(), **optimizer_config) + return DecoupledClipLion(params, **optimizer_config) elif name == 'adalr_lion': - return DecoupledAdaLRLion(model.parameters(), **optimizer_config) + return DecoupledAdaLRLion(params, **optimizer_config) elif name == 'decoupled_lionw_8b': - return DecoupledLionW_8bit(model.parameters(), **optimizer_config) + return DecoupledLionW_8bit(params, **optimizer_config) else: raise ValueError(f'Not sure how to build optimizer: {name}') diff --git a/tests/test_builders.py b/tests/test_builders.py index 237e27b52b..7ac179720e 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -1,17 +1,22 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import re import unittest.mock as mock -from typing import Union +from copy import deepcopy +from typing import Any, Dict, Union import pytest +import torch +import torch.nn as nn from composer.callbacks import Generate from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper -from llmfoundry.utils.builders import build_callback, build_tokenizer +from llmfoundry.utils.builders import (build_callback, build_optimizer, + build_tokenizer) @pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ @@ -110,3 +115,83 @@ def test_build_hf_checkpointer_callback(): assert isinstance(kwargs['mlflow_logging_config'], dict) assert isinstance(kwargs['mlflow_logging_config']['metadata'], dict) assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict + + +class _DummyModule(nn.Module): + + def __init__(self, device: str = 'cpu', dtype: torch.dtype = torch.float32): + super().__init__() + self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype) + self.norm0 = nn.LayerNorm(3, device=device, dtype=dtype) + self.linear1 = nn.Linear(3, 5, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore + return self.linear1(self.norm0(self.linear0(x))) + + +@pytest.mark.parametrize('name, optimizer_config', [ + ('decoupled_adamw', {}), + ('decoupled_lionw', {}), + ('clip_lion', {}), + ('adalr_lion', {}), + pytest.param('decoupled_lionw_8b', {}, marks=pytest.mark.gpu), +]) +@pytest.mark.parametrize('opt_additional_config', [ + { + 'disable_grad': 'norm' + }, + { + 'disable_grad': ['norm', 'bias'] + }, + { + 'param_groups': [{ + 'param_str_match': 'norm', + 'lr': 1e-9, + 'weight_decay': 0.0, + },] + }, + { + 'param_groups': [{ + 'param_str_match': 'no.*.bias', + 'lr': 1e-9, + 'weight_decay': 0.0, + },] + }, + { + 'param_groups': [{ + 'param_str_match': 'norm', + 'lr': 1e-4, + 'weight_decay': 0.0, + },], + 'disable_grad': ['bias'], + }, +]) +def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], + opt_additional_config: Dict[str, Any]): + model = _DummyModule() + optimizer_config.update(deepcopy(opt_additional_config)) + optimizer = build_optimizer(model, name, optimizer_config) + + if 'disable_grad' in opt_additional_config.keys(): + disable_grad = opt_additional_config['disable_grad'] + if isinstance(disable_grad, str): + disable_grad = [disable_grad] + for n, p in model.named_parameters(): + for k in disable_grad: + if re.search(k, n): + assert not p.requires_grad + + if 'param_groups' in opt_additional_config.keys(): + for param_group_config, param_group in zip( + opt_additional_config['param_groups'], + optimizer.param_groups[1:]): + param_group_config = deepcopy(param_group_config) + param_str_match = param_group_config.pop('param_str_match') + + for k, v in param_group_config.items(): + assert param_group[k] == v + + param_ids = [id(p) for p in param_group['params']] + for n, p in model.named_parameters(): + if re.search(param_str_match, n): + assert id(p) in param_ids