Skip to content

Commit

Permalink
Added JSON formatting to RAG prompt builder
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewcoole committed Nov 22, 2024
1 parent 430ca61 commit caff400
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,4 @@ metrics.txt
metrics.png
gdrive-oauth.txt
/eval
.tmp/
1 change: 1 addition & 0 deletions data/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
/eidc_rag_testset.csv
/eidc_rag_test_set.csv
/rag-pipeline.yml
/pipeline.yml
1 change: 1 addition & 0 deletions params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ files:
eval-set: data/evaluation_data.csv
metrics: data/metrics.json
eval-plot: data/eval.png
pipeline: data/pipeline.yml
sub-sample: 0 # sample n datasets for testing (0 will use all datasets)
max-length: 0 # truncate longer texts for testing (0 will use all data)
test-set-size: 101 # reduce the size of the test set for faster testing
Expand Down
34 changes: 26 additions & 8 deletions scripts/run_rag_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import shutil
import sys
from argparse import ArgumentParser
Expand All @@ -24,14 +25,29 @@ def build_rag_pipeline(model_name: str, collection_name: str) -> Pipeline:
print("Creating prompt template...")

template = """
Given the following information, answer the question.
You are part of a retrieval augmented generative pipeline.
Your task is to provide an answer to a question based on a given set of retrieved documents.
The retrieved documents will be given in JSON format.
The retrieved documents are chunks of information retrieved from datasets held in the EIDC (Environmental Information Data Centre).
The EIDC is hosted by UKCEH (UK Centre for Ecology and Hydrology).
Your answer should be as faithful as possible to the information provided by the retrieved documents.
Do not use your own knowledge to answer the question, only the information in the retrieved documents.
Do not refer to "retrieved documents" in your answer, instead use phrases like "available information".
Provide a citation to the relevant chunk_id used to generate each part of your answer.
Question: {{query}}
Context:
{% for document in documents %}
{{ document.content }}
{% endfor %}
"retrieved_documents": [{% for document in documents %}
{
content: "{{ document.content }}",
meta: {
dataset_id: "{{ document.meta.id }}",
source: "{{ document.meta.field }}",
chunk_id: "{{ document.id }}"
}
}
{% endfor %}
]
Answer:
"""
Expand All @@ -41,7 +57,7 @@ def build_rag_pipeline(model_name: str, collection_name: str) -> Pipeline:
print(f"Setting up model ({model_name})...")
llm = OllamaGenerator(
model=model_name,
generation_kwargs={"num_ctx": 16384},
generation_kwargs={"num_ctx": 16384, "temperature": 0.0},
url="http://localhost:11434/api/generate",
)

Expand All @@ -60,6 +76,7 @@ def build_rag_pipeline(model_name: str, collection_name: str) -> Pipeline:
rag_pipe.connect("prompt_builder", "llm")

rag_pipe.connect("llm.replies", "answer_builder.replies")
rag_pipe.connect("prompt_builder.prompt", "answer_builder.query")
return rag_pipe


Expand All @@ -68,7 +85,6 @@ def run_query(query: str, pipeline: Pipeline) -> Dict[str, Any]:
{
"retriever": {"query": query},
"prompt_builder": {"query": query},
"answer_builder": {"query": query},
}
)

Expand All @@ -93,6 +109,8 @@ def main(
model: str,
pipeline_file: str,
) -> None:
if os.path.exists(TMP_DOC_PATH):
shutil.rmtree(TMP_DOC_PATH)
shutil.copytree(doc_store_path, TMP_DOC_PATH)

rag_pipe = build_rag_pipeline(model, collection_name)
Expand All @@ -109,7 +127,7 @@ def main(
df["contexts"] = contexts
df.to_csv(ouput_file, index=False)

shutil.rmtree(TMP_DOC_PATH)
# shutil.rmtree(TMP_DOC_PATH)


if __name__ == "__main__":
Expand Down
8 changes: 8 additions & 0 deletions scripts/upload_to_docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
import sys
import uuid
from argparse import ArgumentParser
import logging

__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
import chromadb
from chromadb.utils import embedding_functions
from chromadb.utils.batch_utils import create_batches

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)

def main(
input_file: str, output_path: str, collection_name: str, embedding_model: str
) -> None:
logger.info(f"Uploading data ({input_file}) to chromaDB ({output_path}) in collection {collection_name}.")
if os.path.exists(output_path):
shutil.rmtree(output_path)

Expand All @@ -37,16 +41,20 @@ def main(
collection = client.create_collection(
name=collection_name, embedding_function=func
)

batches = create_batches(
api=client, ids=ids, documents=docs, embeddings=embs, metadatas=metas
)
logger.info(f"Uploading {len(docs)} document(s) to chroma in {len(batches)} batch(es).")
for batch in batches:
collection.add(
documents=batch[3],
metadatas=batch[2],
embeddings=batch[1],
ids=batch[0],
)
docs_in_col = collection.count()
logger.info(f"{docs_in_col} documents(s) are now in the {collection_name} collection")


if __name__ == "__main__":
Expand Down

1 comment on commit caff400

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

context_precision: 0.45696799875829536
context_recall: 0.5150504867529684
answer_relevancy: 0.5303345339043942
answer_correctness: 0.4851447082020083

Please sign in to comment.