Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch flash attn in test to simulate environment without it installed #778

Merged
merged 4 commits into from
Dec 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,17 @@ def test_huggingface_conversion_callback_interval(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('1ba', '1ba', '1ba', 1, 1)])
@patch('os.cpu_count', MagicMock(return_value=None))
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str],
hf_save_interval: str,
save_interval: str, max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int):
def test_huggingface_conversion_callback(
model: str,
tmp_path: pathlib.Path,
tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str],
hf_save_interval: str,
save_interval: str,
max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int,
):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))
Expand Down Expand Up @@ -580,12 +584,15 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
assert len(normal_checkpoints) == expected_normal_checkpoints
assert len(huggingface_checkpoints) == expected_hf_checkpoints

# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)
# Patch flash_attn package to be empty to simulate loading the model in
# an environment without flash atttention installed
with patch.dict('sys.modules', {'flash_attn': None}):
# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)

# Check that the loaded model has the correct precision, and then set it back
# to the original for the equivalence check
Expand Down
Loading