From 26bbc571cbbb44c77f8d745cddd987449fa2c92c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 22 Nov 2023 23:00:37 +0000 Subject: [PATCH 1/8] enable param group configuration in llm-foundry --- llmfoundry/optim/lion8b.py | 22 ++++++++------- llmfoundry/utils/builders.py | 53 ++++++++++++++++++++++++++++++++---- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 2c2e6e2d35..9d1d1dda71 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch @@ -58,15 +58,17 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): device, or b) step() is executed on a non-CUDA parameter. """ - def __init__(self, - params: Iterable[torch.Tensor], - lr: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0, - quantize: bool = True, - compress_state_dict: bool = False, - error_correction: bool = False, - _fused: bool = True): # XXX this flag is mostly for testing... + def __init__( + self, + params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0, + quantize: bool = True, + compress_state_dict: bool = False, + error_correction: bool = False, + _fused: bool = True, # XXX this flag is mostly for testing... + ): if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index c31917efc6..06e9b26805 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -4,7 +4,8 @@ import logging import os import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from collections import OrderedDict +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from composer import algorithms @@ -155,18 +156,58 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]) -> Algorithm: raise ValueError(f'Not sure how to build algorithm: {name}') +def extract_param_groups( + model: torch.nn.Module, + optimizer_config: Dict[str, Any], +) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: + + if 'disable_grad' in optimizer_config.keys(): + str_match = optimizer_config.pop('disable_grad') + if isinstance(str_match, str): + str_match = [str_match] + for _str_match in str_match: + for n, p in model.named_parameters(): + if n in _str_match: + p.requires_grad = False + + if 'param_groups' in optimizer_config.keys(): + params = [] + param_dict = OrderedDict((n, p) for n, p in model.named_parameters()) + + for param_group_config in optimizer_config['param_groups']: + str_match = param_group_config.pop('param_str_match') + group_param_names = [n for n in param_dict.keys() if str_match in n] + _params = [] + for n in group_param_names: + _params.append(param_dict.pop(n)) + group_params = {'params': _params} + group_params.update(param_group_config) + + params.append(group_params) + + optimizer_config.pop('param_groups') + + params.insert(0, {'params': param_dict.values()}) + return params + + return model.parameters() + + def build_optimizer(model: torch.nn.Module, name: str, optimizer_config: Dict[str, Any]) -> Optimizer: + + params = extract_param_groups(model, optimizer_config) + if name == 'decoupled_adamw': - return DecoupledAdamW(model.parameters(), **optimizer_config) + return DecoupledAdamW(params, **optimizer_config) elif name == 'decoupled_lionw': - return DecoupledLionW(model.parameters(), **optimizer_config) + return DecoupledLionW(params, **optimizer_config) elif name == 'clip_lion': - return DecoupledClipLion(model.parameters(), **optimizer_config) + return DecoupledClipLion(params, **optimizer_config) elif name == 'adalr_lion': - return DecoupledAdaLRLion(model.parameters(), **optimizer_config) + return DecoupledAdaLRLion(params, **optimizer_config) elif name == 'decoupled_lionw_8b': - return DecoupledLionW_8bit(model.parameters(), **optimizer_config) + return DecoupledLionW_8bit(params, **optimizer_config) else: raise ValueError(f'Not sure how to build optimizer: {name}') From f39592ab7fdbb0801545bddbe86e0949e8e4da0d Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley Date: Thu, 23 Nov 2023 13:43:02 -0800 Subject: [PATCH 2/8] add doc string --- llmfoundry/utils/builders.py | 57 +++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 06e9b26805..260b1f3dc1 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -160,7 +160,62 @@ def extract_param_groups( model: torch.nn.Module, optimizer_config: Dict[str, Any], ) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: - + """Extracts parameter groups defined in the optimizer config. + + The optimizer_config defines the optimizer args. It can additionally have key + `disable_grad` which is a string or list of strings. If a string matches a + parameter name, then that parameter will have `requires_grad=False`. This is + useful for freezing parameters. It can additionally have a key + `param_groups` which is a list of dicts. In this dict, key `param_str_match` + defines a string; if a parameter name contains this string, then it will be + in this parameter group. This is useful for grouping parameters together. + The dict can also contain any other key that is a valid optimizer arg. + + Usage + To disable gradient for all parameters that contain the string "norm" or "bias": + ``` + optimizer_config: { + "name": "decoupled_lionw", + "lr": 1e-3, + "weight_decay": 1e-2, + "betas": [0.9, 0.999], + "eps": 1e-8, + "disable_grad": ["norm", "bias"] + } + ``` + + To create modify the optimizer parameters for all parameters that contain the + string "norm" and "bias" seperately: + ``` + optimizer_config: { + "name": "decoupled_lionw", + "lr": 1e-3, + "weight_decay": 1e-2, + "betas": [0.9, 0.999], + "eps": 1e-8, + "param_groups": [ + { + "param_str_match": "norm", + "lr": 1e-4, + "weight_decay": 0.0, + }, + { + "param_str_match": "bias", + "lr": 5e-4, + "weight_decay": 0.0, + }, + ], + } + ``` + + Args: + model (torch.nn.Module): model to extract parameters from + optimizer_config (Dict[str, Any]): optimizer config + + Returns: + Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of + torch.Tensor's or dict's. Specifies what Tensors should be optimized. + """ if 'disable_grad' in optimizer_config.keys(): str_match = optimizer_config.pop('disable_grad') if isinstance(str_match, str): From 2e9548b4ba31b322dcc0e499c25292d147abca32 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley Date: Thu, 23 Nov 2023 13:59:58 -0800 Subject: [PATCH 3/8] add debug logs --- llmfoundry/utils/builders.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 260b1f3dc1..338f2262a4 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -224,6 +224,7 @@ def extract_param_groups( for n, p in model.named_parameters(): if n in _str_match: p.requires_grad = False + log.debug(f'Setting `{n}.requires_grad = False`.') if 'param_groups' in optimizer_config.keys(): params = [] @@ -243,6 +244,9 @@ def extract_param_groups( optimizer_config.pop('param_groups') params.insert(0, {'params': param_dict.values()}) + + log.debug(f'Optimizer param_groups: {params}.') + return params return model.parameters() From fc69417fd5fff6bb2363ec7741c35f1c2fef2cad Mon Sep 17 00:00:00 2001 From: root Date: Fri, 24 Nov 2023 03:57:40 +0000 Subject: [PATCH 4/8] add test, fix bug --- llmfoundry/utils/builders.py | 9 ++-- tests/test_builders.py | 81 +++++++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 338f2262a4..6b1865780e 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -156,7 +156,7 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]) -> Algorithm: raise ValueError(f'Not sure how to build algorithm: {name}') -def extract_param_groups( +def _extract_param_groups( model: torch.nn.Module, optimizer_config: Dict[str, Any], ) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: @@ -170,6 +170,9 @@ def extract_param_groups( defines a string; if a parameter name contains this string, then it will be in this parameter group. This is useful for grouping parameters together. The dict can also contain any other key that is a valid optimizer arg. + Note: to handle name overlap conflics, params are assiged to parameter + groups and added to `param_groups` in the order that `param_str_match` appear + in `param_groups`. Usage To disable gradient for all parameters that contain the string "norm" or "bias": @@ -222,7 +225,7 @@ def extract_param_groups( str_match = [str_match] for _str_match in str_match: for n, p in model.named_parameters(): - if n in _str_match: + if _str_match in n: p.requires_grad = False log.debug(f'Setting `{n}.requires_grad = False`.') @@ -255,7 +258,7 @@ def extract_param_groups( def build_optimizer(model: torch.nn.Module, name: str, optimizer_config: Dict[str, Any]) -> Optimizer: - params = extract_param_groups(model, optimizer_config) + params = _extract_param_groups(model, optimizer_config) if name == 'decoupled_adamw': return DecoupledAdamW(params, **optimizer_config) diff --git a/tests/test_builders.py b/tests/test_builders.py index 237e27b52b..40360037e1 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -2,16 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 import unittest.mock as mock -from typing import Union +from copy import deepcopy +from typing import Any, Dict, Union import pytest +import torch +import torch.nn as nn from composer.callbacks import Generate from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper -from llmfoundry.utils.builders import build_callback, build_tokenizer +from llmfoundry.utils.builders import (build_callback, build_optimizer, + build_tokenizer) @pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ @@ -110,3 +114,76 @@ def test_build_hf_checkpointer_callback(): assert isinstance(kwargs['mlflow_logging_config'], dict) assert isinstance(kwargs['mlflow_logging_config']['metadata'], dict) assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict + + +class _DummyModule(nn.Module): + + def __init__(self, device: str = 'cpu', dtype: torch.dtype = torch.float32): + super().__init__() + self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype) + self.norm0 = nn.LayerNorm(3, device=device, dtype=dtype) + self.linear1 = nn.Linear(3, 5, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore + return self.linear1(self.norm0(self.linear0(x))) + + +@pytest.mark.parametrize('name, optimizer_config', [ + ('decoupled_adamw', {}), + ('decoupled_lionw', {}), + ('clip_lion', {}), + ('adalr_lion', {}), + ('decoupled_lionw_8b', {}), +]) +@pytest.mark.parametrize('opt_additional_config', [ + { + 'disable_grad': 'norm' + }, + { + 'disable_grad': ['norm', 'bias'] + }, + { + 'param_groups': [{ + 'param_str_match': 'norm', + 'lr': 1e-9, + 'weight_decay': 0.0, + },] + }, + { + 'param_groups': [{ + 'param_str_match': 'norm', + 'lr': 1e-4, + 'weight_decay': 0.0, + },], + 'disable_grad': ['bias'], + }, +]) +def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], + opt_additional_config: Dict[str, Any]): + model = _DummyModule() + optimizer_config.update(deepcopy(opt_additional_config)) + optimizer = build_optimizer(model, name, optimizer_config) + + if 'disable_grad' in opt_additional_config.keys(): + disable_grad = opt_additional_config['disable_grad'] + if isinstance(disable_grad, str): + disable_grad = [disable_grad] + for n, p in model.named_parameters(): + for k in disable_grad: + if k in n: + assert not p.requires_grad + + if 'param_groups' in opt_additional_config.keys(): + for param_group_config, param_group in zip( + opt_additional_config['param_groups'], + optimizer.param_groups[1:]): + param_group_config = deepcopy(param_group_config) + param_str_match = param_group_config.pop('param_str_match') + + for k, v in param_group_config.items(): + assert param_group[k] == v + + param_ids = [id(p) for p in param_group['params']] + for n, p in model.named_parameters(): + if param_str_match in n: + assert id(p) in param_ids From 167e173b853b1ea5b86e7c58d166713f16714b94 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 24 Nov 2023 04:10:46 +0000 Subject: [PATCH 5/8] spell check; mark test gpu --- llmfoundry/utils/builders.py | 4 ++-- tests/test_builders.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 6b1865780e..0ad15955ae 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -170,7 +170,7 @@ def _extract_param_groups( defines a string; if a parameter name contains this string, then it will be in this parameter group. This is useful for grouping parameters together. The dict can also contain any other key that is a valid optimizer arg. - Note: to handle name overlap conflics, params are assiged to parameter + Note: to handle name overlap conflicts, params are assigned to parameter groups and added to `param_groups` in the order that `param_str_match` appear in `param_groups`. @@ -188,7 +188,7 @@ def _extract_param_groups( ``` To create modify the optimizer parameters for all parameters that contain the - string "norm" and "bias" seperately: + string "norm" and "bias" separately: ``` optimizer_config: { "name": "decoupled_lionw", diff --git a/tests/test_builders.py b/tests/test_builders.py index 40360037e1..5fd84b33a5 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore ('decoupled_lionw', {}), ('clip_lion', {}), ('adalr_lion', {}), - ('decoupled_lionw_8b', {}), + pytest.param('decoupled_lionw_8b', {}, marks=pytest.mark.gpu), ]) @pytest.mark.parametrize('opt_additional_config', [ { From 3d73bb0f9337c0fb1db78f3cb4195a94c1a9ae63 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 24 Nov 2023 04:57:14 +0000 Subject: [PATCH 6/8] updt to use RegEx search --- llmfoundry/utils/builders.py | 20 ++++++++++---------- tests/test_builders.py | 5 +++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 0ad15955ae..8665fd7ff7 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -1,8 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import functools import logging import os +import re import warnings from collections import OrderedDict from typing import Any, Dict, Iterable, List, Optional, Tuple, Union @@ -220,12 +222,12 @@ def _extract_param_groups( torch.Tensor's or dict's. Specifies what Tensors should be optimized. """ if 'disable_grad' in optimizer_config.keys(): - str_match = optimizer_config.pop('disable_grad') - if isinstance(str_match, str): - str_match = [str_match] - for _str_match in str_match: + str_matches = optimizer_config.pop('disable_grad') + if isinstance(str_matches, str): + str_matches = [str_matches] + for str_match in str_matches: for n, p in model.named_parameters(): - if _str_match in n: + if re.search(str_match, n): p.requires_grad = False log.debug(f'Setting `{n}.requires_grad = False`.') @@ -235,11 +237,9 @@ def _extract_param_groups( for param_group_config in optimizer_config['param_groups']: str_match = param_group_config.pop('param_str_match') - group_param_names = [n for n in param_dict.keys() if str_match in n] - _params = [] - for n in group_param_names: - _params.append(param_dict.pop(n)) - group_params = {'params': _params} + filter_fn = functools.partial(re.search, str_match) + param_names = [n for n in param_dict.keys() if filter_fn(n)] + group_params = {'params': [param_dict.pop(n) for n in param_names]} group_params.update(param_group_config) params.append(group_params) diff --git a/tests/test_builders.py b/tests/test_builders.py index 5fd84b33a5..7a11f16ff0 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import re import unittest.mock as mock from copy import deepcopy from typing import Any, Dict, Union @@ -170,7 +171,7 @@ def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], disable_grad = [disable_grad] for n, p in model.named_parameters(): for k in disable_grad: - if k in n: + if re.search(k, n): assert not p.requires_grad if 'param_groups' in opt_additional_config.keys(): @@ -185,5 +186,5 @@ def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], param_ids = [id(p) for p in param_group['params']] for n, p in model.named_parameters(): - if param_str_match in n: + if re.search(param_str_match, n): assert id(p) in param_ids From 6130334e97d4fdc89caae44748a121befce7417c Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Tue, 28 Nov 2023 15:32:25 -0800 Subject: [PATCH 7/8] Apply suggestions from code review Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/utils/builders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 8665fd7ff7..d7a7531d35 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -189,7 +189,7 @@ def _extract_param_groups( } ``` - To create modify the optimizer parameters for all parameters that contain the + To create and modify the optimizer parameters for all parameters that contain the string "norm" and "bias" separately: ``` optimizer_config: { @@ -219,7 +219,7 @@ def _extract_param_groups( Returns: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of - torch.Tensor's or dict's. Specifies what Tensors should be optimized. + torch.Tensor's or dict's. Specifies what Tensors should be optimized and their param groupings. """ if 'disable_grad' in optimizer_config.keys(): str_matches = optimizer_config.pop('disable_grad') From 6b9ccbf0e055a4955a82ce265c0baef158c1aa0b Mon Sep 17 00:00:00 2001 From: root Date: Tue, 28 Nov 2023 23:57:47 +0000 Subject: [PATCH 8/8] updt with dakinggg pr comments --- llmfoundry/utils/builders.py | 23 +++++++++++++---------- tests/test_builders.py | 7 +++++++ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index d7a7531d35..14196c3ef9 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -189,8 +189,8 @@ def _extract_param_groups( } ``` - To create and modify the optimizer parameters for all parameters that contain the - string "norm" and "bias" separately: + To create and modify the optimizer parameters for all parameters that contain + the string "norm" and "bias" separately: ``` optimizer_config: { "name": "decoupled_lionw", @@ -219,7 +219,8 @@ def _extract_param_groups( Returns: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of - torch.Tensor's or dict's. Specifies what Tensors should be optimized and their param groupings. + torch.Tensor's or dict's. Specifies what Tensors should be optimized + and their param groupings. """ if 'disable_grad' in optimizer_config.keys(): str_matches = optimizer_config.pop('disable_grad') @@ -231,25 +232,27 @@ def _extract_param_groups( p.requires_grad = False log.debug(f'Setting `{n}.requires_grad = False`.') - if 'param_groups' in optimizer_config.keys(): + param_groups_config = optimizer_config.pop('param_groups', None) + if param_groups_config is not None: params = [] param_dict = OrderedDict((n, p) for n, p in model.named_parameters()) - for param_group_config in optimizer_config['param_groups']: + log.debug(f'Default optimizer settings: {optimizer_config}.') + for param_group_config in param_groups_config: str_match = param_group_config.pop('param_str_match') filter_fn = functools.partial(re.search, str_match) param_names = [n for n in param_dict.keys() if filter_fn(n)] group_params = {'params': [param_dict.pop(n) for n in param_names]} group_params.update(param_group_config) - params.append(group_params) + log.debug( + f'Creating optimizer param_group with parameters: {param_names} ' +\ + f'(extracted using {str_match=}). The param_group optimizer ' +\ + f'setting overrides are: {param_group_config}.') - optimizer_config.pop('param_groups') + params.append(group_params) params.insert(0, {'params': param_dict.values()}) - - log.debug(f'Optimizer param_groups: {params}.') - return params return model.parameters() diff --git a/tests/test_builders.py b/tests/test_builders.py index 7a11f16ff0..7ac179720e 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -150,6 +150,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore 'weight_decay': 0.0, },] }, + { + 'param_groups': [{ + 'param_str_match': 'no.*.bias', + 'lr': 1e-9, + 'weight_decay': 0.0, + },] + }, { 'param_groups': [{ 'param_str_match': 'norm',