diff --git a/README.md b/README.md deleted file mode 100644 index 1f05b87..0000000 --- a/README.md +++ /dev/null @@ -1,180 +0,0 @@ -# RAG Me Up -RAG Me Up is a generic framework (server + UIs) that enables you do to RAG on your own dataset **easily**. Its essence is a small and lightweight server and a couple of ways to run UIs to communicate with the server (or write your own). - -RAG Me Up can run on CPU but is best run on any GPU with at least 16GB of vRAM when using the default instruct model. - -Combine the power of RAG with the power of fine-tuning - check out our [LLaMa2Lang repository](https://github.com/UnderstandLingBV/LLaMa2Lang) on fine-tuning LLMs which can then be used in RAG Me Up. - -# Updates -- **2024-09-23** Hybrid retrieval with Postgres only (dense vectors with pgvector and sparse BM25 with pg_search) -- **2024-09-06** Implemented [Re2](https://arxiv.org/abs/2309.06275) -- **2024-09-04** Added an evaluation script that uses Ragas to evaluate your RAG pipeline -- **2024-08-30** Added Ollama compatibility -- **2024-08-27** Using cross encoders now so you can specify your own reranking model -- **2024-07-30** Added multiple provenance attribution methods -- **2024-06-26** Updated readme, added more file types, robust self-inflection -- **2024-06-05** Upgraded to Langchain v0.2 - -# Installation -## Server -``` -git clone https://github.com/UnderstandLingBV/RAGMeUp.git -cd server -pip install -r requirements.txt -``` -Then run the server using `python server.py` from the server subfolder. - -## Scala UI -Make sure you have JDK 17+. Download and install [SBT](https://www.scala-sbt.org/) and run `sbt run` from the `server/scala` directory or alternatively download the [compiled binary](https://github.com/UnderstandLingBV/RAGMeUp/releases/tag/scala-ui) and run `bin/ragemup(.bat)` - -## Using Postgres (adviced for production) -RAG Me Up supports Postgres as hybrid retrieval database with both pgvector and pg_search installed. To run Postgres instead of Milvus, follow these steps. - -- In the postgres folder is a Dockerfile, build it using `docker build -t ragmeup-pgvector-pgsearch .` -- Run the container using `docker run --name ragmeup-pgvector-pgsearch -e POSTGRES_USER=langchain -e POSTGRES_PASSWORD=langchain -e POSTGRES_DB=langchain -p 6024:5432 -d ragmeup-pgvector-pgsearch` -- Once in use, our custom PostgresBM25Retriever will automatically create the right indexes for you. -- pgvector however, will not do this automatically so you have to create them yourself (perhaps after loading the documents first so the right tables are created): - - Make sure the vector column is an actual vector (it's not by default): `ALTER TABLE langchain_pg_embedding ALTER COLUMN embedding TYPE vector(384);` - - Create the index (may take a while with a lot of data): `CREATE INDEX ON langchain_pg_embedding USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);` -- Be sure to set up the right paths in your .env file `vector_store_uri='postgresql+psycopg://langchain:langchain@localhost:6024/langchain'` and `vector_store_sparse_uri='postgresql://langchain:langchain@localhost:6024/langchain'` - -# How does RAG Me Up work? -RAG Me Up aims to provide a robust RAG pipeline that is configurable without necessarily writing any code. To achieve this, a couple of strategies are used to make sure that the user query can be accurately answered through the documents provided. - -The RAG pipeline is visualized in the image below: -![RAG pipeline drawing](./ragmeup.drawio.svg) - -The following steps are executed. Take note that some steps are optional and can be turned off through configuring the `.env` file. - -__Top part - Indexing__ -1. You collect and make your documents available to RAG Me Up. -2. Using different file type loaders, RAG Me Up will read the contents of your documents. Note that for some document types like JSON and XML, you need to specify additional configuration to tell RAG Me Up what to extract. -3. Your documents get chunked using a recursive splitter. -4. The chunks get converted into document (chunk) embeddings using an embedding model. Note that this model is usually a different one than the LLM you intend to use for chat. -5. RAG Me Up uses a hybrid search strategy, combining dense vectors in the vector database with sparse vectors using BM25. By default, RAG Me Up uses a local [Milvus database](https://milvus.io/). - -__Bottom part - Inference__ -1. Inference starts with a user asking a query. This query can either be an initial query or a follow-up query with an associated history and documents retrieved before. Note that both (chat history, documents) need to be passed on by a UI to properly handle follow-up querying. -2. A check is done if new documents need to be fetched, this can be due to one of two cases: - - There is no history given in which case we always need to fetch documents - - **[OPTIONAL]** The LLM itself will judge whether or not the question - in isolation - is phrased in such a way that new documents are fetched or whether it is a follow-up question on existing documents. A flag called `fetch_new_documents` is set to indicate whether or not new documents need to be fetched. -3. Documents are fetched from both the vector database (dense) and the BM25 index (sparse). Only executed if `fetch_new_documents` is set. -4. **[OPTIONAL]** Reranking is applied to extract the most relevant documents returned by the previous step. Only executed if `fetch_new_documents` is set. -5. **[OPTIONAL]** The LLM is asked to judge whether or not the documents retrieved contain an accurate answer to the user's query. Only executed if `fetch_new_documents` is set. - - If this is not the case, the LLM is used to rewrite the query with the instruction to optimize for distance based similarity search. This is then fed back into step 3. **but only once** to avoid lengthy or infinite loops. -6. The documents are injected into the prompt with the user query. The documents can come from: - - The retrieval and reranking of the document databases, if `fetch_new_documents` is set. - - The history passed on with the initial user query, if `fetch_new_documents` is **not** set. -7. The LLM is asked to answer the query with the given chat history and documents. -8. The answer, chat history and documents are returned. - -# Configuration -RAG Me Up uses a `.env` file for configuration, see `.env.template`. The following fields can be configured: - -## LLM configuration -- `llm_model` This is the main LLM (instruct or chat) model to use that you will converse with. Default is LLaMa3-8B -- `llm_assistant_token` This should contain the unique query (sub)string that indicates where in a prompt template the assistant's answer starts -- `embedding_model` The model used to convert your documents' chunks into vectors that will be stored in the vector store -- `trust_remote_code` Set this to true if your LLM needs to execute remote code -- `force_cpu` When set to True, forces RAG Me Up to run fully on CPU (not recommended) - -### Use OpenAI -If you want to use OpenAI as LLM backend, make sure to set `use_openai` to True and make sure you (externally) set the environment variable `OPENAI_API_KEY` to be your OpenAI API Key. - -### Use Gemini -If you want to use Gemini as LLM backend, make sure to set `use_gemini` to True and make sure you (externally) set the environment variable `GOOGLE_API_KEY` to be your Gemini API Key. - -### Use Azure OpenAI -If you want to use Azure OpenAI as LLM backend, make sure to set `use_azure` to True and make sure you (externally) set the following environment variables: -- `AZURE_OPENAI_API_KEY` -- `AZURE_OPENAI_API_VERSION` -- `AZURE_OPENAI_ENDPOINT` -- `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME` - -## Use Ollama -If you want to use Ollama as LLM backend, make sure to install Ollama and set `use_ollama` to True. The model to use should be given in `ollama_model`. - -## RAG Provenance -One of the biggest, arguably unsolved, challenges of RAG is to do good provenance attribution: tracking which of the source documents retrieved from your database led to the LLM generating its answer (the most). RAG Me Up implements several ways of achieving this, each with its own pros and cons. - -The following environment variables can be set for provenance attribution. - -- `provenance_method` Can be one of `rerank, attention, similarity, llm`. If `rerank` is `False` and the value of `provenance_method` is either `rerank` or none of the allowed values, provenance attribution is turned completely off -- `provenance_similarity_llm` If `provenance_method` is set to `similarity`, this model will be used to compute the similarity scores -- `provenance_include_query` Set to True or False to include the query itself when attributing provenance -- `provenance_llm_prompt` If `provenance_method` is set to `llm`, this prompt will be used to let the LLM attribute the provenance of each document in isolation. - -The different provenance attribution metrics are described below. - -### `provenance_method=rerank` (preferred for closed LLMs) -This uses the reranker as the provenance method. While the reranking is already used when retrieving documents (if reranking is turned on), this only applies the rerankers cross-attention to the documents and the *query*. For provenance attribution, we use the same reranking to apply cross-attention to the *answer* (and potentially the query too). - -### `provenance_method=attention` (preferred for OS LLMs) -This is probably the most accurate way of tracking provenance but it can only be used with OS LLMs that allow to return the attention weights. The way we track provenance is by looking at the actual attention weights (of the last attention layer in the model) for each token from the answer to the document and vice versa, optionally we do the same for the query if `provenance_include_query=True`. - -### `provenance_method=similarity` -This method uses a sentence transformer (LM) to get dense vectors for each document as well as for the answer (and potentially query). We then use a cosine similarity to get the similarity of the document vectors to the answer (+ query). - -### `provenance_method=llm` -The LLM that is used to generate messages is now also used to attribute the provenance of each document in isolation. We use the `provenance_llm_prompt` as the prompt to ask the LLM to perform this task. Note that the outcome of this provenance method is highly influenced by the prompt and the strength of the model. As a good practice, make sure you force the LLM to return numbers on a relatively small scale (eg. score from 1 to 3). Using something like a percentage for each document will likely result in random outcomes. - -## Data configuration -- `data_directory` The directory that contains your (initial) documents to load into the vector store -- `file_types` Comma-separated list of file types to load. Supported file types: `PDF, JSON, DOCX, XSLX, PPTX, CSV, XML` -- `json_schema` If you are loading JSON, this should be the schema (using `jq_schema`). For example, use `.` for the root of a JSON object if your data contains JSON objects only and `.[0]` for the first element in each JSON array if your data contains JSON arrays with one JSON object in them -- `json_text_content` Whether or not the JSON data should be loaded as textual content or as structured content (in case of a JSON object) -- `xml_xpath` If you are loading XML, this should be the XPath of the documents to load (the tags that contain your text) - -## Retrieval configuration -- `vector_store_uri` RAG Me Up caches your vector store on disk if possible to make loading a next time faster. This is the location where the vector store is stored. Remove this file to force a reload of all your documents -- `vector_store_k` The number of documents to retrieve from the vector store -- `rerank` Set to either True or False to enable reranking -- `rerank_k` The number of documents to keep after reranking. Note that if you use reranking, this should be your final target for `k` and `vector_store_k` should be set (significantly) higher. For example, set `vector_store_k` to 10 and `rerank_k` to 3 -- `rerank_model` The cross encoder reranking retrieval model to use. Sensible defaults are `cross-encoder/ms-marco-TinyBERT-L-2-v2` for speed and `colbert-ir/colbertv2.0` for accuracy (`antoinelouis/colbert-xm` for multilingual). Set this value to `flashrank` to use the FlashrankReranker. - -## LLM parameters -- `temperature` The chat LLM's temperature. Increase this to create more diverse answers -- `repetition_penalty` The penalty for repeating outputs in the chat answers. Some models are very sensitive to this parameter and need a value bigger than 1.0 (penalty) while others benefit from inversing it (lower than 1.0) -- `max_new_tokens` This caps how much tokens the LLM can generate in its answer. More tokens means slower throughput and more memory usage - -## Prompt configuration -- `rag_instruction` An instruction message for the LLM to let it know what to do. Should include a mentioning of it performing RAG and that documents will be given as input context to generate the answer from. -- `rag_question_initial` The initial question prompt that will be given to the LLM only for the first question a user asks, that is, without chat history -- `rag_question_followup` This is a follow-up question the user is asking. While the context resulting from the prompt will be populated by RAG from the vector store, if chat history is present, this prompt will be used instead of `rag_question_initial` - -### Document retrieval -- `rag_fetch_new_instruction` RAG Me Up automatically determines whether or not new documents should be fetched from the vector store or whether the user is asking a follow-up question on the already fetched documents by leveraging the same LLM that is used for chat. This environment variable determines the prompt to use to make this decision. Be very sure to instruct your LLM to answer with yes or no only and make sure your LLM is capable enough to follow this instruction -- `rag_fetch_new_question` The question prompt used in conjunction with `rag_fetch_new_instruction` to decide if new documents should be fetched or not - -### Rewriting (self-inflection) -- `user_rewrite_loop` Set to either True or False to enable the rewriting of the initial query. Note that a rewrite will always occur at most once -- `rewrite_query_instruction` This is the instruction of the prompt that is used to ask the LLM to judge whether a rewrite is necessary or not. Make sure you force the LLM to answer with yes or no only -- `rewrite_query_question` This is the actual query part of the prompt that isued to ask the LLM to judge a rewrite -- `rewrite_query_prompt` If the rewrite loop is on and the LLM judges a rewrite is required, this is the instruction with question asked to the LLM to rewrite the user query into a phrasing more optimized for RAG. Make sure to instruct your model adequately. - -### Re2 -- `use_re2` Set to either True or False to enable [Re2 (Re-reading)](https://arxiv.org/abs/2309.06275) which repeats the question, generally improving the quality of the answer generated by the LLM. -- `re2_prompt` The prompt used in between the question and the repeated question to signal that we are re-asking. - -## Document splitting configuration -- `splitter` The Langchain document splitter to use. Supported splitters are `RecursiveCharacterTextSplitter` and `SemanticChunker`. -- `chunk_size` The chunk size to use when splitting up documents for `RecursiveCharacterTextSplitter` -- `chunk_overlap` The chunk overlap for `RecursiveCharacterTextSplitter` -- `breakpoint_threshold_type` Sets the breakpoint threshold type when using the `SemanticChunker` ([see here](https://python.langchain.com/v0.2/docs/how_to/semantic-chunker/)). Can be one of: percentile, standard_deviation, interquartile, gradient -- `breakpoint_threshold_amount` The amount to use for the threshold type, in float. Set to `None` to leave default -- `number_of_chunks` The number of chunks to use for the threshold type, in int. Set to `None` to leave default - -# Evaluation -While RAG evaluation is difficult and subjective to begin with, frameworks such as [Ragas](https://docs.ragas.io/en/stable/) can give some metrics as to how well your RAG pipeline and its prompts are working, allowing us to benchmark one approach over the other quantitatively. - -RAG Me Up uses Ragas to evaluate your pipeline. You can run an evaluation based on your `.env` using `python Ragas_eval.py`. The following configuration parameters can be set for evaluation: - -- `ragas_sample_size` The amount of document (chunks) to use in evaluation. These are sampled from your data directory after chunking. -- `ragas_qa_pairs` Ragas works upon questions and ground truth answers. The amount of such pairs to create based on the sampled document chunks is set by this parameter. -- `ragas_question_instruction` The instruction prompt used to generate the questions of the Ragas input pairs. -- `ragas_question_query` The query prompt used to generate the questions of the Ragas input pairs. -- `ragas_answer_instruction` The instruction prompt used to generate the answers of the Ragas input pairs. -- `ragas_answer_query` The query prompt used to generate the answers of the Ragas input pairs. - -# Funding -We are actively looking for funding to democratize AI and advance its applications. Contact us at info@commandos.ai if you want to invest. diff --git a/neo4j/README.md b/neo4j/README.md new file mode 100644 index 0000000..38d3481 --- /dev/null +++ b/neo4j/README.md @@ -0,0 +1,115 @@ +# GraphRAG: Extending RAG Me Up with Neo4j Graph Integration + +- Neo4j integration with RAG Me Up to store and retrieve data via graph queries. +- The RagHelperCloud configuration orchestrates retrieval and LLM processing. +- Tested using Neo4j Desktop & GEMINI. + +Key additions: + +1. Separate Neo4j server with REST endpoints +2. Graph-based CSV and PDF loaders +3. Graph-based retrieval + +# How GraphRAG Works + +### Graph-Based Retrieval + +The `graph_retriever` function queries a Neo4j database with **schema-aware Cypher**. Key points: + +1. **Neo4j Schema Integration** + + - Dynamically retrieves the schema via `/schema`. + - Formats it into a prompt-friendly representation. + +2. **LLM-Driven Query Generation** + + - Combines schema info + user queries to form schema-aware prompts. + - Uses few-shot learning to guide Cypher generation. + +3. **Query Execution & Data Retrieval** + + - Executes LLM-generated queries via `/query`. + - Converts results into LangChain `Document` objects (with metadata). + +4. **Fallback Mechanism** + + - If no valid query applies, returns `None` to skip redundant computation. + +5. **Integration with Other Retrievers** + - Graph-based documents are prioritized as “document 0.” + - Remaining “slots” (based on `chunk_size`) are filled by other retrievers. + +### Graph-Based Document Uploading + +GraphRAG provides two main functions for adding data to Neo4j: + +#### 1. `add_csv_to_graphdb` + +1. **Reads the CSV** + + - Uses Python’s `csv.DictReader` for parsing. + +2. **Defines the Graph Schema** + + - Converts rows into Cypher queries, e.g.: + ```cypher + MERGE (q:Quote {text: $quoteText}) + MERGE (t:Topic {name: $topicName}) + MERGE (q)-[:IS_PART_OF]->(t) + ``` + +3. **Sends Data to Neo4j** + + - Packages Cypher + parameters into JSON, sent via `/add_instances`. + +4. **Logs Status** + - Tracks uploaded records and server responses. + +**Use Case Example (NPS Feedback)** + +- _Quotes_: Customer responses to “Please tell us more...” +- _Topics_: Automatically derived themes (e.g., “Relationship with contact”). + +#### 2. `add_document_to_graphdb` + +1. **Metadata Identification** + + - Determines file type (PDF, etc.). + +2. **Schema Configuration** + + - Fetches schema dynamically (`dynamic_neo4j_schema = True`) or uses a predefined `.env` schema. + +3. **Triplet Extraction with LLM** + + - Combines content + schema to form Cypher prompts. + - Few-shot examples guide the LLM’s query creation. + +4. **Query Execution** + - Escapes special chars, then sends queries as JSON via `/add_instances`. + - Logs success/failure. + +# Setup & Configuration: + +1. install Neo4j Desktop & run to host a local Neo4j database + +2. set environment variables: + + - ngrok authentication token in `ngrok_token`. + - location of the neo4j desktop uri in `neo4j_location`, if you use neo4j desktop, this is: bolt://localhost:7687 + - your Neo4j username in `neo4j_user` + - your neo4j password `neo4j_password` + +3. Confugure RAGMeUp server + + - run the `neo4j server.py` file and save public ngrok url in `neo4j_location`. + - Set `use_gemini` as True such that `RAGHelperCloud` can use Gemini as LLM. + - `GOOGLE_API_KEY` for Gemini athentication. + - launch `server.py` + +4. OPTIONAL improvements: + +- `max_document_limit` to set the maximum amount of document chunks that will be outputted by the retrieval +- `rag_retrieval_instruction`, `retrieval_few_shot`, `rag_retrieval_question` to improve the LLM prompt +- `neo4j_insert_instruction`, `neo4j_insert_schema`, `neo4j_insert_data_only`, `neo4j_insert_few_shot` to change the document upload LLM prompt +- `dynamic_neo4j_schema`, if set to True it will fetch the schema from the neo4j server, if set to False, it will use the LLM instruction of `neo4j_insert_data_only`. When you define a schema here, the LLM will use the schema. If you leave schema open, the LLM will dynamically create new nodes and relationships. diff --git a/neo4j/server.py b/neo4j/server.py new file mode 100644 index 0000000..45af9fc --- /dev/null +++ b/neo4j/server.py @@ -0,0 +1,196 @@ +import csv +from neo4j import GraphDatabase +from flask import Flask, jsonify, request +from pyngrok import ngrok +import os + +# Define the Graph_whisperer class to interact with Neo4j +class Graph_whisperer: + + def __init__(self, uri, user, password): + self.driver = GraphDatabase.driver(uri, auth=(user, password)) + + def close(self): + self.driver.close() + + def create_instance(self, payload): + with self.driver.session() as session: + return session.execute_write(self._create_instance, payload) + + def add_document(self, payload): + with self.driver.session() as session: + return session.execute_write(self._add_document, payload) + + def get_meta_schema(self): + """ + Retrieve detailed schema information, including node labels, properties, and relationship types. + + Returns: + dict: A detailed schema including labels, properties, and relationship types. + """ + with self.driver.session() as session: + # Retrieve node labels and their properties + nodes_query = """ + MATCH (n) + UNWIND labels(n) AS label + RETURN label, collect(DISTINCT keys(n)) AS properties + """ + node_results = session.run(nodes_query) + nodes = {} + for record in node_results: + label = record["label"] + properties = set() + for prop_list in record["properties"]: + properties.update(prop_list) + nodes[label] = list(properties) + + # Retrieve relationship types and their properties + rels_query = """ + MATCH ()-[r]->() + RETURN type(r) AS type, collect(DISTINCT keys(r)) AS properties + """ + rel_results = session.run(rels_query) + relationships = {} + for record in rel_results: + rel_type = record["type"] + properties = set() + for prop_list in record["properties"]: + properties.update(prop_list) + relationships[rel_type] = list(properties) + + return {"nodes": nodes, "relationships": relationships} + + def run_query(self, query): + """ + Executes a Cypher query against the Neo4j database. + + Args: + query (str): The Cypher query to execute. + + Returns: + list: A list of query results, where each result is a dictionary. + """ + with self.driver.session() as session: + result = session.run(query) + return [record.data() for record in result] + + @staticmethod + def _create_instance(tx, payload): + for instance in payload: + tx.run(instance["query"], instance["parameters"]) + return instance + + @staticmethod + def _add_document(self, csv_file_path): + """ + Loads a CSV file into Neo4j by constructing and executing queries for each row. + + Args: + csv_file_path (str): The path to the CSV file to be loaded. + + Returns: + dict: A summary of the import process, including the number of records processed. + """ + payloads = [] + try: + with open(csv_file_path, mode="r", encoding="utf-8") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + # Construct the payload for each row + payloads.append( + { + "query": "MERGE (q:Quote {text: $quoteText}) " + "MERGE (t:Topic {name: $topicName}) " + "MERGE (q)-[:IS_PART_OF]->(t)", + "parameters": { + "quoteText": row.get("quoteText"), + "topicName": row.get("topicName"), + }, + } + ) + # Execute all queries in the payload + self._create_instance(self, payloads) + return { + "message": f"Successfully loaded {len(payloads)} records into Neo4j." + } + except Exception as e: + return {"error": str(e)} + + +# Initialize Flask app +app = Flask(__name__) + + +neo4j_location = os.getenv('neo4j_location') +neo4j_user = os.getenv('neo4j_user') +neo4j_password = os.getenv('neo4j_password') +# Initialize Neo4j database connection +neo4j_db = Graph_whisperer(neo4j_location, neo4j_user, neo4j_password) + + +@app.route("/add_instances", methods=["POST"]) +def add_instance(): + json_data = request.get_json() + # print(json_data) + try: + # Use the json data to insert directly into Neo4j + insert_result = neo4j_db.create_instance(json_data) + return jsonify({"last inserted instance": insert_result}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/add_csv", methods=["POST"]) +def add_csv(): + json_data = request.get_json() + # print(json_data) + try: + # Use the json data to insert directly into Neo4j + insert_result = neo4j_db.add_document(json_data) + return jsonify({"last inserted instance": insert_result}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + +@app.route("/close_db") +def close_db(): + try: + neo4j_db.close() + return jsonify({"message": "Database connection closed."}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + +@app.route("/schema", methods=["GET"]) +def get_meta_schema(): + try: + schema = neo4j_db.get_meta_schema() + app.logger.info(f"Retrieved schema: {schema}") + return jsonify(schema) + except Exception as e: + app.logger.error(f"Error retrieving schema: {e}") + return jsonify({"error": str(e)}), 500 + +@app.route("/run_query", methods=["POST"]) +def run_query(): + try: + # Extract the Cypher query from the request body + query = request.json.get("query") + if not query: + return jsonify({"error": "No query provided"}), 400 + + # Execute the query + results = neo4j_db.run_query(query) + return jsonify({"results": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + + +if __name__ == "__main__": + # # Set ngrok auth token and expose the app + ngrok_token = os.getenv('ngrok_token') + ngrok.set_auth_token(ngrok_token) # Replace with your actual ngrok auth token + public_url = ngrok.connect(4000) # Expose port 5000 + print(f"ngrok tunnel available at: {public_url}") + + # Start Flask app + app.run(host="0.0.0.0",port=4000) diff --git a/server/.env.template b/server/.env.template index cc0ac93..e121bfe 100644 --- a/server/.env.template +++ b/server/.env.template @@ -23,11 +23,13 @@ The source document that you need to score is the following: {context}" data_directory='data' -file_types="pdf,json,docx,pptx,xlsx,csv,xml" +file_types="pdf,json,docx,pptx,xslx,csv,xml,txt" json_schema="." json_text_content=False xml_xpath="//" +max_document_limit=10 +neo4j_location='URL_to_neo4j_server' vector_store=milvus vector_store_uri='data.db' vector_store_collection=ragmeup_documents @@ -35,6 +37,8 @@ vector_store_sparse_uri=bm25_db.pickle vector_store_initial_load=True vector_store_k=10 document_chunks_pickle=rag_chunks.pickle +file_upload_using_llm=True +dynamic_neo4j_schema=False rerank=True rerank_k=3 rerank_model=flashrank @@ -78,6 +82,7 @@ use_openai=False openai_model_name='gpt-4o-mini' use_gemini=False gemini_model_name='gemini-pro' +GOOGLE_API_KEY='Your_API_key' use_azure=False use_ollama=False ollama_model='llama3.1' @@ -95,4 +100,50 @@ ragas_answer_instruction="You are a digital librarian and need to answer questio {context}" ragas_answer_query="Answer the following question, never give any explanation or other output than the generated article itself: -{question}" \ No newline at end of file +{question}" + +rag_retrieval_instruction="Instruction: You are a graph database query assistant. Based on the graph schema below, generate a Cypher query to search for the answer to the user's question. If the schema does not support the query, respond with 'None'. +Schema: +{schema}" +retrieval_few_shot="Few-shot examples: +Example 1: +User query: \'What topics are available?\' +Output: MATCH (t:Topic) RETURN t.name + +Example 2: +User query: \'What is the size of an elephant?\' +Output: None" +rag_retrieval_question="The user question is: + +{question} + +Please generate a Cypher question to answer, or return None if it does not fit the Schema" + +neo4j_insert_instruction= "You are a Neo4j database assistant. Your task is to generate Cypher queries for inserting data into the Neo4j graph database. Use only the nodes, properties, and relationships specified in the provided schema. Ensure that all generated queries are valid Cypher and JSON format, and conform to the schema. Make a maximum 9 additions. If the input data cannot be mapped to the schema, return 'None' and do not generate any invalid query. " +neo4j_insert_schema= "Instruction: You are tasked with generating Cypher queries to insert data into the Neo4j graph database. Use only the nodes, properties, and relationships defined in the following schema. Ensure the queries are valid and align with the schema. If the input data cannot be mapped to the schema, return 'None'. + +Schema: +{schema} + +Input data: +{data} + +Output: " + + +neo4j_insert_data_only= "Instruction: You are tasked with generating Cypher queries to insert data into the Neo4j graph database. Use only the nodes, properties, and relationships defined in the following schema. Ensure the queries are valid and align with the schema. If the input data cannot be mapped to the schema, return 'None'. + +Schema: +Nodes: +- Topic: name +- Fact: name +Relationships: +- IS_PART_OF: No properties + +Input data: +{data} + +Output: " + +neo4j_insert_few_shot="Few-shot examples:Example 1: Schema: Nodes: - Quote: text - Topic: name Relationships: - IS_PART_OF: None Input data:Course block 4 Pitching Tools you need (all available on Canvas > Files): • A series of short videos on pitching by Nathalie Mangelaars (links available on Canvas) • Pitch Toolkit by Pitch Academy • Example Pitch Deck by Horseplay Ventures Expected deliverables: • Pitch script • Slide deck • A Minimum Viable Product (MVP) (also see here and here) Notes: • You are strongly encouraged to already draft a pitch script and create a preliminary slide deck before the pitch training takes place (i.e., on Wednesday November 22). If you come prepared, then Cyrille van Hoof and Nathalie Mangelaars can focus on important opportunities for improvement instead of starting from scratch, which saves us valuable time. In case you do so, include both your draft and final versions to your portfolio. Output:[{\"query\": \"MERGE (q:Quote {text: $quoteText}) MERGE (t:Topic {name: $topicName}) MERGE (q)-[:IS_PART_OF]->(t)\",\"parameters\": { \"quoteText\": \"Pitch Toolkit by Pitch Academy\",\"topicName\": \"Needed tools\"}},{\"query\": \"MERGE (q:Quote {text: $quoteText}) MERGE (t:Topic {name: $topicName}) MERGE (q)-[:IS_PART_OF]->(t)\",\"parameters\": {\"quoteText\": \"A Minimum Viable Product (MVP) (also see here and here) \",\"topicName\": \"Deliverables\"}}]" + diff --git a/server/RAGHelper.py b/server/RAGHelper.py index b72fd58..7ab5eae 100644 --- a/server/RAGHelper.py +++ b/server/RAGHelper.py @@ -1,18 +1,28 @@ import hashlib import os import pickle +import requests +import base64 +import csv +import json +import re -from langchain.retrievers import (ContextualCompressionRetriever, - EnsembleRetriever) +from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever from langchain.retrievers.document_compressors import FlashrankRerank from langchain_community.cross_encoders import HuggingFaceCrossEncoder -from langchain_community.document_loaders import (CSVLoader, DirectoryLoader, - Docx2txtLoader, JSONLoader, - PyPDFDirectoryLoader, - PyPDFLoader, TextLoader, - UnstructuredExcelLoader, - UnstructuredPowerPointLoader) +from langchain_community.document_loaders import ( + CSVLoader, + DirectoryLoader, + Docx2txtLoader, + JSONLoader, + PyPDFDirectoryLoader, + PyPDFLoader, + TextLoader, + UnstructuredExcelLoader, + UnstructuredPowerPointLoader, +) from langchain_community.retrievers import BM25Retriever +from langchain.prompts import ChatPromptTemplate from langchain_core.documents.base import Document from langchain_experimental.text_splitter import SemanticChunker from langchain_milvus.vectorstores import Milvus @@ -44,14 +54,16 @@ def __init__(self, logger): self.rerank_retriever = None self._batch_size = 1000 # Load environment variables - self.vector_store_sparse_uri = os.getenv('vector_store_sparse_uri') - self.vector_store_uri = os.getenv('vector_store_uri') - self.document_chunks_pickle = os.getenv('document_chunks_pickle') - self.data_dir = os.getenv('data_directory') + self.vector_store_sparse_uri = os.getenv("vector_store_sparse_uri") + self.vector_store_uri = os.getenv("vector_store_uri") + self.document_chunks_pickle = os.getenv("document_chunks_pickle") + self.data_dir = os.getenv("data_directory") self.file_types = os.getenv("file_types").split(",") - self.splitter_type = os.getenv('splitter') + self.splitter_type = os.getenv("splitter") self.vector_store = os.getenv("vector_store") - self.vector_store_initial_load = os.getenv("vector_store_initial_load") == "True" + self.vector_store_initial_load = ( + os.getenv("vector_store_initial_load") == "True" + ) self.rerank = os.getenv("rerank") == "True" self.rerank_model = os.getenv("rerank_model") self.rerank_k = int(os.getenv("rerank_k")) @@ -64,8 +76,13 @@ def __init__(self, logger): self.breakpoint_threshold_type = os.getenv('breakpoint_threshold_type') self.vector_store_collection = os.getenv("vector_store_collection") self.xml_xpath = os.getenv("xml_xpath") - self.json_text_content = os.getenv("json_text _content", "false").lower() == 'true' + self.json_text_content = ( + os.getenv("json_text _content", "false").lower() == "true" + ) self.json_schema = os.getenv("json_schema") + self.neo4j = os.getenv("neo4j_location") + self.add_docs_to_neo4j = os.getenv("file_upload_using_llm") + self.dynamic_neo4j_schema = os.getenv("dynamic_neo4j_schema") == "True" @staticmethod def format_documents(docs): @@ -80,13 +97,17 @@ def format_documents(docs): """ doc_strings = [] for i, doc in enumerate(docs): - metadata_string = ", ".join([f"{md}: {doc.metadata[md]}" for md in doc.metadata.keys()]) - doc_strings.append(f"Document {i} content: {doc.page_content}\nDocument {i} metadata: {metadata_string}") + metadata_string = ", ".join( + [f"{md}: {doc.metadata[md]}" for md in doc.metadata.keys()] + ) + doc_strings.append( + f"Document {i} content: {doc.page_content}\nDocument {i} metadata: {metadata_string}" + ) return "\n\n\n\n".join(doc_strings) def _load_chunked_documents(self): """Loads previously chunked documents from a pickle file.""" - with open(self.document_chunks_pickle, 'rb') as f: + with open(self.document_chunks_pickle, "rb") as f: self.logger.info("Loading chunked documents.") self.chunked_documents = pickle.load(f) @@ -98,10 +119,7 @@ def _load_json_files(self): list: A list of loaded Document objects from JSON files. """ text_content = self.json_text_content - loader_kwargs = { - 'jq_schema': self.json_schema, - 'text_content': text_content - } + loader_kwargs = {"jq_schema": self.json_schema, "text_content": text_content} loader = DirectoryLoader( path=self.data_dir, glob="*.json", @@ -130,12 +148,18 @@ def _load_xml_files(self): newdocs = [] for index, doc in enumerate(xmldocs): try: - xmltree = etree.fromstring(doc.page_content.encode('utf-8')) + xmltree = etree.fromstring(doc.page_content.encode("utf-8")) elements = xmltree.xpath(self.xml_xpath) - elements = [etree.tostring(element, pretty_print=True).decode() for element in elements] + elements = [ + etree.tostring(element, pretty_print=True).decode() + for element in elements + ] metadata = doc.metadata - metadata['index'] = index - newdocs += [Document(page_content=content, metadata=metadata) for content in elements] + metadata["index"] = index + newdocs += [ + Document(page_content=content, metadata=metadata) + for content in elements + ] except Exception as e: self.logger.error(f"Error processing XML document: {e}") return newdocs @@ -171,7 +195,9 @@ def _filter_metadata(docs, filters=None): # Filter metadata for each document for doc in docs: - doc.metadata = {key: doc.metadata.get(key) for key in filters if key in doc.metadata} + doc.metadata = { + key: doc.metadata.get(key) for key in filters if key in doc.metadata + } return docs @@ -248,20 +274,20 @@ def _load_json_document(self, filename): return JSONLoader( file_path=filename, jq_schema=self.json_schema, - text_content=self.json_text_content + text_content=self.json_text_content, ) def _load_document(self, filename): """Load documents from the specified file based on its extension.""" - file_type = filename.lower().split('.')[-1] + file_type = filename.lower().split(".")[-1] loaders = { - 'pdf': PyPDFLoader, - 'json': self._load_json_document, - 'txt': TextLoader, - 'csv': CSVLoader, - 'docx': Docx2txtLoader, - 'xlsx': UnstructuredExcelLoader, - 'pptx': UnstructuredPowerPointLoader + "pdf": PyPDFLoader, + "json": self._load_json_document, + "txt": TextLoader, + "csv": CSVLoader, + "docx": Docx2txtLoader, + "xlsx": UnstructuredExcelLoader, + "pptx": UnstructuredPowerPointLoader, } self.logger.info(f"Loading {file_type} document....") if file_type in loaders: @@ -283,8 +309,20 @@ def _create_recursive_text_splitter(self): length_function=len, keep_separator=True, separators=[ - "\n \n", "\n\n", "\n", ".", "!", "?", " ", - ",", "\u200b", "\uff0c", "\u3001", "\uff0e", "\u3002", "" + "\n \n", + "\n\n", + "\n", + ".", + "!", + "?", + " ", + ",", + "\u200b", + "\uff0c", + "\u3001", + "\uff0e", + "\u3002", + "", ], ) @@ -299,15 +337,15 @@ def _create_semantic_chunker(self): self.embeddings, breakpoint_threshold_type=self.breakpoint_threshold_type, breakpoint_threshold_amount=self.breakpoint_threshold_amount, - number_of_chunks=self.number_of_chunks + number_of_chunks=self.number_of_chunks, ) def _initialize_text_splitter(self): """Initialize the text splitter based on the environment settings.""" self.logger.info(f"Initializing {self.splitter_type} splitter.") - if self.splitter_type == 'RecursiveCharacterTextSplitter': + if self.splitter_type == "RecursiveCharacterTextSplitter": self.text_splitter = self._create_recursive_text_splitter() - elif self.splitter_type == 'SemanticChunker': + elif self.splitter_type == "SemanticChunker": self.text_splitter = self._create_semantic_chunker() def _split_documents(self, docs): @@ -320,8 +358,13 @@ def _split_documents(self, docs): self._initialize_text_splitter() self.logger.info("Chunking document(s).") chunked_documents = [ - Document(page_content=doc.page_content, - metadata={**doc.metadata, 'id': hashlib.md5(doc.page_content.encode()).hexdigest()}) + Document( + page_content=doc.page_content, + metadata={ + **doc.metadata, + "id": hashlib.md5(doc.page_content.encode()).hexdigest(), + }, + ) for doc in self.text_splitter.split_documents(docs) ] return chunked_documents @@ -336,14 +379,15 @@ def _split_and_store_documents(self, docs): self.chunked_documents = self._split_documents(docs) # Store the chunked documents self.logger.info("Storing chunked document(s).") - with open(self.document_chunks_pickle, 'wb') as f: + with open(self.document_chunks_pickle, "wb") as f: pickle.dump(self.chunked_documents, f) def _initialize_milvus(self): """Initializes the Milvus vector store.""" self.logger.info("Setting up Milvus Vector DB.") self.db = Milvus.from_documents( - [], self.embeddings, + [], + self.embeddings, drop_old=not self.vector_store_initial_load, connection_args={"uri": self.vector_store_uri}, collection_name=self.vector_store_collection, @@ -356,7 +400,7 @@ def _initialize_postgres(self): embeddings=self.embeddings, collection_name=self.vector_store_collection, connection=self.vector_store_uri, - use_jsonb=True + use_jsonb=True, ) def _initialize_vector_store(self): @@ -368,14 +412,17 @@ def _initialize_vector_store(self): else: raise ValueError( "Only 'milvus' or 'postgres' are supported as vector stores! Please set vector_store in your " - "environment variables.") + "environment variables." + ) if self.vector_store_initial_load: self.logger.info("Loading data from existing store.") # Add the documents 1 by 1, so we can track progress - with tqdm(total=len(self.chunked_documents), desc="Vectorizing documents") as pbar: + with tqdm( + total=len(self.chunked_documents), desc="Vectorizing documents" + ) as pbar: for i in range(0, len(self.chunked_documents), self._batch_size): # Slice the documents for the current batch - batch = self.chunked_documents[i:i + self._batch_size] + batch = self.chunked_documents[i : i + self._batch_size] # Prepare documents and their IDs for batch insertion documents = [d for d in batch] ids = [d.metadata["id"] for d in batch] @@ -389,19 +436,27 @@ def _initialize_vector_store(self): def _initialize_bm25retriever(self): """Initializes in memory BM25Retriever.""" self.logger.info("Initializing BM25Retriever.") + self.sparse_retriever = BM25Retriever.from_texts( [x.page_content for x in self.chunked_documents], - metadatas=[x.metadata for x in self.chunked_documents] + metadatas=[x.metadata for x in self.chunked_documents], ) def _initialize_postgresbm25retriever(self): """Initializes in memory PostgresBM25Retriever.""" self.logger.info("Initializing PostgresBM25Retriever.") - self.sparse_retriever = PostgresBM25Retriever(connection_uri=self.vector_store_sparse_uri, - table_name="sparse_vectors", k=self.vector_store_k) + self.sparse_retriever = PostgresBM25Retriever( + connection_uri=self.vector_store_sparse_uri, + table_name="sparse_vectors", + k=self.vector_store_k, + ) if self.vector_store_initial_load == "True": - self.logger.info("Loading data from existing store into the PostgresBM25Retriever.") - with tqdm(total=len(self.chunked_documents), desc="Vectorizing documents") as pbar: + self.logger.info( + "Loading data from existing store into the PostgresBM25Retriever." + ) + with tqdm( + total=len(self.chunked_documents), desc="Vectorizing documents" + ) as pbar: for d in self.chunked_documents: self.sparse_retriever.add_documents([d], ids=[d.metadata["id"]]) pbar.update(1) @@ -415,7 +470,8 @@ def _initialize_retrievers(self): else: raise ValueError( "Only 'milvus' or 'postgres' are supported as vector stores! Please set vector_store in your " - "environment variables.") + "environment variables." + ) def _initialize_reranker(self): """Initialize the reranking model based on environment settings.""" @@ -426,7 +482,7 @@ def _initialize_reranker(self): self.logger.info("Setting up the ScoredCrossEncoderReranker.") self.compressor = ScoredCrossEncoderReranker( model=HuggingFaceCrossEncoder(model_name=self.rerank_model), - top_n=self.rerank_k + top_n=self.rerank_k, ) self.logger.info("Setting up the ContextualCompressionRetriever.") self.rerank_retriever = ContextualCompressionRetriever( @@ -439,7 +495,7 @@ def _setup_retrievers(self): # Set up the vector retriever self.logger.info("Setting up the Vector Retriever.") retriever = self.db.as_retriever( - search_type="mmr", search_kwargs={'k': self.vector_store_k} + search_type="mmr", search_kwargs={"k": self.vector_store_k} ) self.logger.info("Setting up the hybrid retriever.") self.ensemble_retriever = EnsembleRetriever( @@ -450,13 +506,13 @@ def _setup_retrievers(self): def _update_chunked_documents(self, new_chunks): """Update the chunked documents list and store them.""" - if self.vector_store == 'milvus': + if self.vector_store == "milvus": if not self.chunked_documents: if os.path.exists(self.document_chunks_pickle): self.logger.info("documents chunk pickle exists, loading it.") self._load_chunked_documents() self.chunked_documents += new_chunks - with open(f"{self.vector_store_uri}_sparse.pickle", 'wb') as f: + with open(f"{self.vector_store_uri}_sparse.pickle", "wb") as f: pickle.dump(self.chunked_documents, f) def _add_to_vector_database(self, new_chunks): @@ -475,24 +531,21 @@ def _add_to_vector_database(self, new_chunks): self._initialize_bm25retriever() # Update full retriever too retriever = self.db.as_retriever( - search_type="mmr", - search_kwargs={'k': self.vector_store_k} + search_type="mmr", search_kwargs={"k": self.vector_store_k} ) self.ensemble_retriever = EnsembleRetriever( - retrievers=[self.sparse_retriever, retriever], - weights=[0.5, 0.5] + retrievers=[self.sparse_retriever, retriever], weights=[0.5, 0.5] ) def _parse_cv(self, doc): """Extract skills from the CV document.""" # Implement your skill extraction logic here return [] - + def _deduplicate_chunks(self): """Ensure there are no duplicate entries in the data.""" - self.chunked_documents = list({ - doc.metadata["id"]: doc for doc in self.chunked_documents - }.values() + self.chunked_documents = list( + {doc.metadata["id"]: doc for doc in self.chunked_documents}.values() ) def load_data(self): @@ -512,6 +565,150 @@ def load_data(self): self._initialize_vector_store() self._setup_retrievers() + def add_csv_to_graphdb(self, filename): + path = os.path.join(self.data_dir, filename) + url_add_instance = f"{self.neo4j}/add_instances" + try: + self.logger.info("Uploading csv instances using json") + with open(path, mode="r", encoding="utf-8") as csvfile: + reader = csv.DictReader(csvfile, delimiter=";") + self.logger.info( + reader.fieldnames + ) # This can be changed to a format you have to follow and then the csv will always upload + payloads = [] + for row in reader: + payloads.append( + { + "query": "MERGE (q:Quote {text: $quoteText}) " + "MERGE (t:Topic {name: $topicName}) " + "MERGE (q)-[:IS_PART_OF]->(t)", + "parameters": { + "quoteText": row.get("quote"), + "topicName": row.get("topics"), + }, + } + ) + self.logger.info(f"JSON is: {payloads}") + self.logger.info(f"URL is: {url_add_instance}") + response = requests.post(url=url_add_instance, json=payloads) + self.logger.info( + f"Succesfully loaded {len(payloads)} records into payloads" + ) + except: + self.logger.info(f"server responded with: {response.text}") + + def get_llm(self): + """Accessor method to get the LLM. Subclasses can override this.""" + return None + + def escape_curly_braces_in_query(self, json_string): + # Function to escape braces in the matched 'query' string + def escape_braces(match): + query_content = match.group(1) + escaped_content = query_content.replace("{", "\\\\{").replace("}", "\\\\}") + return '"query": "' + escaped_content + '"' + + # Regular expression to find 'query' fields + pattern = r'"query":\s*"([^"]*)"' + return re.sub(pattern, escape_braces, json_string) + + def add_document_to_graphdb(self, page_content, metadata): + llm = self.get_llm() + if llm is None: + self.logger.error("LLM is not available in RAGHelper.") + return None + if metadata.get("source").lower().split(".")[-1] == "pdf": + try: + if self.dynamic_neo4j_schema == True: + schema_response = requests.get(url=self.neo4j + "/schema") + if schema_response.status_code != 200: + self.logger.info( + "Failed to retrieve schema from the graph database." + ) + return None + schema = schema_response.json() + # schema = "\n".join([f"{key}: {value}" for key, value in schema.items()]) + + # Construct schema text for the prompt + schema_text = self.format_schema_for_prompt(schema) + + self.logger.info(f"this is the text: {schema_text}") + + retrieval_question = ( + os.getenv("neo4j_insert_schema") + .replace("{schema}", schema_text) + .replace("{data}", page_content) + ) + else: + retrieval_question = os.getenv("neo4j_insert_data_only").replace( + "{data}", page_content + ) + + # Load prompt components from .env + retrieval_instruction = os.getenv("neo4j_insert_instruction") + retrieval_few_shot = os.getenv("neo4j_insert_few_shot") + + retrieval_instruction = retrieval_instruction.replace( + "{", "{{" + ).replace("}", "}}") + retrieval_few_shot = retrieval_few_shot.replace("{", "{{").replace( + "}", "}}" + ) + retrieval_question = retrieval_question.replace("{", "{{").replace( + "}", "}}" + ) + + # Combine into a single prompt + retrieval_thread = [ + ("system", retrieval_instruction + "\n\n" + retrieval_few_shot), + ("human", retrieval_question), + ] + + rag_prompt = ChatPromptTemplate.from_messages(retrieval_thread) + self.logger.info("Initializing retrieval for RAG.") + + # Create an LLM chain + llm_chain = rag_prompt | llm + # Invoke the LLM chain and get the response + try: + llm_response = llm_chain.invoke({}) + # self.logger.info(f"llm response is: {llm_response}") + response_text = self.extract_response_content(llm_response).strip() + self.logger.info(f"The LLM response is: {response_text}") + + # Escape the curly braces in 'query' strings + escaped_data = self.escape_curly_braces_in_query(response_text) + + # Now parse the JSON + try: + json_data = json.loads(escaped_data) + print("Parsed JSON data:", json_data) + except json.JSONDecodeError as e: + print("Error parsing JSON:", e) + + def unescape_curly_braces(json_data): + for item in json_data: + item["query"] = ( + item["query"].replace("\\{", "{").replace("\\}", "}") + ) + return json_data + + json_data = unescape_curly_braces(json_data) + + response = requests.post( + url=self.neo4j + "/add_instances", json=json_data + ) + self.logger.info(f"{response}") + if response == "": + self.logger.info( + f"Succesfully loaded {len(json_data)} records into payloads" + ) + except Exception as e: + self.logger.error(f"Error during LLM invocation: {e}") + return None + except Exception as e: + self.logger.error(f"Error adding document to the graph database: {e}") + def add_document(self, filename): """ Load documents from various file types, extract metadata, @@ -523,7 +720,13 @@ def add_document(self, filename): Raises: ValueError: If the file type is unsupported. """ + if filename.lower().split(".")[-1] == "csv": + self.add_csv_to_graphdb(filename) new_docs = self._load_document(filename) + self.logger.info("adding documents to graphdb.") + if self.add_docs_to_neo4j: + for doc in new_docs: + self.add_document_to_graphdb(doc.page_content, doc.metadata) self.logger.info("chunking the documents.") new_chunks = self._split_documents(new_docs) diff --git a/server/RAGHelper_cloud.py b/server/RAGHelper_cloud.py index 75b76a1..c1a440d 100644 --- a/server/RAGHelper_cloud.py +++ b/server/RAGHelper_cloud.py @@ -2,16 +2,20 @@ import re from langchain.prompts import ChatPromptTemplate -from langchain.schema.runnable import RunnablePassthrough +from langchain.schema.runnable import RunnablePassthrough, RunnableLambda from langchain_core.output_parsers import StrOutputParser from langchain_google_genai import ChatGoogleGenerativeAI from langchain_huggingface.embeddings import HuggingFaceEmbeddings from langchain_ollama.llms import OllamaLLM from langchain_openai import AzureChatOpenAI, ChatOpenAI -from provenance import (DocumentSimilarityAttribution, - compute_llm_provenance_cloud, - compute_rerank_provenance) +from provenance import ( + DocumentSimilarityAttribution, + compute_llm_provenance_cloud, + compute_rerank_provenance, +) from RAGHelper import RAGHelper +import requests +from langchain.schema import Document def combine_results(inputs: dict) -> dict: @@ -40,6 +44,9 @@ def __init__(self, logger): self.logger = logger self.llm = self.initialize_llm() self.embeddings = self.initialize_embeddings() + self.max_length = int( + os.getenv("max_document_limit", 10) + ) # Default to 10 if not specified # Load the data self.load_data() @@ -47,19 +54,32 @@ def __init__(self, logger): self.initialize_provenance_attribution() self.initialize_rewrite_loops() + def get_llm(self): + return self.llm + def initialize_llm(self): """Initialize the Language Model based on environment configurations.""" if os.getenv("use_openai") == "True": self.logger.info("Initializing OpenAI conversation.") - return ChatOpenAI(model=os.getenv("openai_model_name"), temperature=0, max_tokens=None, timeout=None, - max_retries=2) + return ChatOpenAI( + model=os.getenv("openai_model_name"), + temperature=0, + max_tokens=None, + timeout=None, + max_retries=2, + ) if os.getenv("use_gemini") == "True": self.logger.info("Initializing Gemini conversation.") - return ChatGoogleGenerativeAI(model=os.getenv("gemini_model_name"), convert_system_message_to_human=True) + return ChatGoogleGenerativeAI( + model=os.getenv("gemini_model_name"), + convert_system_message_to_human=True, + ) if os.getenv("use_azure") == "True": self.logger.info("Initializing Azure OpenAI conversation.") - return AzureChatOpenAI(openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"], - azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"]) + return AzureChatOpenAI( + openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"], + azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], + ) if os.getenv("use_ollama") == "True": self.logger.info("Initializing Ollama conversation.") return OllamaLLM(model=os.getenv("ollama_model")) @@ -69,16 +89,24 @@ def initialize_llm(self): def initialize_embeddings(self): """Initialize the embeddings based on the CPU/GPU configuration.""" - embedding_model = os.getenv('embedding_model') - model_kwargs = {'device': 'cpu'} if os.getenv('force_cpu') == "True" else {'device': 'cuda'} - self.logger.info(f"Initializing embedding model {embedding_model} with params {model_kwargs}.") - return HuggingFaceEmbeddings(model_name=embedding_model, model_kwargs=model_kwargs) + embedding_model = os.getenv("embedding_model") + model_kwargs = ( + {"device": "cpu"} + if os.getenv("force_cpu") == "True" + else {"device": "cuda"} + ) + self.logger.info( + f"Initializing embedding model {embedding_model} with params {model_kwargs}." + ) + return HuggingFaceEmbeddings( + model_name=embedding_model, model_kwargs=model_kwargs + ) def initialize_rag_chains(self): """Create the RAG chain for fetching new documents.""" rag_thread = [ - ('system', os.getenv('rag_fetch_new_instruction')), - ('human', os.getenv('rag_fetch_new_question')) + ("system", os.getenv("rag_fetch_new_instruction")), + ("human", os.getenv("rag_fetch_new_question")), ] self.logger.info("Initializing RAG chains for fetching new documents.") rag_prompt = ChatPromptTemplate.from_messages(rag_thread) @@ -99,18 +127,22 @@ def initialize_rewrite_loops(self): def create_rewrite_ask_chain(self): """Create the chain to ask if a rewrite is needed.""" rewrite_ask_thread = [ - ('system', os.getenv('rewrite_query_instruction')), - ('human', os.getenv('rewrite_query_question')) + ("system", os.getenv("rewrite_query_instruction")), + ("human", os.getenv("rewrite_query_question")), ] rewrite_ask_prompt = ChatPromptTemplate.from_messages(rewrite_ask_thread) rewrite_ask_llm_chain = rewrite_ask_prompt | self.llm - context_retriever = self.rerank_retriever if self.rerank else self.ensemble_retriever - return {"context": context_retriever | RAGHelper.format_documents, - "question": RunnablePassthrough()} | rewrite_ask_llm_chain + context_retriever = ( + self.rerank_retriever if self.rerank else self.ensemble_retriever + ) + return { + "context": context_retriever | RAGHelper.format_documents, + "question": RunnablePassthrough(), + } | rewrite_ask_llm_chain def create_rewrite_chain(self): """Create the chain to perform the actual rewrite.""" - rewrite_thread = [('human', os.getenv('rewrite_query_prompt'))] + rewrite_thread = [("human", os.getenv("rewrite_query_prompt"))] rewrite_prompt = ChatPromptTemplate.from_messages(rewrite_thread) rewrite_llm_chain = rewrite_prompt | self.llm return {"question": RunnablePassthrough()} | rewrite_llm_chain @@ -129,10 +161,44 @@ def handle_rewrite(self, user_query: str) -> str: self.logger.info(f"The response of the rewrite loop is - {response}") response = self.extract_response_content(response) - if re.sub(r'\W+ ', '', response).lower().startswith('yes'): - return self.extract_response_content(self.rewrite_chain.invoke(user_query)) + if re.sub(r"\W+ ", "", response).lower().startswith("yes"): + return self.extract_response_content( + self.rewrite_chain.invoke(user_query) + ) return user_query + def combine_and_limit_documents(self, graph_docs, retriever_docs, question): + """ + Combines graph documents and retriever documents, limits the total number of documents, + and formats the combined documents for downstream processing. + + Args: + graph_docs (list): Documents retrieved from the graph database. + retriever_docs (list): Documents retrieved from other retrievers. + max_limit (int): Maximum number of documents to include. + format_func (callable): Function to format documents. + question (str): The user query. + + Returns: + dict: A dictionary containing the limited docs, formatted context, and question. + """ + + """This is before formatting, importantly, metadata should at least include: source, id""" + if graph_docs is not None: + length = len(graph_docs[0].page_content) // self.chunk_size + combined_docs = graph_docs + retriever_docs + # Ensure at least one document is retained + retain_count = max(1, self.max_length - length) + combined_docs = combined_docs[:retain_count] + else: + combined_docs = retriever_docs + limited_docs = combined_docs[: self.max_length] + return { + "docs": limited_docs, + "context": RAGHelper.format_documents(limited_docs), + "question": question, + } + def handle_user_interaction(self, user_query: str, history: list) -> tuple: """Handle user interaction by processing their query and maintaining conversation history. @@ -153,19 +219,34 @@ def handle_user_interaction(self, user_query: str, history: list) -> tuple: llm_chain = prompt | self.llm if fetch_new_documents: - context_retriever = self.ensemble_retriever if self.rerank else self.rerank_retriever + graph_retrieved_docs = self.graph_retriever( + user_query + ) # Assume this fetches graph DB docs + context_retriever = ( + self.ensemble_retriever if self.rerank else self.rerank_retriever + ) retriever_chain = { - "docs": context_retriever, - "context": context_retriever | RAGHelper.format_documents, - "question": RunnablePassthrough() - } + "retriever_docs": context_retriever, # Lazy retrieval from context retriever + "question": RunnablePassthrough(), + } | RunnableLambda( + lambda input_data: self.combine_and_limit_documents( + graph_docs=graph_retrieved_docs, + retriever_docs=input_data["retriever_docs"], + question=user_query, + ) + ) llm_chain = prompt | self.llm | StrOutputParser() rag_chain = ( retriever_chain | RunnablePassthrough.assign( answer=lambda x: llm_chain.invoke( - {"docs": x["docs"], "context": x["context"], "question": x["question"]} - )) + { + "docs": x["docs"], + "context": x["context"], + "question": x["question"], + } + ) + ) | combine_results ) else: @@ -174,12 +255,10 @@ def handle_user_interaction(self, user_query: str, history: list) -> tuple: rag_chain = ( retriever_chain | RunnablePassthrough.assign( - answer=lambda x: llm_chain.invoke( - {"question": x["question"]} - )) + answer=lambda x: llm_chain.invoke({"question": x["question"]}) + ) | combine_results ) - user_query = self.handle_rewrite(user_query) # Check if we need to apply Re2 to mention the question twice if os.getenv("use_re2") == "True": @@ -187,9 +266,13 @@ def handle_user_interaction(self, user_query: str, history: list) -> tuple: # Invoke RAG pipeline reply = rag_chain.invoke(user_query) - # Track provenance if needed - if fetch_new_documents and os.getenv("provenance_method") in ['rerank', 'attention', 'similarity', 'llm']: + if fetch_new_documents and os.getenv("provenance_method") in [ + "rerank", + "attention", + "similarity", + "llm", + ]: self.track_provenance(reply, user_query) return (thread, reply) @@ -209,7 +292,7 @@ def should_fetch_new_documents(self, user_query: str, history: list) -> bool: return True response = self.rag_fetch_new_chain.invoke(user_query) response = self.extract_response_content(response) - return re.sub(r'\W+ ', '', response).lower().startswith('yes') + return re.sub(r"\W+ ", "", response).lower().startswith("yes") @staticmethod def create_interaction_thread(history: list, fetch_new_documents: bool) -> list: @@ -224,11 +307,17 @@ def create_interaction_thread(history: list, fetch_new_documents: bool) -> list: list: The constructed conversation thread. """ # Create prompt template based on whether we have history or not - thread = [(x["role"], x["content"].replace("{", "(").replace("}", ")")) for x in history] + thread = [ + (x["role"], x["content"].replace("{", "(").replace("}", ")")) + for x in history + ] if fetch_new_documents: - thread = [('system', os.getenv('rag_instruction')), ('human', os.getenv('rag_question_initial'))] + thread = [ + ("system", os.getenv("rag_instruction")), + ("human", os.getenv("rag_question_initial")), + ] else: - thread.append(('human', os.getenv('rag_question_followup'))) + thread.append(("human", os.getenv("rag_question_followup"))) return thread def create_rag_chain(self, retriever_chain: dict, llm_chain: str) -> str: @@ -252,8 +341,8 @@ def track_provenance(self, reply: str, user_query: str) -> None: """ # Add the user question and the answer to our thread for provenance computation # Retrieve answer and context - answer = reply.get('answer') - context = reply.get('docs') + answer = reply.get("answer") + context = reply.get("docs") provenance_method = os.getenv("provenance_method") self.logger.info(f"Provenance method: {provenance_method}") @@ -262,37 +351,55 @@ def track_provenance(self, reply: str, user_query: str) -> None: if provenance_method == "rerank": self.logger.info("Using reranking for provenance attribution.") if not self.rerank: - raise ValueError("Provenance attribution is set to rerank but reranking is not enabled. " - "Please choose another method or enable reranking.") + raise ValueError( + "Provenance attribution is set to rerank but reranking is not enabled. " + "Please choose another method or enable reranking." + ) - reranked_docs = compute_rerank_provenance(self.compressor, user_query, context, answer) - self.logger.debug(f"Reranked documents computed: {len(reranked_docs)} docs reranked.") + reranked_docs = compute_rerank_provenance( + self.compressor, user_query, context, answer + ) + self.logger.debug( + f"Reranked documents computed: {len(reranked_docs)} docs reranked." + ) # Build provenance scores based on reranked docs provenance_scores = [] for doc in context: reranked_score = next( - (d.metadata['relevance_score'] for d in reranked_docs if d.page_content == doc.page_content), None) + ( + d.metadata["relevance_score"] + for d in reranked_docs + if d.page_content == doc.page_content + ), + None, + ) if reranked_score is None: - self.logger.warning(f"Document not found in reranked docs: {doc.page_content}") + self.logger.warning( + f"Document not found in reranked docs: {doc.page_content}" + ) provenance_scores.append(reranked_score) self.logger.debug("Provenance scores computed using reranked documents.") # Use similarity-based provenance if method is 'similarity' elif provenance_method == "similarity": self.logger.info("Using similarity-based provenance attribution.") - provenance_scores = self.attributor.compute_similarity(user_query, context, answer) + provenance_scores = self.attributor.compute_similarity( + user_query, context, answer + ) self.logger.debug("Provenance scores computed using similarity method.") # Use LLM-based provenance if method is 'llm' elif provenance_method == "llm": self.logger.info("Using LLM-based provenance attribution.") - provenance_scores = compute_llm_provenance_cloud(self.llm, user_query, context, answer) + provenance_scores = compute_llm_provenance_cloud( + self.llm, user_query, context, answer + ) self.logger.debug("Provenance scores computed using LLM-based method.") # Add provenance scores to documents for i, score in enumerate(provenance_scores): - reply['docs'][i].metadata['provenance'] = score + reply["docs"][i].metadata["provenance"] = score self.logger.debug(f"Provenance score added to doc {i}: {score}") @staticmethod @@ -306,10 +413,155 @@ def extract_response_content(response: dict) -> str: str: The extracted content. """ # return getattr(response, 'content', getattr(response, 'answer', response['answer'])) - if hasattr(response, 'content'): + if hasattr(response, "content"): response = response.content - elif hasattr(response, 'answer'): + elif hasattr(response, "answer"): response = response.answer - elif 'answer' in response: + elif "answer" in response: response = response["answer"] return response + + def graph_retriever(self, user_query): + """ + Retrieves relevant data from the Neo4j graph database using a schema-aware query generated by an LLM. + + Args: + user_query (str): The user-provided query. + + Returns: + list or None: A list of LangChain Document objects if a valid query is generated; None otherwise. + """ + # get schema from graph endpoint + schema_url = f"{self.neo4j}/schema" + response = requests.get(schema_url) + if response.status_code != 200: + self.logger.error(f"Failed to retrieve schema from {schema_url}.") + return None + + schema = response.json() + + # Construct schema text for the prompt + schema_text = self.format_schema_for_prompt(schema) + + # Load prompt components from .env + retrieval_instruction = os.getenv("rag_retrieval_instruction").replace( + "{schema}", schema_text + ) + retrieval_few_shot = os.getenv("retrieval_few_shot") + retrieval_question = os.getenv("rag_retrieval_question").replace( + "{question}", user_query + ) + + # Combine into a single prompt + retrieval_thread = [ + ("system", retrieval_instruction + "\n\n" + retrieval_few_shot), + ("human", retrieval_question), + ] + self.logger.info(f"the retrieval thread LLM is: {retrieval_thread}") + + rag_prompt = ChatPromptTemplate.from_messages(retrieval_thread) + self.logger.info("Initializing retrieval for RAG.") + + # Create an LLM chain + llm_chain = rag_prompt | self.llm + # Invoke the LLM chain and get the response + try: + llm_response = llm_chain.invoke({}) + # self.logger.info(f"llm response is: {llm_response}") + response_text = self.extract_response_content(llm_response).strip() + self.logger.info(f"The LLM response is: {response_text}") + + except Exception as e: + self.logger.error(f"Error during LLM invocation: {e}") + return None + + query_url = f"{self.neo4j}/run_query" + + # if re.sub(r'\W+ ', '', response_text).lower().startswith('None'): + if response.text.startswith("None"): + return None + else: + # Execute the generated Cypher query + try: + query_response = requests.post(query_url, json={"query": response_text}) + if query_response.status_code != 200: + self.logger.error(f"Failed to execute query: {response_text}") + return None + + query_results = query_response.json().get("results", []) + self.logger.info(f"The found query results: {query_results}") + + # Combine all results into a single document + combined_content = "\n".join( + ", ".join(f"{key}: {value}" for key, value in record.items()) + for record in query_results + ) + + # Create a single LangChain Document + document = Document( + page_content=combined_content, metadata={"source": "graph_db"} + ) + + # Log the combined document + self.logger.info(f"The combined document: {document}") + + # Return as a list containing the single combined document + return [document] + + except Exception as e: + self.logger.error(f"Error executing query or formatting results: {e}") + return None + + def format_schema_for_prompt(self, schema): + """ + Formats the schema dictionary into a text string suitable for inclusion in an LLM prompt. + + Args: + schema (dict): The schema dictionary with nodes and relationships. + + Returns: + str: A formatted string representation of the schema. + """ + schema_lines = [] + schema_lines.append("Nodes:") + for label, properties in schema["nodes"].items(): + props = ", ".join(properties) if properties else "No properties" + schema_lines.append(f" - {label}: {props}") + schema_lines.append("\nRelationships:") + for rel_type, properties in schema["relationships"].items(): + props = ", ".join(properties) if properties else "No properties" + schema_lines.append(f" - {rel_type}: {props}") + return "\n".join(schema_lines) + + def generate_few_shot_examples(self, schema): + """ + Dynamically generate few-shot examples based on the schema. + + Args: + schema (dict): The graph schema with nodes and relationships. + + Returns: + list[dict]: A list of input-output example pairs for few-shot prompting. + """ + examples = [] + + # Generate examples for nodes + for label, properties in schema.get("nodes", {}).items(): + example_query = f"What are the {label.lower()}s?" + example_output = f"MATCH (n:{label}) RETURN n" + examples.append({"input": example_query, "output": example_output}) + + # Add a property-specific example + if properties: + prop = properties[0] # Use the first property for the example + example_query = f"Find {label.lower()}s by {prop}." + example_output = f"MATCH (n:{label}) RETURN n.{prop}" + examples.append({"input": example_query, "output": example_output}) + + # Generate examples for relationships + for rel_type, properties in schema.get("relationships", {}).items(): + example_query = f"What are the relationships of type {rel_type.lower()}?" + example_output = f"MATCH ()-[r:{rel_type}]->() RETURN r" + examples.append({"input": example_query, "output": example_output}) + + return examples