Skip to content

Commit

Permalink
Fix example int8_inference_huggingface.py (#414)
Browse files Browse the repository at this point in the history
* Fix example int8_inference_huggingface.py

* Update examples/int8_inference_huggingface.py

Co-authored-by: Younes Belkada <[email protected]>

---------

Co-authored-by: Younes Belkada <[email protected]>
  • Loading branch information
alexrs and younesbelkada authored Feb 27, 2024
1 parent cc5f8cd commit 4b232ed
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions examples/int8_inference_huggingface.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import LlamaForCausalLM, LlamaTokenizer

MAX_NEW_TOKENS = 128
model_name = 'decapoda-research/llama-7b-hf'
model_name = 'meta-llama/Llama-2-7b-hf'

text = 'Hamburg is in which country?\n'
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = LlamaTokenizer.from_pretrained(model_name)
input_ids = tokenizer(text, return_tensors="pt").input_ids

free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'

n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}

model = AutoModelForCausalLM.from_pretrained(
model = LlamaForCausalLM.from_pretrained(
model_name,
device_map='auto',
load_in_8bit=True,
max_memory=max_memory
)

generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))

0 comments on commit 4b232ed

Please sign in to comment.