Skip to content

Commit

Permalink
UPDATED: GALACTICA support
Browse files Browse the repository at this point in the history
  • Loading branch information
Axlfc committed Mar 17, 2023
1 parent 0a7baf2 commit 715c61f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions Content/Python/chatBot/GALACTICA.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import sys


def process_bot_answer(input_text):
def process_bot_answer(input_text, text_length=200):
tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-125m")
model = OPTForCausalLM.from_pretrained("facebook/galactica-125m", device_map="auto")
# model = OPTForCausalLM.from_pretrained("facebook/galactica-125m")

# Tokenize the prompt and generate text using the BLOOM model
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
# input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = model.generate(input_ids, max_length=200, do_sample=True)
outputs = model.generate(input_ids, max_length=text_length, do_sample=True)

# Decode the generated text and print it
generated_text = tokenizer.decode(outputs[0])
Expand All @@ -32,7 +32,7 @@ def main():
input_text = sys.argv[1]


print(process_bot_answer(input_text))
print(process_bot_answer(input_text, text_length))
# return process_bot_answer(input_text)


Expand Down

0 comments on commit 715c61f

Please sign in to comment.