Skip to content

Commit

Permalink
add the patching
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Dec 5, 2023
1 parent 61cd110 commit 9d90473
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pathlib
import shutil
from argparse import Namespace
from typing import Callable, Optional, cast
from typing import Any, Callable, Optional, cast
from unittest.mock import ANY, MagicMock, patch

import pytest
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_callback_inits():
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)])
@patch('os.cpu_count', MagicMock(return_value=None))
@patch('os.cpu_count', MagicMock(return_value=1))
def test_huggingface_conversion_callback_interval(
tmp_path: pathlib.Path, log_to_mlflow: bool, hf_save_interval: str,
save_interval: str, max_duration: str, expected_hf_checkpoints: int,
Expand Down Expand Up @@ -381,14 +381,12 @@ def test_huggingface_conversion_callback_interval(
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('1ba', '1ba', '1ba', 1, 1)])
@patch('os.cpu_count', MagicMock(return_value=None))
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str],
hf_save_interval: str,
save_interval: str, max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int):
@patch('os.cpu_count', MagicMock(return_value=1))
def test_huggingface_conversion_callback(
model: str, tmp_path: pathlib.Path, tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str], hf_save_interval: str,
save_interval: str, max_duration: str, expected_hf_checkpoints: int,
expected_normal_checkpoints: int, monkeypatch: Any):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))
Expand Down Expand Up @@ -580,12 +578,15 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
assert len(normal_checkpoints) == expected_normal_checkpoints
assert len(huggingface_checkpoints) == expected_hf_checkpoints

# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)
# Patch flash_attn package to be empty to simulate loading the model in
# an environment without flash atttention installed
with patch.dict('sys.modules', {'flash_attn': None}):
# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)

# Check that the loaded model has the correct precision, and then set it back
# to the original for the equivalence check
Expand Down

0 comments on commit 9d90473

Please sign in to comment.