-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlaw_rag.py
157 lines (139 loc) · 6.06 KB
/
law_rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain_chroma import Chroma
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
import chromadb
import os
import hashlib
class LawDocumentRAG:
def __init__(self, groq_api_key):
os.environ["GROQ_API_KEY"] = groq_api_key
self.chroma_client = chromadb.PersistentClient(path="./chroma_db")
self.embed_model = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
self.llm = ChatGroq(model="llama3-70b-8192", temperature=0.1)
self.text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", " ", ""],
chunk_size=1000,
chunk_overlap=200,
length_function=len,
is_separator_regex=False,
)
self.ipc_vectorstore = None
self.custom_vectorstore = None
self.current_custom_doc_hash = None
self.load_default_ipc()
self.LEGAL_RAG_SYSTEM_PROMPT = """You are a helpful legal assistant that explains legal concepts in simple, easy-to-understand language.
Use the following pieces of information to answer the human's questions:
1. Retrieved Context (relevant legal text):
```
{context}
```
2. Conversation History:
```
{chat_history}
```
Current Question: {input}
When answering:
1. Use the context and chat history to provide comprehensive answers
2. Maintain consistency with previous responses
3. If referring to previous discussion, be explicit about it
4. Avoid using complex legal jargon and explain any technical terms
5. If you don't know something based on the provided context, say so
"""
self.RAG_HUMAN_PROMPT = "{input}"
self.rag_prompt = ChatPromptTemplate.from_messages([
("system", self.LEGAL_RAG_SYSTEM_PROMPT),
("human", self.RAG_HUMAN_PROMPT)
])
def load_default_ipc(self):
try:
try:
self.ipc_vectorstore = Chroma(
client=self.chroma_client,
collection_name="ipc_law",
embedding_function=self.embed_model
)
except ValueError:
ipc_path = "IPC_pdf.pdf"
if os.path.exists(ipc_path):
self.load_ipc_document(ipc_path)
except Exception as e:
print(f"Error loading default IPC document: {str(e)}")
def get_file_hash(self, file_path):
with open(file_path, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
def format_chat_history(self, messages):
formatted_history = []
for msg in messages:
role = msg["role"]
content = msg["content"]
formatted_history.append(f"{role.capitalize()}: {content}")
return "\n".join(formatted_history)
def load_ipc_document(self, ipc_pdf_path):
try:
if self.ipc_vectorstore is None:
loader = PyPDFLoader(ipc_pdf_path)
pages = loader.load()
texts = self.text_splitter.split_documents(pages)
self.ipc_vectorstore = Chroma.from_documents(
documents=texts,
embedding=self.embed_model,
collection_name="ipc_law",
client=self.chroma_client
)
return "IPC document loaded and indexed successfully"
except Exception as e:
return f"Error loading IPC document: {str(e)}"
def load_custom_document(self, pdf_path):
try:
new_doc_hash = self.get_file_hash(pdf_path)
if new_doc_hash == self.current_custom_doc_hash and self.custom_vectorstore is not None:
return "Document already loaded"
try:
self.chroma_client.delete_collection("custom_doc")
except ValueError:
pass
loader = PyPDFLoader(pdf_path)
pages = loader.load()
texts = self.text_splitter.split_documents(pages)
self.custom_vectorstore = Chroma.from_documents(
documents=texts,
embedding=self.embed_model,
collection_name="custom_doc",
client=self.chroma_client
)
self.current_custom_doc_hash = new_doc_hash
return "Custom document loaded and indexed successfully"
except Exception as e:
return f"Error loading custom document: {str(e)}"
def format_docs(self, docs):
return "\n\n".join(doc.page_content for doc in docs)
def create_rag_chain(self, vectorstore, chat_history):
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
rag_chain = (
{
"context": retriever | self.format_docs,
"chat_history": lambda x: self.format_chat_history(chat_history),
"input": RunnablePassthrough()
}
| self.rag_prompt
| self.llm
| StrOutputParser()
)
return rag_chain
async def query_ipc(self, question: str, chat_history: list) -> str:
if self.ipc_vectorstore is None:
return "Error: IPC document not properly loaded"
rag_chain = self.create_rag_chain(self.ipc_vectorstore, chat_history)
response = await rag_chain.ainvoke(question)
return response
async def query_custom_document(self, question: str, chat_history: list) -> str:
if self.custom_vectorstore is None:
return "Please load a custom document first"
rag_chain = self.create_rag_chain(self.custom_vectorstore, chat_history)
response = await rag_chain.ainvoke(question)
return response