diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 0746b79f56..f1d37eb3ca 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -50,7 +50,7 @@ def test_lora_s2_attn(self, temp_dir): }, ], "num_epochs": 2, - "micro_batch_size": 8, + "micro_batch_size": 1, "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, @@ -90,7 +90,7 @@ def test_fft_s2_attn(self, temp_dir): }, ], "num_epochs": 2, - "micro_batch_size": 8, + "micro_batch_size": 1, "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, diff --git a/tests/utils/test_models.py b/tests/utils/test_models.py index 0515a2e57a..bfa82ccd1a 100644 --- a/tests/utils/test_models.py +++ b/tests/utils/test_models.py @@ -28,6 +28,10 @@ def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): "axolotl.utils.models.load_model_config" ) as mocked_load_model_config: mocked_load_model_config.return_value = {} - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc: # Should error before hitting tokenizer, so we pass in an empty str load_model(cfg, tokenizer="") + assert ( + "shifted-sparse attention does not currently support sample packing" + in exc.value.message + )