Skip to content

Commit

Permalink
Patch flash attn in test to simulate environment without it installed (
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Dec 5, 2023
1 parent 6ff3f27 commit 1ed9d22
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,17 @@ def test_huggingface_conversion_callback_interval(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('1ba', '1ba', '1ba', 1, 1)])
@patch('os.cpu_count', MagicMock(return_value=None))
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str],
hf_save_interval: str,
save_interval: str, max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int):
def test_huggingface_conversion_callback(
model: str,
tmp_path: pathlib.Path,
tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str],
hf_save_interval: str,
save_interval: str,
max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int,
):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))
Expand Down Expand Up @@ -580,12 +584,15 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
assert len(normal_checkpoints) == expected_normal_checkpoints
assert len(huggingface_checkpoints) == expected_hf_checkpoints

# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)
# Patch flash_attn package to be empty to simulate loading the model in
# an environment without flash atttention installed
with patch.dict('sys.modules', {'flash_attn': None}):
# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)

# Check that the loaded model has the correct precision, and then set it back
# to the original for the equivalence check
Expand Down

0 comments on commit 1ed9d22

Please sign in to comment.