diff --git a/Makefile b/Makefile index 84ae1b2..3237399 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ -# Variables -APP_NAME = playground -APP_VERSION = v0.1.0 - # Install project dependencies +local: + chmod +x local_setup.sh && ./local_setup.sh + + install: pip install -r requirements.txt @@ -19,4 +19,4 @@ test: pytest tests/ # Phony targets -.PHONY: install run format test local \ No newline at end of file +.PHONY: install run lint test local diff --git a/app/config/db.py b/app/config/db.py index b72b20f..c4340de 100644 --- a/app/config/db.py +++ b/app/config/db.py @@ -22,6 +22,7 @@ async def initialize_database(self): return client except Exception as error: print(f"Unable to connect to the MongoDB server with error: {error}.") + raise error class Config: env_file = ".env" diff --git a/app/config/template.py b/app/config/template.py index 366ed18..52786a4 100644 --- a/app/config/template.py +++ b/app/config/template.py @@ -4,12 +4,7 @@ Taking into account that you are the following character: Name is {character_name}, description is {character_description}. Give me some videogame bites that character would say in {additional_context}, taking into account this character is {character_traits}. -I need {number_of_lines} as maximum amount of lines. The output format should be something like this: - -- $Content of Line -- $Content of Line - -until not more lines rest, remember that "$Content of Line" is just the string value +I need {number_of_lines} as maximum amount of lines. The output format should be no introduction, just the lines separated by a new line. """ FINETUNE_PROMPT = """ diff --git a/app/main.py b/app/main.py index 20dd99a..9b021ca 100644 --- a/app/main.py +++ b/app/main.py @@ -45,8 +45,7 @@ async def startup_db_client(): @app.on_event("shutdown") async def shutdown_db_client(): - # await app.mongodb_client.close() - pass + app.mongodb_client.close() # default routes diff --git a/app/routes/dialogue.py b/app/routes/dialogue.py index 18adeff..2b23cfa 100644 --- a/app/routes/dialogue.py +++ b/app/routes/dialogue.py @@ -1,5 +1,6 @@ import os from typing import Annotated +import regex as re import openai import replicate @@ -24,17 +25,11 @@ def get_openai_lines(prompt: str): response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}], - max_tokens=100, + max_tokens=200, temperature=0.8, ) - lines = [] - - char1 = '"' - char2 = '"' - for line in response.choices[0].message.content.split("\n"): - if line != "": - lines.append(line) + lines = [re.sub(r"[^a-zA-Z0-9 ']", '', item.strip()) for item in response.choices[0].message.content.split("\n") if item != ""] return DialogueResponse( lines=lines @@ -44,11 +39,14 @@ def get_openai_lines(prompt: str): def get_llama_lines(prompt: str): response = replicate.run( "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", - input={"prompt": prompt, "system_prompt": SYSTEM_PROMPT, "max_new_tokens": 100}, + input={"prompt": prompt, "system_prompt": SYSTEM_PROMPT, "max_new_tokens": 200}, ) response = [item for item in response if item != ""] - response = "".join(response).split("\n\n") + response = "".join(response).split("\n") + # Llama needs some cleaning up for the response, very difficult to remove via prompt. + response = [re.sub(r"[^a-zA-Z0-9 ']", '', item[2:].strip()) for item in response if item != ""] + response.pop(0) return DialogueResponse(lines=response) diff --git a/requirements.txt b/requirements.txt index 35eb4df..8fab81b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,5 @@ black==23.3.0 isort==5.12.0 autoflake==2.2.0 flake8==6.0.0 -pytest==7.4.0 \ No newline at end of file +pytest==7.4.0 +regex==2023.10.3 \ No newline at end of file