diff --git a/biochatter/llm_connect.py b/biochatter/llm_connect.py index 52e28842..6d8c9fd2 100644 --- a/biochatter/llm_connect.py +++ b/biochatter/llm_connect.py @@ -303,6 +303,60 @@ def get_msg_json(self): return json.dumps(d) +class WasmConversation(Conversation): + def __init__( + self, + model_name: str, + prompts: dict, + correct: bool = True, + split_correction: bool = False, + rag_agent: DocumentEmbedder = None, + ): + """ + This class is used to return the complete query as a string to be used + in the frontend running the wasm model. It does not call the API itself, + but updates the message history similarly to the other conversation + classes. It overrides the `query` method from the `Conversation` class + to return a plain string instead of a tuple that contains the entire + message for the model. + """ + super().__init__( + model_name=model_name, + prompts=prompts, + correct=correct, + split_correction=split_correction, + rag_agent=rag_agent, + ) + + def query(self, text: str, collection_name: Optional[str] = None): + self.append_user_message(text) + + if self.rag_agent: + if self.rag_agent.use_prompt: + self._inject_context(text, collection_name) + + return self._primary_query() + + def _primary_query(self): + """ + Concatenate all messages in the conversation into a single string and + return it. Currently discards information about roles (system, user). + """ + return "\n".join([m.content for m in self.messages]) + + def _correct_response(self, msg: str): + """ + This method is not used for the wasm model. + """ + return "ok" + + def set_api_key(self, api_key: str, user: str | None = None): + """ + This method is not used for the wasm model. + """ + return True + + class XinferenceConversation(Conversation): def __init__( self, diff --git a/test/test_llm_connect.py b/test/test_llm_connect.py index 5ed94ff4..c420a39c 100644 --- a/test/test_llm_connect.py +++ b/test/test_llm_connect.py @@ -8,6 +8,7 @@ HumanMessage, AIMessage, XinferenceConversation, + WasmConversation, ) import pytest from unittest.mock import patch, Mock @@ -212,3 +213,39 @@ def test_generic_chatting(): ) (msg, token_usage, correction) = convo.query("Hello, world!") assert token_usage["completion_tokens"] > 0 + + +def test_wasm_conversation(): + # Initialize the class + wasm_convo = WasmConversation( + model_name="test_model", + prompts={}, + correct=True, + split_correction=False, + rag_agent=None, + ) + + # Check if the model_name is correctly set + assert wasm_convo.model_name == "test_model" + + # Check if the prompts are correctly set + assert wasm_convo.prompts == {} + + # Check if the correct is correctly set + assert wasm_convo.correct == True + + # Check if the split_correction is correctly set + assert wasm_convo.split_correction == False + + # Check if the rag_agent is correctly set + assert wasm_convo.rag_agent == None + + # Test the query method + test_query = "Hello, world!" + result = wasm_convo.query(test_query) + assert result == test_query # assuming the messages list is initially empty + + # Test the _primary_query method, add another message to the messages list + wasm_convo.append_system_message("System message") + result = wasm_convo._primary_query() + assert result == test_query + "\nSystem message"