diff --git a/tests/test_fsdp_act_checkpoint.py b/tests/test_fsdp_act_checkpoint.py index 1a46fcbccd..3b9a746708 100644 --- a/tests/test_fsdp_act_checkpoint.py +++ b/tests/test_fsdp_act_checkpoint.py @@ -3,7 +3,7 @@ import pytest from composer import Trainer -from composer.utils import get_device +from composer.utils import get_device, using_torch_2 from omegaconf import OmegaConf as om from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ CheckpointWrapper @@ -60,9 +60,12 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool, ) or activation_checkpointing_target == [ 'mptblock', 'grouped_query_attention' ]: - assert isinstance( - trainer.state.model.model._fsdp_wrapped_module.transformer. - blocks[0]._fsdp_wrapped_module, CheckpointWrapper) + module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[ + 0]._fsdp_wrapped_module + if not using_torch_2(): + module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[ + 0]._fsdp_wrapped_module._fpw_module + assert isinstance(module, CheckpointWrapper) elif activation_checkpointing_target == ['grouped_query_attention']: assert isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 1b40c715de..f9191cd701 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -10,7 +10,7 @@ from composer import Trainer from composer.loggers import MLFlowLogger -from composer.utils import dist, get_device +from composer.utils import dist, get_device, using_torch_2 from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM @@ -501,7 +501,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, 'drop_last': False, 'num_workers': 0, 'pin_memory': False, - 'prefetch_factor': None, + 'prefetch_factor': None if using_torch_2() else 2, 'persistent_workers': False, 'timeout': 0 } diff --git a/tests/test_model.py b/tests/test_model.py index 5e589dbd60..4d5b0a4dbc 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -650,8 +650,8 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict, right_pad_v_left_pad_atol = 1e-6 if attention_impl == 'torch' else 1e-8 if rope and pos_emb_config['rope_impl'] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. - right_pad_v_left_pad_rtol = 1e-2 - right_pad_v_left_pad_atol = 1e-2 + right_pad_v_left_pad_rtol = 2e-2 + right_pad_v_left_pad_atol = 2e-2 assert torch.allclose(right_padding_output[0, :3], left_padding_output[0, 3:], rtol=right_pad_v_left_pad_rtol,