From f39592ab7fdbb0801545bddbe86e0949e8e4da0d Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley Date: Thu, 23 Nov 2023 13:43:02 -0800 Subject: [PATCH] add doc string --- llmfoundry/utils/builders.py | 57 +++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 06e9b26805..260b1f3dc1 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -160,7 +160,62 @@ def extract_param_groups( model: torch.nn.Module, optimizer_config: Dict[str, Any], ) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: - + """Extracts parameter groups defined in the optimizer config. + + The optimizer_config defines the optimizer args. It can additionally have key + `disable_grad` which is a string or list of strings. If a string matches a + parameter name, then that parameter will have `requires_grad=False`. This is + useful for freezing parameters. It can additionally have a key + `param_groups` which is a list of dicts. In this dict, key `param_str_match` + defines a string; if a parameter name contains this string, then it will be + in this parameter group. This is useful for grouping parameters together. + The dict can also contain any other key that is a valid optimizer arg. + + Usage + To disable gradient for all parameters that contain the string "norm" or "bias": + ``` + optimizer_config: { + "name": "decoupled_lionw", + "lr": 1e-3, + "weight_decay": 1e-2, + "betas": [0.9, 0.999], + "eps": 1e-8, + "disable_grad": ["norm", "bias"] + } + ``` + + To create modify the optimizer parameters for all parameters that contain the + string "norm" and "bias" seperately: + ``` + optimizer_config: { + "name": "decoupled_lionw", + "lr": 1e-3, + "weight_decay": 1e-2, + "betas": [0.9, 0.999], + "eps": 1e-8, + "param_groups": [ + { + "param_str_match": "norm", + "lr": 1e-4, + "weight_decay": 0.0, + }, + { + "param_str_match": "bias", + "lr": 5e-4, + "weight_decay": 0.0, + }, + ], + } + ``` + + Args: + model (torch.nn.Module): model to extract parameters from + optimizer_config (Dict[str, Any]): optimizer config + + Returns: + Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of + torch.Tensor's or dict's. Specifies what Tensors should be optimized. + """ if 'disable_grad' in optimizer_config.keys(): str_match = optimizer_config.pop('disable_grad') if isinstance(str_match, str):