diff --git a/tests/integration/test_eval_model.py b/tests/integration/test_eval_model.py index 474faac..3fee999 100644 --- a/tests/integration/test_eval_model.py +++ b/tests/integration/test_eval_model.py @@ -54,7 +54,12 @@ def test_model_loading(): assert cfg == model_ret.zanj_model_config assert_model_cfg_equality(model_ret, model_load_auto) - assert_model_output_equality(model_ret, model_load_auto) + vocab_size: int = len(model_ret.zanj_model_config.tokenizer) + assert_model_output_equality( + model_ret, + model_load_auto, + check_argsort_equality=(vocab_size > 2048), + ) # assert_model_exact_equality(model_ret, model_load_auto)