Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Nov 18, 2023
1 parent 84220de commit ae71955
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from composer.core.precision import Precision, get_precision_context
from composer.optim import DecoupledAdamW
from composer.trainer.dist_strategy import prepare_fsdp_module
from composer.utils import dist, get_device
from composer.utils import dist, get_device, reproducibility
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel,
Expand Down Expand Up @@ -1400,6 +1400,9 @@ def test_generation_kwargs_dont_crash(attn_impl: str,
f'dail implementation of rope requires gpu and flash attention 2.')
composer_device = get_device(None)

if composer_device.name == 'gpu':
torch.use_deterministic_algorithms(False)

hf_config = MPTConfig(
init_device='cpu',
d_model=128,
Expand Down Expand Up @@ -1434,6 +1437,9 @@ def test_generation_kwargs_dont_crash(attn_impl: str,
attention_mask=no_padding_attention_mask,
**generation_kwargs)

if composer_device.name == 'gpu':
reproducibility.configure_deterministic_mode()


@pytest.mark.gpu
@pytest.mark.parametrize('attention_impl', ['torch', 'flash', 'triton'])
Expand Down

0 comments on commit ae71955

Please sign in to comment.