Skip to content

Commit

Permalink
chore: Clean up repetitive model kwargs (#670)
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 authored Oct 4, 2023
1 parent 697c50d commit e62d590
Showing 1 changed file with 5 additions and 14 deletions.
19 changes: 5 additions & 14 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def load_model(
hijack_expand_mask()

model_kwargs = {}

model_kwargs["device_map"] = cfg.device_map
model_kwargs["torch_dtype"] = cfg.torch_dtype

if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
Expand Down Expand Up @@ -206,6 +210,7 @@ def load_model(
or cfg.is_mistral_derived_model
):
model_kwargs["use_flash_attention_2"] = True

try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM
Expand All @@ -220,10 +225,8 @@ def load_model(
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
**model_kwargs,
)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
Expand Down Expand Up @@ -257,28 +260,22 @@ def load_model(

model = MixFormerSequentialForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
Expand Down Expand Up @@ -307,19 +304,15 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
Expand All @@ -330,10 +323,8 @@ def load_model(
LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
Expand Down

0 comments on commit e62d590

Please sign in to comment.