From 3a589b1aeb6da4ce141c0a44c8a385a42c07a6af Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 20 Aug 2024 14:23:01 -0600 Subject: [PATCH] fix failing model loading tests turns out that the outputs were actually equal to within tolerances, but argsort for large enough vocabularies fails (since weight recovery is not perfect and introduces some error) --- .../training/zanj/test_zanj_ht_save_load.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/unit/maze_transformer/training/zanj/test_zanj_ht_save_load.py b/tests/unit/maze_transformer/training/zanj/test_zanj_ht_save_load.py index ff713533..2db1d8da 100644 --- a/tests/unit/maze_transformer/training/zanj/test_zanj_ht_save_load.py +++ b/tests/unit/maze_transformer/training/zanj/test_zanj_ht_save_load.py @@ -37,19 +37,19 @@ ( "raster", MazeTokenizer( - tokenization_mode=TokenizationMode.AOTP_UT_rasterized, max_grid_size=10 + tokenization_mode=TokenizationMode.AOTP_UT_rasterized, max_grid_size=5 ), ), ( "uniform", MazeTokenizer( - tokenization_mode=TokenizationMode.AOTP_UT_uniform, max_grid_size=10 + tokenization_mode=TokenizationMode.AOTP_UT_uniform, max_grid_size=5 ), ), ( "indexed", MazeTokenizer( - tokenization_mode=TokenizationMode.AOTP_CTT_indexed, max_grid_size=10 + tokenization_mode=TokenizationMode.AOTP_CTT_indexed, max_grid_size=5 ), ), ("modular", MazeTokenizerModular()), # only checking default for now @@ -106,7 +106,12 @@ def test_model_save_fold_ln(cfg_model: tuple[ConfigHolder, ZanjHookedTransformer zanj.save(model, fname) model_load = zanj.read(fname) - assert_model_output_equality(model, model_load) + vocab_size: int = len(model.zanj_model_config.tokenizer) + assert_model_output_equality( + model, + model_load, + check_argsort_equality=(vocab_size > 2048), + ) @pytest.mark.parametrize("cfg_model", MODELS, ids=lambda x: x[0].name) @@ -131,4 +136,9 @@ def test_model_save_refactored_attn_matrices( zanj.save(model, fname) model_load = zanj.read(fname) - assert_model_output_equality(model, model_load) + vocab_size: int = len(model.zanj_model_config.tokenizer) + assert_model_output_equality( + model, + model_load, + check_argsort_equality=(vocab_size > 2048), + )