From 1d6a5e2bd638778a42d757ff0cb600f918eb1c31 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Fri, 25 Oct 2024 21:06:56 +0800 Subject: [PATCH] Refactor func load_model to class ModelLoader (#1909) --- cicd/cicd.sh | 2 +- src/axolotl/utils/models.py | 1136 +++++++++++++++++++--------------- tests/e2e/test_load_model.py | 95 +++ tests/utils/test_models.py | 91 ++- 4 files changed, 826 insertions(+), 498 deletions(-) create mode 100644 tests/e2e/test_load_model.py diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 104a8f84ab..483d62a7ad 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -1,6 +1,6 @@ #!/bin/bash set -e -pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ +pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/ pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c18af9760f..5e53df72cb 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -324,671 +324,823 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): return processor -def load_model( - cfg: DictDefault, - tokenizer: PreTrainedTokenizerBase, - *, - processor: ProcessorMixin = None, # pylint: disable=unused-argument - inference: bool = False, - reference_model: bool = False, - **kwargs, # pylint: disable=unused-argument -) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: +class ModelLoader: """ - Load a model for a given configuration and tokenizer. + ModelLoader: managing all the config and monkey patches while loading model """ - base_model = cfg.base_model - model_type = cfg.type_of_model - model_config = load_model_config(cfg) - - # load any patches from plugins - from axolotl.integrations.base import PluginManager + def __init__( + self, + cfg: DictDefault, + tokenizer: PreTrainedTokenizerBase, + *, + processor: ProcessorMixin = None, # pylint: disable=unused-argument + inference: bool = False, + reference_model: bool = False, + **kwargs, # pylint: disable=unused-argument + ) -> None: + self.cfg = cfg + self.tokenizer = tokenizer + self.inference: bool = inference + self.reference_model: bool = reference_model + + # init model kwargs + self.model_kwargs: Dict[str, Any] = {} + if cfg.model_kwargs: + for key, val in cfg.model_kwargs.items(): + self.model_kwargs[key] = val + + # init model + self.model: PreTrainedModel + self.base_model = cfg.base_model + self.model_type = cfg.type_of_model + + # init model config + self.model_config = load_model_config(cfg) + if cfg.is_multimodal: + self.text_model_config = self.model_config.text_config + else: + self.text_model_config = self.model_config - plugin_manager = PluginManager.get_instance() - plugin_manager.pre_model_load(cfg) + self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name - if cfg.is_multimodal: - text_model_config = model_config.text_config - else: - text_model_config = model_config + def apply_patches(self) -> None: + # load any patches from plugins + from axolotl.integrations.base import PluginManager - # TODO refactor as a kwarg - load_in_8bit = cfg.load_in_8bit + plugin_manager = PluginManager.get_instance() + plugin_manager.pre_model_load(self.cfg) - if cfg.gradient_checkpointing == "unsloth": - transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper + if self.cfg.gradient_checkpointing == "unsloth": + transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper - if hasattr(model_config, "model_type") and model_config.model_type == "mllama": - if cfg.flash_attention: - from axolotl.monkeypatch.attention.mllama import patch_mllama + if self.cfg.flash_attention: + self.patch_attention() - patch_mllama() + if self.cfg.sample_packing and self.cfg.s2_attention: + raise ValueError( + "Received `sample_packing=true` and `s2_attention=true`; however, \ + shifted-sparse attention does not currently support sample packing." + ) - if hasattr(model_config, "model_type") and model_config.model_type == "btlm": - if cfg.flash_attention: - from axolotl.monkeypatch.btlm_attn_hijack_flash import ( - replace_btlm_attn_with_flash_attn, + if ( + self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES + and self.cfg.flash_attention + and self.cfg.sample_packing + ): + patch_for_multipack( + self.cfg.model_config_type, + model_name=self.cfg.base_model, + is_remote_code=self.cfg.trust_remote_code, ) - replace_btlm_attn_with_flash_attn(cfg.base_model) + if self.cfg.is_llama_derived_model: + self.patch_loss() + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora - if ( - hasattr(model_config, "model_type") - and model_config.model_type == "stablelm_epoch" - ): - if cfg.flash_attention and cfg.sample_packing: - from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( - replace_stablelm_attn_with_flash_attn, + patch_self_attn_lora() + elif self.cfg.is_llama_derived_model: + self.patch_llama_derived_model() + + if ( + self.cfg.model_config_type == "mistral" + and self.cfg.flash_attn_cross_entropy_loss + ): + from axolotl.monkeypatch.mistral_attn_hijack_flash import ( + patch_mistral_cross_entropy, ) - replace_stablelm_attn_with_flash_attn(cfg.base_model) + patch_mistral_cross_entropy() - if cfg.sample_packing and cfg.s2_attention: - raise ValueError( - "Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not currently support sample packing." - ) + def patch_attention(self) -> None: + if hasattr(self.model_config, "model_type"): + if self.model_config.model_type == "mllama" and self.cfg.flash_attention: + from axolotl.monkeypatch.attention.mllama import patch_mllama - if ( - cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and cfg.flash_attention - and cfg.sample_packing - ): - patch_for_multipack( - cfg.model_config_type, - model_name=cfg.base_model, - is_remote_code=cfg.trust_remote_code, - ) + patch_mllama() - if cfg.is_llama_derived_model: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - patch_llama_cross_entropy, - patch_llama_rms_norm, - ) + if self.model_config.model_type == "btlm": + from axolotl.monkeypatch.btlm_attn_hijack_flash import ( + replace_btlm_attn_with_flash_attn, + ) - if cfg.flash_attn_cross_entropy: - patch_llama_cross_entropy() - if cfg.flash_attn_rms_norm: - patch_llama_rms_norm() - elif cfg.unsloth_rms_norm: - from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm - - patch_unsloth_layernorm() - if cfg.unsloth_cross_entropy_loss: - from axolotl.monkeypatch.unsloth_ import ( - integrate_cross_entropy_loss_patch, + replace_btlm_attn_with_flash_attn(self.cfg.base_model) + + if ( + self.model_config.model_type == "stablelm_epoch" + and self.cfg.sample_packing + ): + from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( + replace_stablelm_attn_with_flash_attn, ) - integrate_cross_entropy_loss_patch(model_type="llama") - if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora + replace_stablelm_attn_with_flash_attn(self.cfg.base_model) + + def patch_loss(self) -> None: + """ + Patch loss functions + """ + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_llama_cross_entropy, + patch_llama_rms_norm, + ) + + if self.cfg.flash_attn_cross_entropy: + patch_llama_cross_entropy() + if self.cfg.flash_attn_rms_norm: + patch_llama_rms_norm() + elif self.cfg.unsloth_rms_norm: + from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm + + patch_unsloth_layernorm() + if self.cfg.unsloth_cross_entropy_loss: + from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - patch_self_attn_lora() - elif cfg.is_llama_derived_model: - # Modify all llama derived models in one block + integrate_cross_entropy_loss_patch(model_type="llama") + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora + + patch_self_attn_lora() - if cfg.flash_attention: + def patch_llama_derived_model(self) -> None: + """ + Modify all llama derived models in one block + """ + + if self.cfg.flash_attention: from axolotl.monkeypatch.llama_attn_hijack_flash import ( replace_llama_attn_with_flash_attn, ) - if cfg.sample_packing: - if cfg.device not in ["mps", "cpu"] and not inference: + if self.cfg.sample_packing: + if self.cfg.device not in ["mps", "cpu"] and not self.inference: LOG.info("patching with flash attention for sample packing") replace_llama_attn_with_flash_attn( packed=True, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, ) - elif cfg.s2_attention: + elif self.cfg.s2_attention: LOG.info("patching w/ flash-enabled, shifted-sparse attention") replace_llama_attn_with_flash_attn( packed=False, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, use_shifted_sparse_attn=True, ) - elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm: + elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm: replace_llama_attn_with_flash_attn( packed=False, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, ) - elif cfg.xformers_attention: + elif self.cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_attention, ) LOG.info("patching with xformers attention") hijack_llama_attention() - elif cfg.sample_packing: + elif self.cfg.sample_packing: from axolotl.monkeypatch.llama_patch_multipack import ( hijack_llama_prepare_4d_mask, ) LOG.info("patching llama _prepare_4d_causal_attention_mask*") hijack_llama_prepare_4d_mask() - elif cfg.s2_attention: + elif self.cfg.s2_attention: raise NotImplementedError( "Shifted-sparse attention not currently implemented without flash attention." ) - if cfg.unsloth_cross_entropy_loss: + if self.cfg.unsloth_cross_entropy_loss: from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch integrate_cross_entropy_loss_patch(model_type="llama") - if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora patch_self_attn_lora() - # Modify mistral derived models - if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss: - from axolotl.monkeypatch.mistral_attn_hijack_flash import ( - patch_mistral_cross_entropy, - ) - - patch_mistral_cross_entropy() - - model_kwargs: Dict[str, Any] = {} - - if cfg.model_kwargs: - for key, val in cfg.model_kwargs.items(): - model_kwargs[key] = val + def set_auto_model_loader(self) -> None: + """set self.AutoModelLoader + - default value: AutoModelForCausalLM (set at __init__) + - when using a multi modality model, self.AutoModelLoader should + be set according to model type of the model + """ + if self.cfg.is_multimodal: + if self.model_config.model_type == "llava": + self.AutoModelLoader = ( # pylint: disable=invalid-name + LlavaForConditionalGeneration + ) + elif self.model_config.model_type == "mllama": + self.AutoModelLoader = ( # pylint: disable=invalid-name + MllamaForConditionalGeneration + ) + else: + self.AutoModelLoader = ( + AutoModelForVision2Seq # pylint: disable=invalid-name + ) - max_memory = cfg.max_memory - device_map = cfg.device_map + def set_device_map_config(self) -> None: + device_map = self.cfg.device_map + max_memory = self.cfg.max_memory - AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name - if cfg.is_multimodal: - if model_config.model_type == "llava": - AutoModelLoader = ( # pylint: disable=invalid-name - LlavaForConditionalGeneration - ) - elif model_config.model_type == "mllama": - AutoModelLoader = ( # pylint: disable=invalid-name - MllamaForConditionalGeneration + if self.cfg.gpu_memory_limit: + gpu_memory_limit = ( + str(self.cfg.gpu_memory_limit) + "GiB" + if isinstance(self.cfg.gpu_memory_limit, int) + else self.cfg.gpu_memory_limit ) - else: - AutoModelLoader = AutoModelForVision2Seq # pylint: disable=invalid-name - - if cfg.gpu_memory_limit: - gpu_memory_limit = ( - str(cfg.gpu_memory_limit) + "GiB" - if isinstance(cfg.gpu_memory_limit, int) - else cfg.gpu_memory_limit - ) - max_memory = {} - for i in range(torch.cuda.device_count()): - max_memory[i] = gpu_memory_limit - max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything + max_memory = {} + for i in range(torch.cuda.device_count()): + max_memory[i] = gpu_memory_limit + max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything - if max_memory is not None: - # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py - from accelerate import infer_auto_device_map + if max_memory is not None: + # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py + from accelerate import infer_auto_device_map - with init_empty_weights(): - model_canvas = AutoModelLoader.from_config( - model_config, trust_remote_code=cfg.trust_remote_code or False + with init_empty_weights(): + model_canvas = self.AutoModelLoader.from_config( + self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + ) + model_canvas.tie_weights() + device_map = infer_auto_device_map( + model_canvas, + max_memory=max_memory, + dtype=self.cfg.torch_dtype, ) - model_canvas.tie_weights() - device_map = infer_auto_device_map( - model_canvas, - max_memory=max_memory, - dtype=cfg.torch_dtype, - ) - # We can discard max_memory now as we have a device map set up for us - max_memory = None - - model_kwargs["device_map"] = device_map - model_kwargs["torch_dtype"] = cfg.torch_dtype - - if torch.backends.mps.is_available(): - model_kwargs["device_map"] = "mps:0" - - # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss - # if cfg.rl: - # if torch.cuda.device_count() > 1: - # if reference_model: - # model_kwargs["device_map"] = "cuda:" + str( - # torch.cuda.current_device() + 1 - # ) - # else: - # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) - - if is_deepspeed_zero3_enabled(): - del model_kwargs["device_map"] + # We can discard max_memory now as we have a device map set up for us + max_memory = None + + self.model_kwargs["device_map"] = device_map + self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype + + if torch.backends.mps.is_available(): + self.model_kwargs["device_map"] = "mps:0" + + # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss + # if cfg.rl: + # if torch.cuda.device_count() > 1: + # if reference_model: + # model_kwargs["device_map"] = "cuda:" + str( + # torch.cuda.current_device() + 1 + # ) + # else: + # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) + + if is_deepspeed_zero3_enabled(): + del self.model_kwargs["device_map"] + + def set_quantization_config(self) -> None: + self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit + self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit + + if self.cfg.gptq: + if not hasattr(self.model_config, "quantization_config"): + LOG.warning( + "model config does not contain quantization_config information" + ) + else: + if self.cfg.gptq_disable_exllama is not None: + self.model_config.quantization_config[ + "disable_exllama" + ] = self.cfg.gptq_disable_exllama + self.model_kwargs["quantization_config"] = GPTQConfig( + **self.model_config.quantization_config + ) + if ( + self.cfg.adapter in ["qlora", "lora"] + and hasattr(self.model_config, "quantization_config") + and self.model_config.quantization_config["quant_method"] + in ["gptq", "awq", "bitsandbytes"] + ): + if self.model_config.quantization_config["quant_method"] == "gptq": + self.model_kwargs["quantization_config"] = GPTQConfig( + **self.model_config.quantization_config + ) + elif self.model_config.quantization_config["quant_method"] == "awq": + self.model_kwargs["quantization_config"] = AwqConfig( + **self.model_config.quantization_config + ) + elif ( + self.model_config.quantization_config["quant_method"] == "bitsandbytes" + ): + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **self.model_config.quantization_config + ) + elif self.cfg.adapter == "qlora" and ( + "load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"] + ): + bnb_config = { + "load_in_4bit": True, + "llm_int8_threshold": 6.0, + "llm_int8_has_fp16_weight": False, + "bnb_4bit_compute_dtype": self.cfg.torch_dtype, + "bnb_4bit_use_double_quant": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_quant_storage": torch.bfloat16, + } + if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( + self.cfg.deepspeed or self.cfg.fsdp + ): + # for some reason, this causes the loss to be off by an order of magnitude + # but deepspeed needs this still in bfloat16 + bnb_config["bnb_4bit_quant_storage"] = torch.float32 - if cfg.revision_of_model: - model_kwargs["revision"] = cfg.revision_of_model + if self.cfg.bnb_config_kwargs: + bnb_config.update(self.cfg.bnb_config_kwargs) - if cfg.gptq: - if not hasattr(model_config, "quantization_config"): - LOG.warning("model config does not contain quantization_config information") - else: - if cfg.gptq_disable_exllama is not None: - model_config.quantization_config[ - "disable_exllama" - ] = cfg.gptq_disable_exllama - model_kwargs["quantization_config"] = GPTQConfig( - **model_config.quantization_config + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, ) - if ( - cfg.adapter in ["qlora", "lora"] - and hasattr(model_config, "quantization_config") - and model_config.quantization_config["quant_method"] - in ["gptq", "awq", "bitsandbytes"] - ): - if model_config.quantization_config["quant_method"] == "gptq": - model_kwargs["quantization_config"] = GPTQConfig( - **model_config.quantization_config + elif self.cfg.adapter == "lora" and ( + "load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"] + ): + bnb_config = { + "load_in_8bit": True, + } + # Exclude mamba blocks from int8 quantization for jamba + if self.cfg.model_config_type == "jamba": + bnb_config["llm_int8_skip_modules"] = ["mamba"] + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, ) - elif model_config.quantization_config["quant_method"] == "awq": - model_kwargs["quantization_config"] = AwqConfig( - **model_config.quantization_config + + # no longer needed per https://github.com/huggingface/transformers/pull/26610 + if "quantization_config" in self.model_kwargs or self.cfg.gptq: + if "load_in_8bit" in self.model_kwargs: + del self.model_kwargs["load_in_8bit"] + if "load_in_4bit" in self.model_kwargs: + del self.model_kwargs["load_in_4bit"] + + def set_attention_config(self) -> None: + """ + sample packing uses custom FA2 patch + """ + if self.cfg.flash_attention: + if not self.cfg.sample_packing and self.cfg.s2_attention: + pass + self.model_kwargs["attn_implementation"] = "flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" ) - elif model_config.quantization_config["quant_method"] == "bitsandbytes": - model_kwargs["quantization_config"] = BitsAndBytesConfig( - **model_config.quantization_config + elif self.cfg.sdp_attention: + self.model_kwargs["attn_implementation"] = "sdpa" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "sdpa" + ) + elif self.cfg.eager_attention: + self.model_kwargs["attn_implementation"] = "eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "eager" ) - elif cfg.adapter == "qlora" and cfg.load_in_4bit: - bnb_config = { - "load_in_4bit": True, - "llm_int8_threshold": 6.0, - "llm_int8_has_fp16_weight": False, - "bnb_4bit_compute_dtype": cfg.torch_dtype, - "bnb_4bit_use_double_quant": True, - "bnb_4bit_quant_type": "nf4", - "bnb_4bit_quant_storage": torch.bfloat16, - } - if cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( - cfg.deepspeed or cfg.fsdp - ): - # for some reason, this causes the loss to be off by an order of magnitude - # but deepspeed needs this still in bfloat16 - bnb_config["bnb_4bit_quant_storage"] = torch.float32 - - if cfg.bnb_config_kwargs: - bnb_config.update(cfg.bnb_config_kwargs) - - model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - elif cfg.adapter == "lora" and cfg.load_in_8bit: - bnb_config = { - "load_in_8bit": True, - } - # Exclude mamba blocks from int8 quantization for jamba - if cfg.model_config_type == "jamba": - bnb_config["llm_int8_skip_modules"] = ["mamba"] - model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - - if cfg.load_in_8bit and cfg.adapter is not None: - model_kwargs["load_in_8bit"] = True - if cfg.load_in_4bit and cfg.adapter is not None: - model_kwargs["load_in_4bit"] = True - - # no longer needed per https://github.com/huggingface/transformers/pull/26610 - if "quantization_config" in model_kwargs or cfg.gptq: - if "load_in_8bit" in model_kwargs: - del model_kwargs["load_in_8bit"] - if "load_in_4bit" in model_kwargs: - del model_kwargs["load_in_4bit"] - - # sample packing uses custom FA2 patch - if cfg.flash_attention: - if not cfg.sample_packing and cfg.s2_attention: - pass - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - elif cfg.sdp_attention: - model_kwargs["attn_implementation"] = "sdpa" - model_config._attn_implementation = "sdpa" # pylint: disable=protected-access - elif cfg.eager_attention: - model_kwargs["attn_implementation"] = "eager" - model_config._attn_implementation = "eager" # pylint: disable=protected-access - - if cfg.low_cpu_mem_usage: - model_kwargs["low_cpu_mem_usage"] = True - qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora" + if self.cfg.low_cpu_mem_usage: + self.model_kwargs["low_cpu_mem_usage"] = True - try: + def build_model(self, qlora_fsdp) -> bool: skip_move_to_device = False if ( # pylint: disable=condition-evals-to-constant) - (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) + (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) and not qlora_fsdp and False ): - model = load_sharded_model( - base_model, - model_config, - cfg, - torch_dtype=cfg.torch_dtype, + self.model = load_sharded_model( + self.base_model, + self.model_config, + self.cfg, + torch_dtype=self.cfg.torch_dtype, ) skip_move_to_device = True elif ( qlora_fsdp - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading) + and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and ( + self.cfg.model_config_type == "dbrx" + or self.cfg.qlora_sharded_model_loading + ) ): - quant_storage = cfg.torch_dtype + quant_storage = self.cfg.torch_dtype quantization_config = hasattr( - model_config, "quantization_config" - ) and getattr(model_config, "quantization_config") + self.model_config, "quantization_config" + ) and getattr(self.model_config, "quantization_config") quantization_config = ( - quantization_config or model_kwargs["quantization_config"] + quantization_config or self.model_kwargs["quantization_config"] ) - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = load_sharded_model_quant( - base_model, - model_config, - cfg, + if self.cfg.is_multimodal: + self.model_config.text_config = self.text_model_config + self.model = load_sharded_model_quant( + self.base_model, + self.model_config, + self.cfg, quant_storage=quant_storage, quantization_config=quantization_config, ) skip_move_to_device = True elif ( - model_config.model_type == "llama" - and not cfg.trust_remote_code - and not cfg.gptq + self.model_config.model_type == "llama" + and not self.cfg.trust_remote_code + and not self.cfg.gptq ): - if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: skip_move_to_device = True - if "device_map" in model_kwargs: - del model_kwargs["device_map"] - - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = AutoModelLoader.from_pretrained( - base_model, - config=model_config, - **model_kwargs, + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + if self.cfg.is_multimodal: + self.model_config.text_config = self.text_model_config + self.model = self.AutoModelLoader.from_pretrained( + self.base_model, + config=self.model_config, + **self.model_kwargs, ) - if cfg.flash_attention and not inference: + # TODO (MengqingCao) split these patches seperately + if self.cfg.flash_attention and not self.inference: from axolotl.monkeypatch.llama_attn_hijack_flash import ( is_xformers_swiglu_available, replace_llama_mlp_with_swiglu, replace_llama_qkv_with_fused, ) - if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): + if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): LOG.info("patching with SwiGLU") - replace_llama_mlp_with_swiglu(model) + replace_llama_mlp_with_swiglu(self.model) - if cfg.flash_attn_fuse_qkv: + if self.cfg.flash_attn_fuse_qkv: LOG.info("patching with fused QKV") - replace_llama_qkv_with_fused(model) - elif model_type == "MambaLMHeadModel": + replace_llama_qkv_with_fused(self.model) + elif self.model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name - model_kwargs["dtype"] = model_kwargs["torch_dtype"] - model_kwargs["device"] = torch.cuda.current_device() - del model_kwargs["torch_dtype"] - del model_kwargs["device_map"] + self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"] + self.model_kwargs["device"] = torch.cuda.current_device() + del self.model_kwargs["torch_dtype"] + del self.model_kwargs["device_map"] - model = MambaLMHeadModel.from_pretrained( - base_model, - **model_kwargs, + self.model = MambaLMHeadModel.from_pretrained( + self.base_model, + **self.model_kwargs, ) elif ( - model_type - and model_type != "AutoModelForCausalLM" - and not cfg.trust_remote_code + self.model_type + and self.model_type != "AutoModelForCausalLM" + and not self.cfg.trust_remote_code ): - if cfg.gptq: - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = AutoModelLoader.from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + if self.cfg.is_multimodal: + self.model_config.text_config = self.text_model_config + if self.cfg.gptq: + self.model = self.AutoModelLoader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) else: - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = getattr(transformers, model_type).from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + self.model = getattr(transformers, self.model_type).from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) else: # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # when training starts if ( - hasattr(text_model_config, "max_seq_len") - and text_model_config.max_seq_len - and cfg.sequence_len > model_config.max_seq_len + hasattr(self.text_model_config, "max_seq_len") + and self.text_model_config.max_seq_len + and self.cfg.sequence_len > self.text_model_config.max_seq_len ): - text_model_config.max_seq_len = cfg.sequence_len - LOG.warning(f"increasing context length to {cfg.sequence_len}") + self.text_model_config.max_seq_len = self.cfg.sequence_len + LOG.warning(f"increasing context length to {self.cfg.sequence_len}") elif ( - hasattr(text_model_config, "max_sequence_length") - and text_model_config.max_sequence_length - and cfg.sequence_len > text_model_config.max_sequence_length + hasattr(self.text_model_config, "max_sequence_length") + and self.text_model_config.max_sequence_length + and self.cfg.sequence_len > self.text_model_config.max_sequence_length ): - text_model_config.max_sequence_length = cfg.sequence_len - LOG.warning(f"increasing context length to {cfg.sequence_len}") - if cfg.gptq: - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = AutoModelLoader.from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + self.text_model_config.max_sequence_length = self.cfg.sequence_len + LOG.warning(f"increasing context length to {self.cfg.sequence_len}") + if self.cfg.gptq: + if self.cfg.is_multimodal: + self.model_config.text_config = self.text_model_config + self.model = self.AutoModelLoader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) else: - if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if ( + self.cfg.fsdp + and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + ): # disabling either of these two still leads to VRAM spike before setting back down skip_move_to_device = True - if "device_map" in model_kwargs: - del model_kwargs["device_map"] - - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = AutoModelLoader.from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + if self.cfg.is_multimodal: + self.model_config.text_config = self.text_model_config + self.model = self.AutoModelLoader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) - except Exception as err: # pylint: disable=broad-exception-caught - LOG.exception(err) - raise err - - if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: - model = model.merge_and_unload() - - embeddings_len = ( - math.ceil(len(tokenizer) / 32) * 32 - if cfg.resize_token_embeddings_to_32x - else len(tokenizer) - ) - if ( - hasattr(model, "get_input_embeddings") - and model.get_input_embeddings().num_embeddings < embeddings_len - ): - model.resize_token_embeddings(embeddings_len) - else: - model.tie_weights() + if is_deepspeed_zero3_enabled(): + skip_move_to_device = True - if ( - hasattr(model, "config") - and hasattr(model.config, "max_position_embeddings") - and model.config.max_position_embeddings - and cfg.sequence_len > model.config.max_position_embeddings - ): - LOG.warning( - f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}" - ) - model.config.max_position_embeddings = cfg.sequence_len + return skip_move_to_device - if ( - hasattr(model, "config") - and hasattr(model.config, "bos_token_id") - and model.config.bos_token_id - and model.config.bos_token_id != tokenizer.bos_token_id - ): - model.config.bos_token_id = tokenizer.bos_token_id + def ajust_model_config(self) -> None: + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "max_position_embeddings") + and self.model.config.max_position_embeddings + and self.cfg.sequence_len > self.model.config.max_position_embeddings + ): + LOG.warning( + f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" + ) + self.model.config.max_position_embeddings = self.cfg.sequence_len - if ( - hasattr(model, "config") - and hasattr(model.config, "eos_token_id") - and model.config.eos_token_id - and model.config.eos_token_id != tokenizer.eos_token_id - ): - model.config.eos_token_id = tokenizer.eos_token_id - - if hasattr(model, "device") and model.device.type in ("cuda", "mps"): - log_gpu_memory_usage(LOG, "after model load", model.device) - - # make sure these are fp32 per Ramesh et al. (2021) - embedding_modules = get_linear_embedding_layers(cfg.model_config_type) - if not cfg.fsdp: - # FSDP doesn't like mixed Float and BFloat16 - for name, module in model.named_modules(): - if "norm" in name or name.endswith(".gate"): - module.to(torch.float32) - if model_config.model_type == "btlm": - # don't upcast lm_head for btlm - continue - if any(m in name for m in embedding_modules): - if hasattr(module, "weight"): - module.to(torch.float32) + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "bos_token_id") + and self.model.config.bos_token_id + and self.model.config.bos_token_id != self.tokenizer.bos_token_id + ): + self.model.config.bos_token_id = self.tokenizer.bos_token_id - needs_fa2_dtype = cfg.adapter or cfg.fsdp - skip_prepare_model_for_kbit_training = False + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "eos_token_id") + and self.model.config.eos_token_id + and self.model.config.eos_token_id != self.tokenizer.eos_token_id + ): + self.model.config.eos_token_id = self.tokenizer.eos_token_id - if is_deepspeed_zero3_enabled(): + def set_z3_leaf_modules(self) -> None: from deepspeed.utils import ( # pylint: disable=no-name-in-module set_z3_leaf_modules, ) - if cfg.model_config_type in MOE_ARCH_BLOCK: - moe_blocks = MOE_ARCH_BLOCK[cfg.model_config_type] + if self.cfg.model_config_type in MOE_ARCH_BLOCK: + moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type] moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks set_z3_leaf_modules( - model, + self.model, [ - get_module_class_from_name(model, module_name) + get_module_class_from_name(self.model, module_name) for module_name in moe_blocks ], ) - if cfg.model_config_type == "qwen" and cfg.adapter == "lora": - # Qwen doesn't play nicely with LoRA if this is enabled - skip_prepare_model_for_kbit_training = True + def prepare_model(self, qlora_fsdp) -> None: + skip_prepare_model_for_kbit_training = False + if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora": + # Qwen doesn't play nicely with LoRA if this is enabled + skip_prepare_model_for_kbit_training = True - loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits - if cfg.adapter == "lora" and loftq_bits: - skip_prepare_model_for_kbit_training = True + loftq_bits = ( + self.cfg.peft + and self.cfg.peft.loftq_config + and self.cfg.peft.loftq_config.loftq_bits + ) + if self.cfg.adapter == "lora" and loftq_bits: + skip_prepare_model_for_kbit_training = True - if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading): - # make sure everything is in the same dtype - skip_prepare_model_for_kbit_training = True + if qlora_fsdp or ( + self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + ): + # make sure everything is in the same dtype + skip_prepare_model_for_kbit_training = True - if is_deepspeed_zero3_enabled(): - skip_prepare_model_for_kbit_training = True + if is_deepspeed_zero3_enabled(): + skip_prepare_model_for_kbit_training = True + + is_load_in_8bit = ( + "load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"] + ) + is_load_in_4bit = ( + "load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"] + ) - if cfg.adapter in ["lora", "qlora"]: - if cfg.gradient_checkpointing: - model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs - ) if ( - cfg.load_in_8bit or cfg.load_in_4bit - ) and not skip_prepare_model_for_kbit_training: + not skip_prepare_model_for_kbit_training + and self.cfg.adapter in ["lora", "qlora"] + and (is_load_in_8bit or is_load_in_4bit) + ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") - model = prepare_model_for_kbit_training( - model, use_gradient_checkpointing=cfg.gradient_checkpointing + self.model = prepare_model_for_kbit_training( + self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing ) - needs_fa2_dtype = True - # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp: - LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) - for name, module in model.named_modules(): + def convert_embedding_modules_dtype( + self, embedding_modules, dist_dtype, before_kbit_train_or_finetune + ) -> None: + for name, module in self.model.named_modules(): if "norm" in name: - module.to(cfg.torch_dtype) + module.to(dist_dtype) + if before_kbit_train_or_finetune: + if name.endswith(".gate"): + module.to(dist_dtype) + if self.model_config.model_type == "btlm": + # don't upcast lm_head for btlm + continue if any(m in name for m in embedding_modules): if hasattr(module, "weight"): - module.to(cfg.torch_dtype) - - lora_config = None - if not reference_model or cfg.lora_model_dir: - # if we're not loading the reference model, then we're loading the model for training - # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config - if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto"] and not cfg.merge_lora: - _, lora_config = load_lora(model, cfg, inference=False, config_only=True) + module.to(dist_dtype) + + def apply_lora_patch(self) -> None: + if self.cfg.unsloth_lora_mlp: + from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch + + integrate_lora_mlp_patch(self.model) + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import integrate_lora_patch + + integrate_lora_patch(self.model, self.cfg) + if self.cfg.unsloth_rope: + from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings + + integrate_rope_embeddings() + + def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: + self.apply_patches() + self.set_auto_model_loader() + self.set_device_map_config() + if self.cfg.revision_of_model: + self.model_kwargs["revision"] = self.cfg.revision_of_model + self.set_quantization_config() + self.set_attention_config() + + qlora_fsdp = self.cfg.fsdp and self.cfg.adapter == "qlora" + skip_move_to_device = False + + try: + skip_move_to_device = self.build_model(qlora_fsdp) + except Exception as err: # pylint: disable=broad-exception-caught + LOG.exception(err) + raise err + + if isinstance(self.model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: + self.model = self.model.merge_and_unload() + + embeddings_len = ( + math.ceil(len(self.tokenizer) / 32) * 32 + if self.cfg.resize_token_embeddings_to_32x + else len(self.tokenizer) + ) + if ( + hasattr(self.model, "get_input_embeddings") + and self.model.get_input_embeddings().num_embeddings < embeddings_len + ): + self.model.resize_token_embeddings(embeddings_len) else: - model, lora_config = load_adapter(model, cfg, cfg.adapter) + self.model.tie_weights() + + self.ajust_model_config() + + # log device memory usage + if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"): + log_gpu_memory_usage(LOG, "after model load", self.model.device) + + # make sure these are fp32 per Ramesh et al. (2021) + embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) + if not self.cfg.fsdp: + # FSDP doesn't like mixed Float and BFloat16 + self.convert_embedding_modules_dtype( + embedding_modules, + dist_dtype=torch.float32, + before_kbit_train_or_finetune=True, + ) - if is_deepspeed_zero3_enabled(): - skip_move_to_device = True + if is_deepspeed_zero3_enabled(): + self.set_z3_leaf_modules() - if ( - cfg.ddp - and not load_in_8bit - and not (cfg.rl and cfg.load_in_4bit) - and not skip_move_to_device - ): - # TODO revaldate this conditional - model.to(f"cuda:{cfg.local_rank}") + needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp + if self.cfg.adapter in ["lora", "qlora"]: + needs_fa2_dtype = True + if self.cfg.gradient_checkpointing: + self.model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs + ) + + self.prepare_model(qlora_fsdp) - if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: - setattr(model, "is_parallelizable", True) - setattr(model, "model_parallel", True) + # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + if (needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp: + LOG.info( + "converting modules to %s for flash attention", self.cfg.torch_dtype + ) + self.convert_embedding_modules_dtype( + embedding_modules, + dist_dtype=self.cfg.torch_dtype, + before_kbit_train_or_finetune=False, + ) + + # --------------------------------------------------------- + # load lora or adapter + # --------------------------------------------------------- + lora_config = None + if not self.reference_model or self.cfg.lora_model_dir: + # if we're not loading the reference model, then we're loading the model for training + # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config + if ( + self.cfg.adapter + and self.cfg.rl in ["dpo", "ipo", "kto"] + and not self.cfg.merge_lora + ): + _, lora_config = load_lora( + self.model, self.cfg, inference=False, config_only=True + ) + else: + self.model, lora_config = load_adapter( + self.model, self.cfg, self.cfg.adapter + ) - requires_grad = [] - for name, param in model.named_parameters(recurse=True): - if param.requires_grad: - requires_grad.append(f"{name}: {param.requires_grad}") - if len(requires_grad) == 0: - LOG.warning("there are no parameters that require gradient updates") - if hasattr(model, "config"): - model.config.use_cache = False + # --------------------------------------------------------- + # put model to accelerator + # --------------------------------------------------------- + is_load_in_8bit = ( + "load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"] + ) + is_load_in_4bit = ( + "load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"] + ) + if ( + self.cfg.ddp + and not is_load_in_8bit + and not (self.cfg.rl and is_load_in_4bit) + and not skip_move_to_device + ): + # TODO revaldate this conditional + self.model.to(f"cuda:{self.cfg.local_rank}") - if cfg.flash_optimum: - from optimum.bettertransformer import BetterTransformer + if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: + setattr(self.model, "is_parallelizable", True) + setattr(self.model, "model_parallel", True) - model = BetterTransformer.transform(model) + # --------------------------------------------------------- + # parameters that require gradient updates + # --------------------------------------------------------- + requires_grad = [] + for name, param in self.model.named_parameters(recurse=True): + if param.requires_grad: + requires_grad.append(f"{name}: {param.requires_grad}") + if len(requires_grad) == 0: + LOG.warning("there are no parameters that require gradient updates") + if hasattr(self.model, "config"): + self.model.config.use_cache = False - if cfg.adapter is not None: - log_gpu_memory_usage(LOG, "after adapters", model.device) + if self.cfg.flash_optimum: + from optimum.bettertransformer import BetterTransformer - if cfg.unsloth_lora_mlp: - from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch + self.model = BetterTransformer.transform(self.model) - integrate_lora_mlp_patch(model) - if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import integrate_lora_patch + if self.cfg.adapter is not None: + log_gpu_memory_usage(LOG, "after adapters", self.model.device) - integrate_lora_patch(model, cfg) + self.apply_lora_patch() - if cfg.unsloth_rope: - from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() - integrate_rope_embeddings() + # TODO resume_from_checkpoint handling + return self.model, lora_config - for _ in range(3): - gc.collect() - torch.cuda.empty_cache() - # TODO resume_from_checkpoint handling - return model, lora_config +def load_model( + cfg: DictDefault, + tokenizer: PreTrainedTokenizerBase, + *, + processor: ProcessorMixin = None, # pylint: disable=unused-argument + inference: bool = False, + reference_model: bool = False, + **kwargs, # pylint: disable=unused-argument +) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: + """ + Load a model for a given configuration and tokenizer. + """ + loader = ModelLoader( + cfg, + tokenizer, + processor=processor, + inference=inference, + reference_model=reference_model, + **kwargs, + ) + return loader.load_model() def load_adapter(model, cfg, adapter, inference=False): diff --git a/tests/e2e/test_load_model.py b/tests/e2e/test_load_model.py new file mode 100644 index 0000000000..31a9b1a878 --- /dev/null +++ b/tests/e2e/test_load_model.py @@ -0,0 +1,95 @@ +"""Module for testing ModelLoader.""" + +import shutil +import tempfile + +import pytest +import torch + +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import ModelLoader, load_model, load_tokenizer + + +@pytest.fixture(name="temp_dir") +def fixture_temp_dir(): + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir) + + +class TestLoadModelUtils: + """ + Testing module testing ModelLoader. + """ + + def setup_method(self): + # load config + self.cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "tokenizer_config": "JackFram/llama-68m", + "sequence_len": 1024, + "load_in_8bit": False, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + } + ) + self.model_loader = ( # pylint: disable=attribute-defined-outside-init + ModelLoader( + cfg=self.cfg, + tokenizer="", + ) + ) + + @pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"]) + @pytest.mark.parametrize( + "dist_dtype", [torch.bfloat16, torch.float16, torch.float32] + ) + @pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False]) + def test_convert_embedding_modules_dtype( + self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune + ): + self.cfg.output_dir = temp_dir + self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all + self.model_loader.model, _ = load_model( + self.cfg, + self.model_loader.tokenizer, + inference=False, + reference_model=True, + ) + self.model_loader.convert_embedding_modules_dtype( + embedding_modules, dist_dtype, before_kbit_train_or_finetune + ) + for name, module in self.model_loader.model.named_modules(): + if ( + "norm" in name + or (before_kbit_train_or_finetune and name.endswith(".gate")) + or ( + any(m in name for m in embedding_modules) + and hasattr(module, "weight") + ) + ): + for _, param in module.named_parameters(): + assert param.dtype == dist_dtype diff --git a/tests/utils/test_models.py b/tests/utils/test_models.py index e06bb6c250..31698f05fb 100644 --- a/tests/utils/test_models.py +++ b/tests/utils/test_models.py @@ -1,18 +1,64 @@ """Module for testing models utils file.""" - -import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest +from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.utils.import_utils import is_torch_mps_available from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model +from axolotl.utils.models import ModelLoader, load_model -class ModelsUtilsTest(unittest.TestCase): +class TestModelsUtils: """Testing module for models utils.""" + def setup_method(self) -> None: + # load config + self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init + { + "base_model": "JackFram/llama-68m", + "model_type": "LlamaForCausalLM", + "tokenizer_type": "LlamaTokenizer", + "load_in_8bit": True, + "load_in_4bit": False, + "adapter": "lora", + "flash_attention": False, + "sample_packing": True, + "device_map": "auto", + } + ) + self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init + spec=PreTrainedTokenizerBase + ) + self.inference = False # pylint: disable=attribute-defined-outside-init + self.reference_model = True # pylint: disable=attribute-defined-outside-init + + # init ModelLoader + self.model_loader = ( # pylint: disable=attribute-defined-outside-init + ModelLoader( + cfg=self.cfg, + tokenizer=self.tokenizer, + inference=self.inference, + reference_model=self.reference_model, + ) + ) + + def test_set_device_map_config(self): + # check device_map + device_map = self.cfg.device_map + if is_torch_mps_available(): + device_map = "mps" + self.model_loader.set_device_map_config() + if is_deepspeed_zero3_enabled(): + assert "device_map" not in self.model_loader.model_kwargs + else: + assert device_map in self.model_loader.model_kwargs["device_map"] + + # check torch_dtype + assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"] + def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): cfg = DictDefault( { @@ -35,3 +81,38 @@ def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): "shifted-sparse attention does not currently support sample packing" in str(exc.value) ) + + @pytest.mark.parametrize("adapter", ["lora", "qlora", None]) + @pytest.mark.parametrize("load_in_8bit", [True, False]) + @pytest.mark.parametrize("load_in_4bit", [True, False]) + @pytest.mark.parametrize("gptq", [True, False]) + def test_set_quantization_config( + self, + adapter, + load_in_8bit, + load_in_4bit, + gptq, + ): + # init cfg as args + self.cfg.load_in_8bit = load_in_8bit + self.cfg.load_in_4bit = load_in_4bit + self.cfg.gptq = gptq + self.cfg.adapter = adapter + + self.model_loader.set_quantization_config() + if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq: + assert not ( + hasattr(self.model_loader.model_kwargs, "load_in_8bit") + and hasattr(self.model_loader.model_kwargs, "load_in_4bit") + ) + elif load_in_8bit and self.cfg.adapter is not None: + assert self.model_loader.model_kwargs["load_in_8bit"] + elif load_in_4bit and self.cfg.adapter is not None: + assert self.model_loader.model_kwargs["load_in_4bit"] + + if (self.cfg.adapter == "qlora" and load_in_4bit) or ( + self.cfg.adapter == "lora" and load_in_8bit + ): + assert self.model_loader.model_kwargs.get( + "quantization_config", BitsAndBytesConfig + )