Skip to content

Commit

Permalink
update retrieve_and_generate method with Query result pydantic model
Browse files Browse the repository at this point in the history
  • Loading branch information
KevKibe committed Apr 30, 2024
1 parent aa826f3 commit c197796
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 50 deletions.
24 changes: 14 additions & 10 deletions src/_cohere/doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from operator import itemgetter
from _cohere.config import Config
from utils.config import Config
from utils.response_model import QueryResult
from langchain.output_parsers import PydanticOutputParser

class CoherePineconeIndexer:
"""
Expand Down Expand Up @@ -218,7 +220,7 @@ def initialize_vectorstore(self, index_name):
vectorstore = PineconeVectorStore(index,embed, "text")
return vectorstore

def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'gpt-3.5-turbo-1106', top_k: int =5):
def retrieve_and_generate(self,query: str, vector_store: str, model_name: str = 'gpt-3.5-turbo-1106', top_k: int =5):
"""
Retrieve documents from the Pinecone index and generate a response.
Args:
Expand All @@ -228,17 +230,19 @@ def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'g
top_k: The number of documents to retrieve from the index : defaults to 5
"""
llm = Cohere(model="command", cohere_api_key = self.cohere_api_key)
rag_prompt = PromptTemplate(template = Config.template_str, input_variables = ["query", "context"])

vector_store = self.initialize_vectorstore(index_name)
parser = PydanticOutputParser(pydantic_object=QueryResult)
rag_prompt = PromptTemplate(template = Config.template_str,
input_variables = ["query", "context"],
partial_variables={"format_instructions": parser.get_format_instructions()})
retriever = vector_store.as_retriever(search_kwargs = {"k": top_k})

rag_chain = (
{"context": itemgetter("query")| retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
| StrOutputParser()
"query": itemgetter("query"),
}
| rag_prompt
| llm
| parser
)

return rag_chain.invoke({"query": query})
34 changes: 14 additions & 20 deletions src/_google/doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from langchain_community.document_loaders import UnstructuredHTMLLoader
from langchain_pinecone import PineconeVectorStore
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from operator import itemgetter
from langchain_google_genai import ChatGoogleGenerativeAI
from _google.config import Config

from utils.config import Config
from utils.response_model import QueryResult
from langchain.output_parsers import PydanticOutputParser

class GooglePineconeIndexer:
"""
Expand Down Expand Up @@ -97,31 +97,22 @@ def load_document(self, file_url: str) -> List[str]:
"""
pages = []
file_path = Path(file_url)

# Determine file type and use the appropriate loader
file_extension = file_path.suffix

# Load and split PDF files
if file_extension == ".pdf":
loader = PyPDFLoader(file_url)
pages = loader.load_and_split()

# Load and split DOCX and DOC files
elif file_extension in ('.docx', '.doc'):
loader = UnstructuredWordDocumentLoader(file_url)
pages = loader.load_and_split()

# Load and split Markdown files
elif file_extension == '.md':
loader = UnstructuredMarkdownLoader(file_url)
pages = loader.load_and_split()

# Load and split HTML files
elif file_extension == '.html':
loader = UnstructuredHTMLLoader(file_url)
pages = loader.load_and_split()

# Return the list of pages
return pages

def tiktoken_len(self, text: str) -> int:
Expand Down Expand Up @@ -242,7 +233,7 @@ def initialize_vectorstore(self, index_name):
return vectorstore


def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'gemini-pro', top_k: int =5):
def retrieve_and_generate(self,query: str, vector_store: str, model_name: str = 'gemini-pro', top_k: int =5):
"""
Retrieve documents from the Pinecone index and generate a response.
Args:
Expand All @@ -252,16 +243,19 @@ def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'g
top_k: The number of documents to retrieve from the index : defaults to 5
"""
llm = ChatGoogleGenerativeAI(model = Config.default_google_model, google_api_key=self.google_api_key)
rag_prompt = PromptTemplate(template = Config.template_str, input_variables = ["query", "context"])
vector_store = self.initialize_vectorstore(index_name)
parser = PydanticOutputParser(pydantic_object=QueryResult)
rag_prompt = PromptTemplate(template = Config.template_str,
input_variables = ["query", "context"],
partial_variables={"format_instructions": parser.get_format_instructions()})
retriever = vector_store.as_retriever(search_kwargs = {"k": top_k})

rag_chain = (
{"context": itemgetter("query")| retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
| StrOutputParser()
"query": itemgetter("query"),
}
| rag_prompt
| llm
| parser
)

return rag_chain.invoke({"query": query})
35 changes: 15 additions & 20 deletions src/_openai/doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
from langchain_community.document_loaders import UnstructuredHTMLLoader
from langchain_pinecone import PineconeVectorStore
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from operator import itemgetter
from langchain_openai import ChatOpenAI
from _openai.config import Config


from utils.config import Config
from utils.response_model import QueryResult
from langchain.output_parsers import PydanticOutputParser

class OpenaiPineconeIndexer:
"""
Expand Down Expand Up @@ -99,30 +98,23 @@ def load_document(self, file_url: str) -> List[str]:
pages = []
file_path = Path(file_url)

# Determine file type and use the appropriate loader
file_extension = file_path.suffix

# Load and split PDF files
if file_extension == ".pdf":
loader = PyPDFLoader(file_url)
pages = loader.load_and_split()

# Load and split DOCX and DOC files
elif file_extension in ('.docx', '.doc'):
loader = UnstructuredWordDocumentLoader(file_url)
pages = loader.load_and_split()

# Load and split Markdown files
elif file_extension == '.md':
loader = UnstructuredMarkdownLoader(file_url)
pages = loader.load_and_split()

# Load and split HTML files
elif file_extension == '.html':
loader = UnstructuredHTMLLoader(file_url)
pages = loader.load_and_split()

# Return the list of pages
return pages


Expand Down Expand Up @@ -240,7 +232,7 @@ def initialize_vectorstore(self, index_name):
return vectorstore


def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'gpt-3.5-turbo-1106', top_k: int =5):
def retrieve_and_generate(self,query: str, vector_store: str, model_name: str = 'gpt-3.5-turbo-1106', top_k: int =5):
"""
Retrieve documents from the Pinecone index and generate a response.
Args:
Expand All @@ -250,17 +242,19 @@ def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'g
top_k: The number of documents to retrieve from the index : defaults to 5
"""
llm = ChatOpenAI(model = Config.default_openai_model, openai_api_key = self.openai_api_key)
rag_prompt = PromptTemplate(template = Config.template_str, input_variables = ["query", "context"])

vector_store = self.initialize_vectorstore(index_name)
parser = PydanticOutputParser(pydantic_object=QueryResult)
rag_prompt = PromptTemplate(template = Config.template_str,
input_variables = ["query", "context"],
partial_variables={"format_instructions": parser.get_format_instructions()})
retriever = vector_store.as_retriever(search_kwargs = {"k": top_k})

rag_chain = (
{"context": itemgetter("query")| retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
| StrOutputParser()
"query": itemgetter("query"),
}
| rag_prompt
| llm
| parser
)

return rag_chain.invoke({"query": query})
Expand All @@ -270,3 +264,4 @@ def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'g




0 comments on commit c197796

Please sign in to comment.