Skip to content

Commit

Permalink
Fix 1.13 tests (#751)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Nov 21, 2023
1 parent 9bf21f2 commit 6dc94a2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
11 changes: 7 additions & 4 deletions tests/test_fsdp_act_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from composer import Trainer
from composer.utils import get_device
from composer.utils import get_device, using_torch_2
from omegaconf import OmegaConf as om
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
CheckpointWrapper
Expand Down Expand Up @@ -60,9 +60,12 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool,
) or activation_checkpointing_target == [
'mptblock', 'grouped_query_attention'
]:
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module, CheckpointWrapper)
module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[
0]._fsdp_wrapped_module
if not using_torch_2():
module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[
0]._fsdp_wrapped_module._fpw_module
assert isinstance(module, CheckpointWrapper)
elif activation_checkpointing_target == ['grouped_query_attention']:
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from composer import Trainer
from composer.loggers import MLFlowLogger
from composer.utils import dist, get_device
from composer.utils import dist, get_device, using_torch_2

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM
Expand Down Expand Up @@ -501,7 +501,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
'drop_last': False,
'num_workers': 0,
'pin_memory': False,
'prefetch_factor': None,
'prefetch_factor': None if using_torch_2() else 2,
'persistent_workers': False,
'timeout': 0
}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,8 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
right_pad_v_left_pad_atol = 1e-6 if attention_impl == 'torch' else 1e-8
if rope and pos_emb_config['rope_impl'] == 'dail':
# dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs.
right_pad_v_left_pad_rtol = 1e-2
right_pad_v_left_pad_atol = 1e-2
right_pad_v_left_pad_rtol = 2e-2
right_pad_v_left_pad_atol = 2e-2
assert torch.allclose(right_padding_output[0, :3],
left_padding_output[0, 3:],
rtol=right_pad_v_left_pad_rtol,
Expand Down

0 comments on commit 6dc94a2

Please sign in to comment.