From caff4002a8f04626dabb08ec12d545fb60a8f1c9 Mon Sep 17 00:00:00 2001 From: mpc Date: Fri, 22 Nov 2024 11:51:15 +0000 Subject: [PATCH] Added JSON formatting to RAG prompt builder --- .gitignore | 1 + data/.gitignore | 1 + params.yaml | 1 + scripts/run_rag_pipeline.py | 34 ++++++++++++++++++++++++++-------- scripts/upload_to_docstore.py | 8 ++++++++ 5 files changed, 37 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 03bcc6a..3790941 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,4 @@ metrics.txt metrics.png gdrive-oauth.txt /eval +.tmp/ \ No newline at end of file diff --git a/data/.gitignore b/data/.gitignore index 4a16dfd..b8dc417 100644 --- a/data/.gitignore +++ b/data/.gitignore @@ -16,3 +16,4 @@ /eidc_rag_testset.csv /eidc_rag_test_set.csv /rag-pipeline.yml +/pipeline.yml diff --git a/params.yaml b/params.yaml index e679c5f..b2947b8 100644 --- a/params.yaml +++ b/params.yaml @@ -16,6 +16,7 @@ files: eval-set: data/evaluation_data.csv metrics: data/metrics.json eval-plot: data/eval.png + pipeline: data/pipeline.yml sub-sample: 0 # sample n datasets for testing (0 will use all datasets) max-length: 0 # truncate longer texts for testing (0 will use all data) test-set-size: 101 # reduce the size of the test set for faster testing diff --git a/scripts/run_rag_pipeline.py b/scripts/run_rag_pipeline.py index 886f723..e3a218c 100644 --- a/scripts/run_rag_pipeline.py +++ b/scripts/run_rag_pipeline.py @@ -1,3 +1,4 @@ +import os import shutil import sys from argparse import ArgumentParser @@ -24,14 +25,29 @@ def build_rag_pipeline(model_name: str, collection_name: str) -> Pipeline: print("Creating prompt template...") template = """ - Given the following information, answer the question. + You are part of a retrieval augmented generative pipeline. + Your task is to provide an answer to a question based on a given set of retrieved documents. + The retrieved documents will be given in JSON format. + The retrieved documents are chunks of information retrieved from datasets held in the EIDC (Environmental Information Data Centre). + The EIDC is hosted by UKCEH (UK Centre for Ecology and Hydrology). + Your answer should be as faithful as possible to the information provided by the retrieved documents. + Do not use your own knowledge to answer the question, only the information in the retrieved documents. + Do not refer to "retrieved documents" in your answer, instead use phrases like "available information". + Provide a citation to the relevant chunk_id used to generate each part of your answer. Question: {{query}} - Context: - {% for document in documents %} - {{ document.content }} - {% endfor %} + "retrieved_documents": [{% for document in documents %} + { + content: "{{ document.content }}", + meta: { + dataset_id: "{{ document.meta.id }}", + source: "{{ document.meta.field }}", + chunk_id: "{{ document.id }}" + } + } + {% endfor %} + ] Answer: """ @@ -41,7 +57,7 @@ def build_rag_pipeline(model_name: str, collection_name: str) -> Pipeline: print(f"Setting up model ({model_name})...") llm = OllamaGenerator( model=model_name, - generation_kwargs={"num_ctx": 16384}, + generation_kwargs={"num_ctx": 16384, "temperature": 0.0}, url="http://localhost:11434/api/generate", ) @@ -60,6 +76,7 @@ def build_rag_pipeline(model_name: str, collection_name: str) -> Pipeline: rag_pipe.connect("prompt_builder", "llm") rag_pipe.connect("llm.replies", "answer_builder.replies") + rag_pipe.connect("prompt_builder.prompt", "answer_builder.query") return rag_pipe @@ -68,7 +85,6 @@ def run_query(query: str, pipeline: Pipeline) -> Dict[str, Any]: { "retriever": {"query": query}, "prompt_builder": {"query": query}, - "answer_builder": {"query": query}, } ) @@ -93,6 +109,8 @@ def main( model: str, pipeline_file: str, ) -> None: + if os.path.exists(TMP_DOC_PATH): + shutil.rmtree(TMP_DOC_PATH) shutil.copytree(doc_store_path, TMP_DOC_PATH) rag_pipe = build_rag_pipeline(model, collection_name) @@ -109,7 +127,7 @@ def main( df["contexts"] = contexts df.to_csv(ouput_file, index=False) - shutil.rmtree(TMP_DOC_PATH) + # shutil.rmtree(TMP_DOC_PATH) if __name__ == "__main__": diff --git a/scripts/upload_to_docstore.py b/scripts/upload_to_docstore.py index 545d113..ad016b9 100644 --- a/scripts/upload_to_docstore.py +++ b/scripts/upload_to_docstore.py @@ -4,6 +4,7 @@ import sys import uuid from argparse import ArgumentParser +import logging __import__("pysqlite3") sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") @@ -11,10 +12,13 @@ from chromadb.utils import embedding_functions from chromadb.utils.batch_utils import create_batches +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger(__name__) def main( input_file: str, output_path: str, collection_name: str, embedding_model: str ) -> None: + logger.info(f"Uploading data ({input_file}) to chromaDB ({output_path}) in collection {collection_name}.") if os.path.exists(output_path): shutil.rmtree(output_path) @@ -37,9 +41,11 @@ def main( collection = client.create_collection( name=collection_name, embedding_function=func ) + batches = create_batches( api=client, ids=ids, documents=docs, embeddings=embs, metadatas=metas ) + logger.info(f"Uploading {len(docs)} document(s) to chroma in {len(batches)} batch(es).") for batch in batches: collection.add( documents=batch[3], @@ -47,6 +53,8 @@ def main( embeddings=batch[1], ids=batch[0], ) + docs_in_col = collection.count() + logger.info(f"{docs_in_col} documents(s) are now in the {collection_name} collection") if __name__ == "__main__":