diff --git a/tests/unit-tests/test_embed.py b/tests/unit-tests/test_embed.py index 8b13f1fda9..5c9069b869 100644 --- a/tests/unit-tests/test_embed.py +++ b/tests/unit-tests/test_embed.py @@ -220,7 +220,8 @@ def test_lm_cache(): lm_models = LMModels(g, lm_config, 0, 10) lm_cache = LMCache(g, lm_models, tmpdirname) - lm_cache.update_cache(100) + ret = lm_cache.update_cache(100) + assert ret == True # This is the first time we need to compute the BERT embeddings. assert len(lm_cache) == 1 assert len(lm_cache.ntypes) == 1 assert lm_cache.ntypes[0] == 'n0'