From 715c61ffca4d3bca1c383a4ec81b41c27eecd81c Mon Sep 17 00:00:00 2001 From: Axlfc Date: Fri, 17 Mar 2023 13:18:27 +0100 Subject: [PATCH] UPDATED: GALACTICA support --- Content/Python/chatBot/GALACTICA.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Content/Python/chatBot/GALACTICA.py b/Content/Python/chatBot/GALACTICA.py index b5aa839..e957d16 100644 --- a/Content/Python/chatBot/GALACTICA.py +++ b/Content/Python/chatBot/GALACTICA.py @@ -3,7 +3,7 @@ import sys -def process_bot_answer(input_text): +def process_bot_answer(input_text, text_length=200): tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-125m") model = OPTForCausalLM.from_pretrained("facebook/galactica-125m", device_map="auto") # model = OPTForCausalLM.from_pretrained("facebook/galactica-125m") @@ -11,7 +11,7 @@ def process_bot_answer(input_text): # Tokenize the prompt and generate text using the BLOOM model input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda") # input_ids = tokenizer(input_text, return_tensors="pt").input_ids - outputs = model.generate(input_ids, max_length=200, do_sample=True) + outputs = model.generate(input_ids, max_length=text_length, do_sample=True) # Decode the generated text and print it generated_text = tokenizer.decode(outputs[0]) @@ -32,7 +32,7 @@ def main(): input_text = sys.argv[1] - print(process_bot_answer(input_text)) + print(process_bot_answer(input_text, text_length)) # return process_bot_answer(input_text)