diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6e520bd50e..6c8e7b8f0c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -176,6 +176,10 @@ def load_model( hijack_expand_mask() model_kwargs = {} + + model_kwargs["device_map"] = cfg.device_map + model_kwargs["torch_dtype"] = cfg.torch_dtype + if cfg.model_revision: model_kwargs["revision"] = cfg.model_revision if cfg.gptq: @@ -206,6 +210,7 @@ def load_model( or cfg.is_mistral_derived_model ): model_kwargs["use_flash_attention_2"] = True + try: if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: from transformers import LlamaForCausalLM @@ -220,10 +225,8 @@ def load_model( model = LlamaForCausalLM.from_pretrained( base_model, config=config, - device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, **model_kwargs, ) # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: @@ -257,28 +260,22 @@ def load_model( model = MixFormerSequentialForCausalLM.from_pretrained( base_model, - device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, **model_kwargs, ) elif model_type and not cfg.trust_remote_code: if cfg.gptq: model = AutoModelForCausalLM.from_pretrained( base_model, - device_map=cfg.device_map, - torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) else: model = getattr(transformers, model_type).from_pretrained( base_model, - device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -307,8 +304,6 @@ def load_model( model = AutoModelForCausalLM.from_pretrained( base_model, config=config, - device_map=cfg.device_map, - torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -316,10 +311,8 @@ def load_model( model = AutoModelForCausalLM.from_pretrained( base_model, config=config, - device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -330,10 +323,8 @@ def load_model( LOG.exception(err) model = AutoModelForCausalLM.from_pretrained( base_model, - device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, )