Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding LCEL Chain and Favicon #20

Draft
wants to merge 1 commit into
base: vector_graph
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# IDE
.idea/

# Notebook output
failed_files.json

# Jupyter Notebook
.ipynb_checkpoints
__pycache__

# Mac System file
.DS_STORE

Expand All @@ -13,4 +15,4 @@ __pycache__
secrets.toml

# Dependency managers
*.lock
*.lock
Binary file removed rag_demo/__pycache__/neo4j_driver.cpython-311.pyc
Binary file not shown.
16 changes: 7 additions & 9 deletions rag_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
"appStarted",
{})

st.set_page_config(
page_title="Neo4j RAG Demo",
page_icon="static/logo-mark-fullcolor-CMYK-transBG.png"
)

set_llm_cache(InMemoryCache())

st.markdown(f"""
Expand Down Expand Up @@ -69,14 +74,7 @@

# Vector only response
vector_response = rag_vector_only.get_results(user_input)
content = f"##### Vector only: \n" + vector_response['answer']

# Cite sources, if any
sources = vector_response['sources']
sources_split = sources.split(', ')
for source in sources_split:
if source != "" and source != "N/A" and source != "None":
content += f"\n - [{source}]({source})"
content = f"##### Vector only: \n" + vector_response

track("rag_demo", "ai_response", {"type": "vector", "answer": content})
new_message = {"role": "ai", "content": content}
Expand All @@ -90,7 +88,7 @@

vgraph_response = rag_vector_graph.get_results(user_input)
# content = f"##### Vector + Graph: \n" + vgraph_response['answer']
content = f"##### Vector + Graph: \n" + vgraph_response.content
content = f"##### Vector + Graph: \n" + vgraph_response

# Cite sources, if any
# sources = vgraph_response['sources']
Expand Down
25 changes: 25 additions & 0 deletions rag_demo/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel


# add typing for input
class QuestionPrompt(BaseModel):
question: str


# utility for formatting retrieved data
def format_docs(docs):
return "\n\n".join([d.page_content for d in docs])


def generate_chain(retriever, model, prompt):
return (
{
"context": itemgetter("question") | retriever | format_docs,
"question": itemgetter("question")
}
| prompt
| model
| StrOutputParser()
)
89 changes: 34 additions & 55 deletions rag_demo/rag_vector_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from operator import itemgetter

from langchain.chains import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain.prompts.prompt import PromptTemplate
from langchain.llms.bedrock import Bedrock
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel
from retry import retry
from timeit import default_timer as timer
import streamlit as st
Expand All @@ -13,6 +18,8 @@
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.conversation.memory import ConversationBufferMemory

from rag import generate_chain

PROMPT_TEMPLATE = """Human: You are a Financial expert with SEC filings who can answer questions only based on the context below.
* Answer the question STRICTLY based on the context provided in JSON below.
* Do not assume or retrieve any information outside of the context
Expand All @@ -33,12 +40,25 @@
</context>

Assistant:"""
PROMPT = PromptTemplate(
input_variables=["question","context"], template=PROMPT_TEMPLATE
)
PROMPT = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
# PROMPT_TEMPLATE = '''You are a Financial expert with SEC filings. A user is wondering "{question}".
#
# Please answer the questions only based on the context below.
# * Do not assume or retrieve any information outside of the context
# * Use three sentences maximum and keep the answer concise
# * Think step by step before answering.
# * Do not return helpful or extra text or apologies
# * Just return summary to the user. DO NOT start with Here is a summary
# * List the results in rich text format if there are more than one results
# * If the context is empty, just respond None
#
# Here is the context:
# {context}
# '''

EMBEDDING_MODEL = OpenAIEmbeddings()
MEMORY = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
MEMORY = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer',
return_messages=True)

url = st.secrets["NEO4J_URI"]
username = st.secrets["NEO4J_USERNAME"]
Expand All @@ -48,27 +68,27 @@
url=url,
username=username,
password=password,
sanitize = True
sanitize=True
)
# TEMP
llm_key = st.secrets["OPENAI_API_KEY"]

@retry(tries=5, delay=5)
def get_results(question):

@retry(tries=5, delay=5)
def get_results(question: str):
# TODO: Update index and node property names to reflect the embedding origin LLM,
# ie "document_text_openai" index and "text_openai_embedding"
# Currently the try-except block below only works with small datasets, it needs to be replaced
# with a large node count variation

index_name = "form_10k_chunks"
node_property_name = "textEmbedding"
url=st.secrets["NEO4J_URI"]
username=st.secrets["NEO4J_USERNAME"]
password=st.secrets["NEO4J_PASSWORD"]
url = st.secrets["NEO4J_URI"]
username = st.secrets["NEO4J_USERNAME"]
password = st.secrets["NEO4J_PASSWORD"]
retrieval_query = """
WITH node AS doc, score as similarity
ORDER BY similarity DESC LIMIT 5
ORDER BY similarity DESC LIMIT 20
CALL { WITH doc
OPTIONAL MATCH (prevDoc:Document)-[:NEXT]->(doc)
OPTIONAL MATCH (doc)-[:NEXT]->(nextDoc:Document)
Expand All @@ -87,31 +107,6 @@ def get_results(question):
RETURN coalesce(prevDoc.text,'') + coalesce(document.text,'') + coalesce(nextDoc.text,'') as text, similarity as score,
{documentId: coalesce(document.documentId,''), company: coalesce(companyName,''), managers: coalesce(managers,''), source: document.source} AS metadata
"""
# retrieval_query = """
# WITH node AS doc, score as similarity
# CALL { WITH doc
# OPTIONAL MATCH (prevDoc:Document)-[:NEXT]->(doc)
# OPTIONAL MATCH (doc)-[:NEXT]->(nextDoc:Document)
# RETURN prevDoc, doc AS result, nextDoc
# }
# WITH result, prevDoc, nextDoc, similarity
# CALL {
# WITH result
# OPTIONAL MATCH (result)<-[:HAS_CHUNK]-(:Form)-[:FILED]->(company:Company), (company)<-[:OWNS_STOCK_IN]-(manager:Manager)
# WITH result, company.name as companyName, apoc.text.join(collect(manager.managerName),';') as managers
# WHERE companyName IS NOT NULL OR managers > ""
# WITH result, companyName, managers
# ORDER BY result.score DESC
# RETURN result as document, result.score as popularity, companyName, managers
# }
# RETURN '##DocumentID: ' + coalesce(document.documentId,'') +'\n'+
# '##Text: ' + coalesce(prevDoc.text+'\n','') + coalesce(document.text+'\n','') + coalesce(nextDoc.text+'\n','') +
# '###Company: ' + coalesce(companyName,'') +'\n'+ '###Managers: ' + coalesce(managers,'') as text,
# similarity as score, {source: document.source} AS metadata
# ORDER BY similarity ASC // so that best answers are the last
# """


try:
store = Neo4jVector.from_existing_index(
embedding=EMBEDDING_MODEL,
Expand All @@ -136,26 +131,10 @@ def get_results(question):
)

retriever = store.as_retriever()
chat_llm = ChatOpenAI(temperature=0)
chain = generate_chain(retriever, chat_llm, PROMPT)

context = retriever.get_relevant_documents(question)
print(context)
completePrompt = PROMPT.format(question=question, context=context)
print(completePrompt)
chat_llm = ChatOpenAI(openai_api_key=llm_key)
result = chat_llm.invoke(completePrompt)
# chain = RetrievalQAWithSourcesChain.from_chain_type(
# ChatOpenAI(temperature=0),
# chain_type="stuff",
# retriever=retriever,
# memory=MEMORY
# )

# result = chain.invoke({
# "question": question},
# prompt=PROMPT,
# return_only_outputs = True,
# )
result = chain.invoke({"question": question})

print(f'result: {result}')
# Will return a dict with keys: answer, sources
return result
46 changes: 15 additions & 31 deletions rag_demo/rag_vector_only.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from langchain.chains import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain.prompts.prompt import PromptTemplate
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.chains import RetrievalQAWithSourcesChain
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.chains.conversation.memory import ConversationBufferMemory
from retry import retry
from timeit import default_timer as timer
import streamlit as st
from neo4j_driver import run_query
from json import loads, dumps


from rag import generate_chain

PROMPT_TEMPLATE = """Human: You are a Financial expert with SEC filings who can answer questions only based on the context below.
* Answer the question STRICTLY based on the context provided in JSON below.
Expand All @@ -23,7 +19,7 @@
* If the context is empty, just respond None

<question>
{input}
{question}
</question>

Here is the context:
Expand All @@ -32,26 +28,26 @@
</context>

Assistant:"""
PROMPT = PromptTemplate(
input_variables=["input","context"], template=PROMPT_TEMPLATE
)
PROMPT = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)

EMBEDDING_MODEL = OpenAIEmbeddings()
MEMORY = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
MEMORY = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer',
return_messages=True)


