From ae719553116a80077e5d84a94ea780e94853b9cf Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 17 Nov 2023 22:39:49 -0800 Subject: [PATCH] fix --- tests/test_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index c38db7c9f7..c160c064dc 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -16,7 +16,7 @@ from composer.core.precision import Precision, get_precision_context from composer.optim import DecoupledAdamW from composer.trainer.dist_strategy import prepare_fsdp_module -from composer.utils import dist, get_device +from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, @@ -1400,6 +1400,9 @@ def test_generation_kwargs_dont_crash(attn_impl: str, f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(None) + if composer_device.name == 'gpu': + torch.use_deterministic_algorithms(False) + hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1434,6 +1437,9 @@ def test_generation_kwargs_dont_crash(attn_impl: str, attention_mask=no_padding_attention_mask, **generation_kwargs) + if composer_device.name == 'gpu': + reproducibility.configure_deterministic_mode() + @pytest.mark.gpu @pytest.mark.parametrize('attention_impl', ['torch', 'flash', 'triton'])