Skip to content

Commit

Permalink
integrated helper function into other test
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Mar 26, 2024
1 parent 37e073a commit ae663cd
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
from scripts.inference.convert_composer_to_hf import convert_composer_to_hf
from tests.data_utils import make_tiny_ft_dataset

_OPTIMIZER_CFG = lambda: {
'name': 'decoupled_adamw',
'lr': 6e-4,
'betas': [0.9, 0.95],
'eps': 1e-8,
'weight_decay': 0.0,
}


def _save_model_mock(*args: Any, path: str, **kwargs: Any):
os.makedirs(path, exist_ok=True)
Expand Down Expand Up @@ -292,13 +300,7 @@ def test_huggingface_conversion_callback_interval(

original_model = build_tiny_mpt()

optimizer_config = {
'name': 'decoupled_adamw',
'lr': 6e-4,
'betas': [0.9, 0.95],
'eps': 1e-8,
'weight_decay': 0.0,
}
optimizer_config = _OPTIMIZER_CFG()
optimizer_name = optimizer_config.pop('name')
optimizer = build_optimizer(original_model, optimizer_name,
optimizer_config)
Expand Down Expand Up @@ -527,15 +529,6 @@ def _get_dataloader_cfg(tiny_dataset_folder_path: str, max_seq_len: int):
return dataloader_cfg


_OPTIMIZER_CFG = lambda: {
'name': 'decoupled_adamw',
'lr': 6e-4,
'betas': [0.9, 0.95],
'eps': 1e-8,
'weight_decay': 0.0,
}


def _assert_checkpoint_equivalence(tmp_path: pathlib.Path,
expected_normal_checkpoints: int,
expected_hf_checkpoints: int,
Expand Down

0 comments on commit ae663cd

Please sign in to comment.