diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4cbb4f173d..07c6f7001c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -293,6 +293,10 @@ def load_model( ) if model.get_input_embeddings().num_embeddings < embeddings_len: model.resize_token_embeddings(embeddings_len) + else: + model.vocab_size = len(tokenizer) + model.config.vocab_size = len(tokenizer) + model.tie_weights() if ( hasattr(model.config, "max_position_embeddings")