From 9d904739bea00bb68ed0cccc513d3ef5b49f24cb Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 5 Dec 2023 19:13:54 +0000 Subject: [PATCH] add the patching --- .../inference/test_convert_composer_to_hf.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index d21c942dee..3eec135697 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -6,7 +6,7 @@ import pathlib import shutil from argparse import Namespace -from typing import Callable, Optional, cast +from typing import Any, Callable, Optional, cast from unittest.mock import ANY, MagicMock, patch import pytest @@ -258,7 +258,7 @@ def test_callback_inits(): @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)]) -@patch('os.cpu_count', MagicMock(return_value=None)) +@patch('os.cpu_count', MagicMock(return_value=1)) def test_huggingface_conversion_callback_interval( tmp_path: pathlib.Path, log_to_mlflow: bool, hf_save_interval: str, save_interval: str, max_duration: str, expected_hf_checkpoints: int, @@ -381,14 +381,12 @@ def test_huggingface_conversion_callback_interval( @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('1ba', '1ba', '1ba', 1, 1)]) -@patch('os.cpu_count', MagicMock(return_value=None)) -def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, - tie_word_embeddings: bool, - fsdp_state_dict_type: Optional[str], - hf_save_interval: str, - save_interval: str, max_duration: str, - expected_hf_checkpoints: int, - expected_normal_checkpoints: int): +@patch('os.cpu_count', MagicMock(return_value=1)) +def test_huggingface_conversion_callback( + model: str, tmp_path: pathlib.Path, tie_word_embeddings: bool, + fsdp_state_dict_type: Optional[str], hf_save_interval: str, + save_interval: str, max_duration: str, expected_hf_checkpoints: int, + expected_normal_checkpoints: int, monkeypatch: Any): delete_transformers_cache() dist.initialize_dist(get_device('gpu')) @@ -580,12 +578,15 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, assert len(normal_checkpoints) == expected_normal_checkpoints assert len(huggingface_checkpoints) == expected_hf_checkpoints - # Load the last huggingface checkpoint - loaded_model = transformers.AutoModelForCausalLM.from_pretrained( - os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba{batches_per_epoch}'), - trust_remote_code=True, - ) + # Patch flash_attn package to be empty to simulate loading the model in + # an environment without flash atttention installed + with patch.dict('sys.modules', {'flash_attn': None}): + # Load the last huggingface checkpoint + loaded_model = transformers.AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_path, 'checkpoints', 'huggingface', + f'ba{batches_per_epoch}'), + trust_remote_code=True, + ) # Check that the loaded model has the correct precision, and then set it back # to the original for the equivalence check