diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3287c0ee93..7fe230c997 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -176,6 +176,10 @@ def load_model( hijack_expand_mask() model_kwargs = {} + + model_kwargs["device_map"] = cfg.device_map + model_kwargs["torch_dtype"] = cfg.torch_dtype + if cfg.model_revision: model_kwargs["revision"] = cfg.model_revision if cfg.gptq: @@ -202,6 +206,7 @@ def load_model( if cfg.flash_attention and not cfg.sample_packing: if cfg.is_llama_derived_model or cfg.is_falcon_derived_model: model_kwargs["use_flash_attention_2"] = True + try: if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: from transformers import LlamaForCausalLM @@ -216,10 +221,8 @@ def load_model( model = LlamaForCausalLM.from_pretrained( base_model, config=config, - device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, **model_kwargs, ) # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: @@ -253,28 +256,22 @@ def load_model( model = MixFormerSequentialForCausalLM.from_pretrained( base_model, - device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, **model_kwargs, ) elif model_type and not cfg.trust_remote_code: if cfg.gptq: model = AutoModelForCausalLM.from_pretrained( base_model, - device_map=cfg.device_map, - torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) else: model = getattr(transformers, model_type).from_pretrained( base_model, - device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -303,8 +300,6 @@ def load_model( model = AutoModelForCausalLM.from_pretrained( base_model, config=config, - device_map=cfg.device_map, - torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -312,10 +307,8 @@ def load_model( model = AutoModelForCausalLM.from_pretrained( base_model, config=config, - device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -326,10 +319,8 @@ def load_model( LOG.exception(err) model = AutoModelForCausalLM.from_pretrained( base_model, - device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, )