diff --git a/tests/test_model.py b/tests/test_model.py index 18ce7190a2..3308c65fd3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1259,7 +1259,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict, torch.testing.assert_close( second_output.logits, full_output.logits[:, -1, :].unsqueeze(1), - atol=1e-1, + atol=1.1e-2, rtol=1e-2, )