Skip to content

Commit

Permalink
updt to use RegEx search
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Nov 24, 2023
1 parent 167e173 commit 3d73bb0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
20 changes: 10 additions & 10 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import functools
import logging
import os
import re
import warnings
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -220,12 +222,12 @@ def _extract_param_groups(
torch.Tensor's or dict's. Specifies what Tensors should be optimized.
"""
if 'disable_grad' in optimizer_config.keys():
str_match = optimizer_config.pop('disable_grad')
if isinstance(str_match, str):
str_match = [str_match]
for _str_match in str_match:
str_matches = optimizer_config.pop('disable_grad')
if isinstance(str_matches, str):
str_matches = [str_matches]
for str_match in str_matches:
for n, p in model.named_parameters():
if _str_match in n:
if re.search(str_match, n):
p.requires_grad = False
log.debug(f'Setting `{n}.requires_grad = False`.')

Expand All @@ -235,11 +237,9 @@ def _extract_param_groups(

for param_group_config in optimizer_config['param_groups']:
str_match = param_group_config.pop('param_str_match')
group_param_names = [n for n in param_dict.keys() if str_match in n]
_params = []
for n in group_param_names:
_params.append(param_dict.pop(n))
group_params = {'params': _params}
filter_fn = functools.partial(re.search, str_match)
param_names = [n for n in param_dict.keys() if filter_fn(n)]
group_params = {'params': [param_dict.pop(n) for n in param_names]}
group_params.update(param_group_config)

params.append(group_params)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_builders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import re
import unittest.mock as mock
from copy import deepcopy
from typing import Any, Dict, Union
Expand Down Expand Up @@ -170,7 +171,7 @@ def test_build_optimizer(name: str, optimizer_config: Dict[str, Any],
disable_grad = [disable_grad]
for n, p in model.named_parameters():
for k in disable_grad:
if k in n:
if re.search(k, n):
assert not p.requires_grad

if 'param_groups' in opt_additional_config.keys():
Expand All @@ -185,5 +186,5 @@ def test_build_optimizer(name: str, optimizer_config: Dict[str, Any],

param_ids = [id(p) for p in param_group['params']]
for n, p in model.named_parameters():
if param_str_match in n:
if re.search(param_str_match, n):
assert id(p) in param_ids

0 comments on commit 3d73bb0

Please sign in to comment.