You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
Thanks for your work!
I wonder how can we generate tokens by using model.generate() or model(inputs)?
The following code will produce bug while the model is printed. import transformers transformers.AutoModelForCausalLM.from_pretrained = staticmethod(get_llama_marlin) model =transformers.AutoModelForCausalLM.from_pretrained(path, torch_dtype="auto", device_map="auto") text="Hello, my dog is cute and" print(model) tokenizer = transformers.AutoTokenizer.from_pretrained(path, use_fast=False) inputs = tokenizer(text, return_tensors="pt").input_ids inputs = inputs.to(model.device) generation_output = model(inputs) # or model.generate(inputs) print(generation_output)
The model is printed:
The bug is :
When we comment transformers.AutoModelForCausalLM.from_pretrained = staticmethod(get_llama_marlin) , the rest code run properly.
The text was updated successfully, but these errors were encountered:
Hi,
Thanks for your work!
I wonder how can we generate tokens by using model.generate() or model(inputs)?
The following code will produce bug while the model is printed.
import transformers transformers.AutoModelForCausalLM.from_pretrained = staticmethod(get_llama_marlin) model =transformers.AutoModelForCausalLM.from_pretrained(path, torch_dtype="auto", device_map="auto") text="Hello, my dog is cute and" print(model) tokenizer = transformers.AutoTokenizer.from_pretrained(path, use_fast=False) inputs = tokenizer(text, return_tensors="pt").input_ids inputs = inputs.to(model.device) generation_output = model(inputs) # or model.generate(inputs) print(generation_output)
The model is printed:
The bug is :
When we comment
transformers.AutoModelForCausalLM.from_pretrained = staticmethod(get_llama_marlin)
, the rest code run properly.The text was updated successfully, but these errors were encountered: