diff --git a/test/model/mixtral/test_mixtral.py b/test/model/mixtral/test_mixtral.py index 93c9a41..2c14fb5 100644 --- a/test/model/mixtral/test_mixtral.py +++ b/test/model/mixtral/test_mixtral.py @@ -182,7 +182,7 @@ def compare_model_weights(self, base_model, model): continue if isinstance(param, DTensor): param = param.redistribute(param.device_mesh, [Replicate()], async_op=False)._local_tensor - torch.testing.assert_close(param, base_param, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(param, base_param, atol=2e-4, rtol=2e-4) @skip_unless_torch_gpu @with_comms