diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index d21c942dee..5c3d0f1830 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -382,13 +382,17 @@ def test_huggingface_conversion_callback_interval( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('1ba', '1ba', '1ba', 1, 1)]) @patch('os.cpu_count', MagicMock(return_value=None)) -def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, - tie_word_embeddings: bool, - fsdp_state_dict_type: Optional[str], - hf_save_interval: str, - save_interval: str, max_duration: str, - expected_hf_checkpoints: int, - expected_normal_checkpoints: int): +def test_huggingface_conversion_callback( + model: str, + tmp_path: pathlib.Path, + tie_word_embeddings: bool, + fsdp_state_dict_type: Optional[str], + hf_save_interval: str, + save_interval: str, + max_duration: str, + expected_hf_checkpoints: int, + expected_normal_checkpoints: int, +): delete_transformers_cache() dist.initialize_dist(get_device('gpu')) @@ -580,12 +584,15 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, assert len(normal_checkpoints) == expected_normal_checkpoints assert len(huggingface_checkpoints) == expected_hf_checkpoints - # Load the last huggingface checkpoint - loaded_model = transformers.AutoModelForCausalLM.from_pretrained( - os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba{batches_per_epoch}'), - trust_remote_code=True, - ) + # Patch flash_attn package to be empty to simulate loading the model in + # an environment without flash atttention installed + with patch.dict('sys.modules', {'flash_attn': None}): + # Load the last huggingface checkpoint + loaded_model = transformers.AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_path, 'checkpoints', 'huggingface', + f'ba{batches_per_epoch}'), + trust_remote_code=True, + ) # Check that the loaded model has the correct precision, and then set it back # to the original for the equivalence check