Skip to content

Commit

Permalink
Update loading of models with new attention methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Mar 5, 2024
1 parent 1a717f5 commit a464e7b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 15 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ dependencies = [
"scipy",
"pyyaml",
"packaging",
"torch>=2.0.1",
"transformers>=4.33.1",
"torch>=2.2.0",
"transformers>=4.37",
"tokenizers>=0.13.3",
"sentencepiece",
"protobuf",
Expand Down
2 changes: 1 addition & 1 deletion textwiz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import loader, conversation_template, prompt_template


__version__ = '0.3.0'
__version__ = '0.4.0'


def is_chat_model(model_name: str) -> bool:
Expand Down
45 changes: 33 additions & 12 deletions textwiz/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,25 +708,46 @@ def load_model(model_name: str, quantization_8bits: bool = False, quantization_4
device_map = 'balanced'

# Load model
model = AutoModelForCausalLM.from_pretrained(ALL_MODELS_MAPPING[model_name], device_map=device_map,
torch_dtype=dtype, load_in_8bit=quantization_8bits,
load_in_4bit=quantization_4bits, low_cpu_mem_usage=True,
**additional_kwargs)
# We first try with flash attention 2
try:
model = AutoModelForCausalLM.from_pretrained(ALL_MODELS_MAPPING[model_name], attn_implementation='flash_attention_2',
device_map=device_map, torch_dtype=dtype, load_in_8bit=quantization_8bits,
load_in_4bit=quantization_4bits, low_cpu_mem_usage=True, **additional_kwargs)
success = True
except:
success = False

# Second try with Pytorch native sdpa (which may sometimes but not for all models also use flash attention 2)
if not success:
try:
model = AutoModelForCausalLM.from_pretrained(ALL_MODELS_MAPPING[model_name], attn_implementation='sdpa',
device_map=device_map, torch_dtype=dtype, load_in_8bit=quantization_8bits,
load_in_4bit=quantization_4bits, low_cpu_mem_usage=True, **additional_kwargs)
success = True
except:
success = False

# Last try with BetterTransformer, which is the same as sdpa but with coverage for more models
if not success:
model = AutoModelForCausalLM.from_pretrained(ALL_MODELS_MAPPING[model_name], attn_implementation='eager', device_map=device_map,
torch_dtype=dtype, load_in_8bit=quantization_8bits, load_in_4bit=quantization_4bits,
low_cpu_mem_usage=True, **additional_kwargs)
# For some reason bettertransformer is supported for codegen2 models but makes them crash during the forward
if not ('codegen2-' in model_name):
# Convert to better transformer to use Pytorch optimizations if supported by the model
try:
model = model.to_bettertransformer()
except:
warnings.warn(('The default manual attention implementation will be used. This will result in slower generation and '
'higher memory usage. This should not be an issue for small models.'))


# If the flag is active we directly put our model on one gpu without using any device_map (this is
# more efficient). But if the model is quantized, this is already done automatically because quantization
# happen only on gpu
if only_move_to_one_gpu and not quantization:
# This operation is in-place for nn.Module
model.cuda(gpu_rank)

# For some reason bettertransformer is supported for codegen2 models but makes them crash during the forward
if not ('codegen2-' in model_name):
# Convert to better transformer to use Pytorch optimizations if supported by the model
try:
model = model.to_bettertransformer()
except:
pass

model.eval()

Expand Down

0 comments on commit a464e7b

Please sign in to comment.