def df_to_context(df):
result = df.to_json(orient="records")
parsed = loads(result)
return dumps(parsed)


@retry(tries=5, delay=5)
def get_results(question):

index_name = "form_10k_chunks"
node_property_name = "textEmbedding"
url=st.secrets["NEO4J_URI"]
username=st.secrets["NEO4J_USERNAME"]
password=st.secrets["NEO4J_PASSWORD"]
url = st.secrets["NEO4J_URI"]
username = st.secrets["NEO4J_USERNAME"]
password = st.secrets["NEO4J_PASSWORD"]

try:
store = Neo4jVector.from_existing_index(
Expand All @@ -75,22 +71,10 @@ def get_results(question):
)

retriever = store.as_retriever()
chat_llm = ChatOpenAI(temperature=0)
chain = generate_chain(retriever, chat_llm, PROMPT)

chain = RetrievalQAWithSourcesChain.from_chain_type(
ChatOpenAI(temperature=0),
chain_type="stuff",
retriever=retriever,
memory=MEMORY
)

result = chain.invoke({
"question": question},
prompt=PROMPT,
return_only_outputs = True
)
result = chain.invoke({"question": question})

print(f'result: {result}')
# Will return a dict with keys: answer, sources
return result


Binary file added static/logo-mark-fullcolor-CMYK-transBG.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.