From 3d73bb0f9337c0fb1db78f3cb4195a94c1a9ae63 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 24 Nov 2023 04:57:14 +0000 Subject: [PATCH] updt to use RegEx search --- llmfoundry/utils/builders.py | 20 ++++++++++---------- tests/test_builders.py | 5 +++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 0ad15955ae..8665fd7ff7 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -1,8 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import functools import logging import os +import re import warnings from collections import OrderedDict from typing import Any, Dict, Iterable, List, Optional, Tuple, Union @@ -220,12 +222,12 @@ def _extract_param_groups( torch.Tensor's or dict's. Specifies what Tensors should be optimized. """ if 'disable_grad' in optimizer_config.keys(): - str_match = optimizer_config.pop('disable_grad') - if isinstance(str_match, str): - str_match = [str_match] - for _str_match in str_match: + str_matches = optimizer_config.pop('disable_grad') + if isinstance(str_matches, str): + str_matches = [str_matches] + for str_match in str_matches: for n, p in model.named_parameters(): - if _str_match in n: + if re.search(str_match, n): p.requires_grad = False log.debug(f'Setting `{n}.requires_grad = False`.') @@ -235,11 +237,9 @@ def _extract_param_groups( for param_group_config in optimizer_config['param_groups']: str_match = param_group_config.pop('param_str_match') - group_param_names = [n for n in param_dict.keys() if str_match in n] - _params = [] - for n in group_param_names: - _params.append(param_dict.pop(n)) - group_params = {'params': _params} + filter_fn = functools.partial(re.search, str_match) + param_names = [n for n in param_dict.keys() if filter_fn(n)] + group_params = {'params': [param_dict.pop(n) for n in param_names]} group_params.update(param_group_config) params.append(group_params) diff --git a/tests/test_builders.py b/tests/test_builders.py index 5fd84b33a5..7a11f16ff0 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import re import unittest.mock as mock from copy import deepcopy from typing import Any, Dict, Union @@ -170,7 +171,7 @@ def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], disable_grad = [disable_grad] for n, p in model.named_parameters(): for k in disable_grad: - if k in n: + if re.search(k, n): assert not p.requires_grad if 'param_groups' in opt_additional_config.keys(): @@ -185,5 +186,5 @@ def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], param_ids = [id(p) for p in param_group['params']] for n, p in model.named_parameters(): - if param_str_match in n: + if re.search(param_str_match, n): assert id(p) in param_ids