From 1ffa3866f2500fce827bc60f3907a2103ba3ac54 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 22 Dec 2023 21:49:07 +0900 Subject: [PATCH] Feat: Warns to add to modules_to_save when adding tokens or switching special_tokens (#787) * Feat: Auto add to modules_to_save when adding tokens * fix: swap to error instead of warning * feat: add check when special_tokens differ and add test --- src/axolotl/utils/config.py | 14 ++++++++++++++ src/axolotl/utils/models.py | 17 +++++++++++++++++ tests/test_tokenizers.py | 36 ++++++++++++++++++++++++++++++++++++ tests/test_validation.py | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 1b4ce92465..d9e56b95a6 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -448,6 +448,20 @@ def validate_config(cfg): if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0: raise ValueError("neftune_noise_alpha must be > 0.0") + if ( + cfg.adapter + and cfg.tokens + and ( + not cfg.lora_modules_to_save + or not all( + x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"] + ) + ) + ): + raise ValueError( + "lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`." + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 022229af85..8cb9e8426a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -136,6 +136,23 @@ def load_tokenizer(cfg): if cfg.special_tokens: for k, val in cfg.special_tokens.items(): + # check if new special token is not already in tokenizer and + # is adapter training to make sure lora_modules_to_save is set + if ( + (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) + and cfg.adapter + and ( + not cfg.lora_modules_to_save + or not all( + x in cfg.lora_modules_to_save + for x in ["embed_tokens", "lm_head"] + ) + ) + ): + raise ValueError( + "Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens." + ) + tokenizer.add_special_tokens( {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} ) diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index 5c83391942..bfe4f06af9 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -3,6 +3,8 @@ """ import unittest +import pytest + from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_tokenizer @@ -31,6 +33,40 @@ def test_dont_use_fast(self): tokenizer = load_tokenizer(cfg) assert "Fast" not in tokenizer.__class__.__name__ + def test_special_tokens_modules_to_save(self): + # setting special_tokens to new token + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "adapter": "lora", + "special_tokens": {"bos_token": "[INST]"}, + } + ) + with pytest.raises( + ValueError, + match=r".*Please set lora_modules_to_save*", + ): + load_tokenizer(cfg) + + # setting special_tokens but not changing from default + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "adapter": "lora", + "special_tokens": {"bos_token": ""}, + } + ) + load_tokenizer(cfg) + + # non-adapter setting special_tokens + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "special_tokens": {"bos_token": "[INST]"}, + } + ) + load_tokenizer(cfg) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_validation.py b/tests/test_validation.py index fabc23da33..12997b023b 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -682,6 +682,43 @@ def test_warmup_step_no_conflict(self): validate_config(cfg) + def test_add_tokens_adapter(self): + cfg = DictDefault( + {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} + ) + + with pytest.raises( + ValueError, + match=r".*lora_modules_to_save not properly set yet adding new tokens*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens"], + } + ) + + with pytest.raises( + ValueError, + match=r".*lora_modules_to_save not properly set yet adding new tokens*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens", "lm_head"], + } + ) + + validate_config(cfg) + class ValidationWandbTest(ValidationTest): """