Skip to content

Commit

Permalink
Maintain the historical chat conversation per user; (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
ranjan-stha authored Sep 26, 2024
1 parent 7aa0f31 commit ba18a99
Showing 1 changed file with 35 additions and 8 deletions.
43 changes: 35 additions & 8 deletions chatbot-core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,28 @@
class LLMBase:
"""LLM Base containing common methods"""

mem_key: str = field(init=False)
conversation_max_window: int = field(init=False)
qdrant_client: QdrantClient = field(init=False)
llm_model: Any = field(init=False)
user_memory_mapping: dict = field(init=False)
memory: Any = field(init=False)
embedding_model: CustomEmbeddingsWrapper = field(init=False)
rag_chain: Optional[Any] = None

def __post_init__(self, mem_key: str = "chat_history", conversation_max_window: int = 3):
self.llm_model = None
self.qdrant_client = None
self.memory = None

self.mem_key = mem_key
self.conversation_max_window = conversation_max_window

try:
self.qdrant_client = QdrantClient(host=settings.QDRANT_DB_HOST, port=settings.QDRANT_DB_PORT)
except Exception as e:
raise Exception(f"Qdrant client is not properly setup. {str(e)}")
self.memory = ConversationBufferWindowMemory(k=conversation_max_window, memory_key=mem_key, return_messages=True)

self.user_memory_mapping = {}

self.embedding_model = CustomEmbeddingsWrapper(
url=settings.EMBEDDING_MODEL_URL,
Expand Down Expand Up @@ -113,23 +119,44 @@ def create_chain(self, db_collection_name: str):
rag_chain = create_retrieval_chain(history_aware_retriever, chat_response_chain)
return rag_chain

def execute_chain(self, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME):
def execute_chain(self, user_id: str, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME):
"""
Executes the chain
"""
if not self.rag_chain:
self.rag_chain = self.create_chain(db_collection_name=db_collection_name)

response = self.rag_chain.invoke({"input": query, "chat_history": self.get_message_history()["chat_history"]})
self.memory.chat_memory.add_message(HumanMessage(content=query))
self.memory.chat_memory.add_message(AIMessage(content=response["answer"]))
if "user_id" not in self.user_memory_mapping:
self.user_memory_mapping[user_id] = ConversationBufferWindowMemory(
k=self.conversation_max_window, memory_key=self.mem_key, return_messages=True
)

memory = self.user_memory_mapping[user_id]

response = self.rag_chain.invoke(
{"input": query, "chat_history": self.get_message_history(user_id=user_id)["chat_history"]}
)
memory.chat_memory.add_message(HumanMessage(content=query))
memory.chat_memory.add_message(AIMessage(content=response["answer"]))
self.user_memory_mapping[user_id] = memory

return response["answer"] if "answer" in response else ""

def get_message_history(self):
def get_message_history(self, user_id: str):
"""
Returns the historical conversational data
"""
return self.memory.load_memory_variables({})
if "user_id" in self.user_memory_mapping:
return self.user_memory_mapping[user_id].load_memory_variables({})
return {}

def delete_message_history_by_user(self, user_id: str) -> bool:
"""Deletes the message history based on user id"""
if "user_id" in self.user_memory_mapping:
del self.user_memory_mapping[user_id]
logger.info(f"Successfully delete the {user_id} conversational history.")
return True
return False


@dataclass
Expand Down

0 comments on commit ba18a99

Please sign in to comment.