From 557363db1177267303608d699af176e4d0dcbf9e Mon Sep 17 00:00:00 2001 From: mpc Date: Thu, 21 Nov 2024 14:11:17 +0000 Subject: [PATCH 1/2] Adds script to run experiments --- .gitignore | 1 + data/.gitignore | 1 + dvc.yaml | 8 ++++---- params.yaml | 5 +++-- pyproject.toml | 1 + run-experiments.sh | 7 +++++++ scripts/chunk_data.py | 28 +++++++++++++++++++--------- scripts/create_embeddings.py | 5 ++++- scripts/evaluate.py | 12 ++++++++---- 9 files changed, 48 insertions(+), 20 deletions(-) create mode 100755 run-experiments.sh diff --git a/.gitignore b/.gitignore index bf560c6..03bcc6a 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ cython_debug/ metrics.txt metrics.png gdrive-oauth.txt +/eval diff --git a/data/.gitignore b/data/.gitignore index 09fbf7e..addcca2 100644 --- a/data/.gitignore +++ b/data/.gitignore @@ -13,3 +13,4 @@ /supporting-docs.json /metrics.json /eval.png +/eidc_rag_testset.csv diff --git a/dvc.yaml b/dvc.yaml index 2473eec..903cb5a 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -1,3 +1,5 @@ +metrics: +- data/metrics.json stages: fetch-metadata: cmd: python scripts/fetch_eidc_metadata.py ${files.metadata} -s ${sub-sample} @@ -20,7 +22,7 @@ stages: outs: - ${files.extracted} chunk-data: - cmd: python scripts/chunk_data.py -o ${files.chunked} -c ${hp.chunk-size} -ol ${hp.overlap} ${files.extracted} ${files.supporting-docs} + cmd: python scripts/chunk_data.py -o ${files.chunked} -c ${hp.chunk-size} -ol ${hp.overlap} ${files.extracted} ${files.supporting-docs} -m ${max-length} deps: - ${files.extracted} - ${files.supporting-docs} @@ -42,7 +44,7 @@ stages: outs: - ${files.doc-store} generate-testset: - cmd: cp data/synthetic-datasets/eidc_rag_test_sample.csv data/ + cmd: head -n 2 data/synthetic-datasets/eidc_rag_test_sample.csv > ${files.test-set} outs: - ${files.test-set} run-rag-pipeline: @@ -61,5 +63,3 @@ stages: outs: - ${files.metrics} - ${files.eval-plot} -metrics: -- ${files.metrics} \ No newline at end of file diff --git a/params.yaml b/params.yaml index edf0085..bfc46af 100644 --- a/params.yaml +++ b/params.yaml @@ -12,11 +12,12 @@ files: chunked: data/chunked_data.json embeddings: data/embeddings.json doc-store: data/chroma-data - test-set: data/eidc_rag_test_sample.csv + test-set: data/eidc_rag_testset.csv eval-set: data/evaluation_data.csv metrics: data/metrics.json eval-plot: data/eval.png -sub-sample: 3 # sample size of 0 will process all data +sub-sample: 0 # sample n datasets for testing (0 will use all datasets) +max-length: 0 # truncate longer texts for testing (0 will use all data) rag: model: llama3.1 prompt: >- diff --git a/pyproject.toml b/pyproject.toml index 3dda280..a56e3f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "ragas == 0.1.10", "nltk == 3.9.1", "nbformat == 4.2.0", + "pygit2 == 1.14.1", ] [project.optional-dependencies] diff --git a/run-experiments.sh b/run-experiments.sh new file mode 100755 index 0000000..43a8d67 --- /dev/null +++ b/run-experiments.sh @@ -0,0 +1,7 @@ +#!/bin/bash +dvc queue remove --all +dvc exp run --queue -S hp.chunk-size=400 -S sub-sample=1 -S max-length=500 +dvc exp run --queue -S hp.chunk-size=600 -S sub-sample=1 -S max-length=500 +dvc queue start +dvc queue status +echo "Run `dvc queue status` to check the state of the experiments" diff --git a/scripts/chunk_data.py b/scripts/chunk_data.py index 7fe672b..673651d 100644 --- a/scripts/chunk_data.py +++ b/scripts/chunk_data.py @@ -3,19 +3,20 @@ from typing import Any, Dict, List -def chunk_value(value: str, chunk_size: int, overlap: int) -> List[str]: +def chunk_value(value: str, chunk_size: int, overlap: int, max_length: int) -> List[str]: chunks = [] start = 0 - while start < len(value): + end = max_length if len(value) > max_length > 0 else len(value) + while start < end: chunks.append(value[start : (start + chunk_size)]) start += chunk_size - overlap return chunks def chunk_metadata_value( - metada_value: str, chunk_size: int, overlap: int + metada_value: str, chunk_size: int, overlap: int, max_length: int ) -> List[Dict[str, Any]]: - chunks = chunk_value(metada_value["value"], chunk_size, overlap) + chunks = chunk_value(metada_value["value"], chunk_size, overlap, max_length) return [ { "chunk": chunks[i], @@ -28,20 +29,20 @@ def chunk_metadata_value( def chunk_metadata_file( - file: str, chunk_size: int, overlap: int + file: str, chunk_size: int, overlap: int, max_length: int ) -> List[Dict[str, str]]: chunked_metadata = [] with open(file) as f: json_data = json.load(f) for metadata in json_data: - chunked_metadata.extend(chunk_metadata_value(metadata, chunk_size, overlap)) + chunked_metadata.extend(chunk_metadata_value(metadata, chunk_size, overlap, max_length)) return chunked_metadata -def main(files: List[str], ouput_file: str, chunk_size: int, overlap: int) -> None: +def main(files: List[str], ouput_file: str, chunk_size: int, overlap: int, max_length: int) -> None: all_chunked_metadata = [] for file in files: - all_chunked_metadata.extend(chunk_metadata_file(file, chunk_size, overlap)) + all_chunked_metadata.extend(chunk_metadata_file(file, chunk_size, overlap, max_length)) with open(ouput_file, "w") as f: json.dump(all_chunked_metadata, f, indent=4) @@ -73,6 +74,15 @@ def main(files: List[str], ouput_file: str, chunk_size: int, overlap: int) -> No nargs="?", const=100, ) + parser.add_argument( + "-m", + "--max_length", + help="""Maximum length of data in characters - meant for truncating large + strings in testing. 0 defaults to all data""", + type=int, + nargs="?", + const=0, + ) args = parser.parse_args() assert args.chunk > args.overlap - main(args.input_files, args.output, args.chunk, args.overlap) + main(args.input_files, args.output, args.chunk, args.overlap, args.max_length) diff --git a/scripts/create_embeddings.py b/scripts/create_embeddings.py index 7aa507c..220eed0 100644 --- a/scripts/create_embeddings.py +++ b/scripts/create_embeddings.py @@ -1,6 +1,7 @@ import json from argparse import ArgumentParser - +import gc +import torch from sentence_transformers import SentenceTransformer from torch import Tensor from tqdm import tqdm @@ -16,6 +17,8 @@ def main(input_file: str, output_file: str) -> None: data = json.load(input) for chunk in tqdm(data): chunk["embedding"] = create_embedding(chunk["chunk"]).tolist() + gc.collect() + torch.cuda.empty_cache() json.dump(data, output) diff --git a/scripts/evaluate.py b/scripts/evaluate.py index c130e96..fc2ffb8 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -1,5 +1,6 @@ import json from argparse import ArgumentParser +from pathlib import Path import nest_asyncio import pandas as pd @@ -44,10 +45,9 @@ def main(eval_dataset: str, metric_output: str, image_output: str) -> None: run_config=RunConfig(max_workers=1), ) result_df = result.to_pandas() - pio.templates.default = "gridon" - fig = go.Figure() - with open(metric_output, "w") as f: + Path(metric_output).parent.mkdir(parents=True, exist_ok=True) + with open(metric_output, "w+") as f: json.dump(result, f) metrics = [ metric @@ -55,6 +55,10 @@ def main(eval_dataset: str, metric_output: str, image_output: str) -> None: if metric not in ["question", "ground_truth", "answer", "contexts"] ] + + pio.templates.default = "gridon" + fig = go.Figure() + for metric in metrics: fig.add_trace( go.Violin( @@ -66,7 +70,7 @@ def main(eval_dataset: str, metric_output: str, image_output: str) -> None: ) ) fig.update_yaxes(range=[-0.02, 1.02]) - with open(image_output, "wb") as f: + with open(image_output, "wb+") as f: f.write(fig.to_image(format="png")) From d6c2eed1485aa979ac8556199832cfec76f389c9 Mon Sep 17 00:00:00 2001 From: mpc Date: Thu, 21 Nov 2024 15:13:54 +0000 Subject: [PATCH 2/2] Modifies shell script to run experiments with various LLMs --- README.md | 26 +++++++- dvc.lock | 112 ++++++++++++++++----------------- dvc.yaml | 4 +- params.yaml | 5 +- run-experiments.sh | 11 +++- scripts/chunk_data.py | 16 +++-- scripts/create_embeddings.py | 3 +- scripts/evaluate.py | 1 - scripts/fetch_eidc_metadata.py | 3 +- scripts/run_rag_pipeline.py | 27 +++++--- 10 files changed, 129 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 414f241..5168631 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,31 @@ data/metrics.json faithfulness 0.75 0.69375 -0.05625 Path Param HEAD workspace Change params.yaml hp.chunk-size 300 1000 700 ``` -## Notes + +It is also possible to compare the results of all experiments: +```shell +dvc exp show --only-changed +``` +Experiments can be remove using (`-A` flag removes all experiment, but individually experiment can be removed using their name or ID): +```shell +dvc exp remove -A +``` +### Experiment Runner +The repository includes a simple shell script that can be used as an experiment runner to test various different models: +```shell +./run-experiments.sh +``` +This will run the dvc pipeline with various different llm model (check the shell scripts for details) and save the results as experiments. + +An experiment for each model defined will be queued and run in the background. To check the status of the experiments: +```shell +dvc queue status +``` +To check the output for an experiment currently running use: +```shell +dvc queue log $EXPERIMENT_NAME +``` +## Other Notes ### DVC and CML Notes on the use of Data Version Control and Continuous Machine Learning: diff --git a/dvc.lock b/dvc.lock index f520496..9464e23 100644 --- a/dvc.lock +++ b/dvc.lock @@ -1,17 +1,17 @@ schema: '2.0' stages: fetch-metadata: - cmd: python scripts/fetch_eidc_metadata.py data/eidc_metadata.json -s 3 + cmd: python scripts/fetch_eidc_metadata.py data/eidc_metadata.json -s 1 deps: - path: scripts/fetch_eidc_metadata.py hash: md5 - md5: a564cb0804b482ef09658f0cb4a0a705 - size: 941 + md5: 82907434d9521996e30014df01bbba8e + size: 964 outs: - path: data/eidc_metadata.json hash: md5 - md5: 068ae066ea08ee369c505c8640481cf6 - size: 125674 + md5: ee850e1b0b28cd55ad7d7b31c81645db + size: 114886 prepare: cmd: python scripts/extract_metadata.py data/eidc_metadata.json data/extracted_metadata.json deps: @@ -33,8 +33,8 @@ stages: deps: - path: data/eidc_metadata.json hash: md5 - md5: 068ae066ea08ee369c505c8640481cf6 - size: 125674 + md5: ee850e1b0b28cd55ad7d7b31c81645db + size: 114886 - path: scripts/extract_metadata.py hash: md5 md5: e66f21369c5106eaaad4476612c6fb5e @@ -42,53 +42,53 @@ stages: outs: - path: data/extracted_metadata.json hash: md5 - md5: e71f887d993834e3bda1eb00e711e724 - size: 7005 + md5: 6870e7ecdde041bc8b62d2759ab745c3 + size: 2381 chunk-data: - cmd: python scripts/chunk_data.py -o data/chunked_data.json -c 500 -ol 100 data/extracted_metadata.json - data/supporting-docs.json + cmd: python scripts/chunk_data.py -o data/chunked_data.json -c 250 -ol 75 data/extracted_metadata.json + data/supporting-docs.json -m 250 deps: - path: data/extracted_metadata.json hash: md5 - md5: e71f887d993834e3bda1eb00e711e724 - size: 7005 + md5: 6870e7ecdde041bc8b62d2759ab745c3 + size: 2381 - path: data/supporting-docs.json hash: md5 - md5: bdab1ea8df4a87aa3d314044eb2eaa0a - size: 188762 + md5: 12837e5cbf10fbd75c6fa476d6423a41 + size: 75646 - path: scripts/chunk_data.py hash: md5 - md5: e8de02d6b14c8fc22533d0becfb7d35d - size: 2198 + md5: 3ad449140b03e1c2904b22a5b401a12e + size: 2705 outs: - path: data/chunked_data.json hash: md5 - md5: a01ff8ed4d429203d6903466d26937ff - size: 320740 + md5: 2bd1ec3c646b46de10f43e87a711ec34 + size: 2576 create-embeddings: cmd: python scripts/create_embeddings.py data/chunked_data.json data/embeddings.json deps: - path: data/chunked_data.json hash: md5 - md5: a01ff8ed4d429203d6903466d26937ff - size: 320740 + md5: 2bd1ec3c646b46de10f43e87a711ec34 + size: 2576 - path: scripts/create_embeddings.py hash: md5 - md5: d9282fc92ed400855c4fc2a290289f14 - size: 867 + md5: fa4627c83a65af2e3ea9b2b749f1b29d + size: 952 outs: - path: data/embeddings.json hash: md5 - md5: 363e3eaf7f8baddf9aa2e83f45f074b1 - size: 4345553 + md5: 84df39fc14944f3834863c56062f42bb + size: 61385 upload-to-docstore: cmd: python scripts/upload_to_docstore.py data/embeddings.json -o data/chroma-data -em all-MiniLM-L6-v2 -c eidc-data deps: - path: data/embeddings.json hash: md5 - md5: 363e3eaf7f8baddf9aa2e83f45f074b1 - size: 4345553 + md5: 84df39fc14944f3834863c56062f42bb + size: 61385 - path: scripts/upload_to_docstore.py hash: md5 md5: 7b9433047ff175d5e6af8d6056caf05b @@ -96,45 +96,45 @@ stages: outs: - path: data/chroma-data hash: md5 - md5: 39b81f6d319a02523fbc356dd667b920.dir - size: 5702372 + md5: c302823e4ac392340c4dea80eff42d41.dir + size: 1872612 nfiles: 5 run-rag-pipeline: - cmd: python scripts/run_rag_pipeline.py data/eidc_rag_test_sample.csv data/evaluation_data.csv - data/chroma-data -c eidc-data + cmd: python scripts/run_rag_pipeline.py -i data/eidc_rag_testset.csv -o data/evaluation_data.csv + -ds data/chroma-data -c eidc-data -m llama3.1 deps: - path: data/chroma-data hash: md5 - md5: 39b81f6d319a02523fbc356dd667b920.dir - size: 5702372 + md5: c302823e4ac392340c4dea80eff42d41.dir + size: 1872612 nfiles: 5 - - path: data/eidc_rag_test_sample.csv + - path: data/eidc_rag_testset.csv hash: md5 - md5: a371d83c5822d256286e80d64d58c3fe - size: 7524 + md5: 946861e99a3d1d5c37e48d6c791145ba + size: 4572 - path: scripts/run_rag_pipeline.py hash: md5 - md5: ea2b8d94ee42499870d925f916982e8a - size: 3781 + md5: a3f803eafc1a73d837a763f04a56924e + size: 3937 outs: - path: data/evaluation_data.csv hash: md5 - md5: a4470a84d2de8b1d04c7d2dfd8b5f807 - size: 9859 + md5: 7697d47129fe7491dfa15c8795ca29fe + size: 3911 generate-testset: - cmd: cp data/synthetic-datasets/eidc_rag_test_sample.csv data/ + cmd: head -n 3 data/synthetic-datasets/eidc_rag_test_sample.csv > data/eidc_rag_testset.csv outs: - - path: data/eidc_rag_test_sample.csv + - path: data/eidc_rag_testset.csv hash: md5 - md5: a371d83c5822d256286e80d64d58c3fe - size: 7524 + md5: 946861e99a3d1d5c37e48d6c791145ba + size: 4572 fetch-supporting-docs: cmd: python scripts/fetch_supporting_docs.py data/eidc_metadata.json data/supporting-docs.json deps: - path: data/eidc_metadata.json hash: md5 - md5: 068ae066ea08ee369c505c8640481cf6 - size: 125674 + md5: ee850e1b0b28cd55ad7d7b31c81645db + size: 114886 - path: scripts/fetch_supporting_docs.py hash: md5 md5: 02b94a2cc7bff711784cbdec3650b618 @@ -142,26 +142,26 @@ stages: outs: - path: data/supporting-docs.json hash: md5 - md5: bdab1ea8df4a87aa3d314044eb2eaa0a - size: 188762 + md5: 12837e5cbf10fbd75c6fa476d6423a41 + size: 75646 evaluate: cmd: python scripts/evaluate.py data/evaluation_data.csv -m data/metrics.json -img data/eval.png deps: - path: data/evaluation_data.csv hash: md5 - md5: a4470a84d2de8b1d04c7d2dfd8b5f807 - size: 9859 + md5: 7697d47129fe7491dfa15c8795ca29fe + size: 3911 - path: scripts/evaluate.py hash: md5 - md5: a9c4c04157007c12c068aacdf5e099a9 - size: 2634 + md5: 4154acf8e74c1d8bcd0b0da72af038e0 + size: 2728 outs: - path: data/eval.png hash: md5 - md5: 981434fb5f4e61ce4288a4431f70bcc1 - size: 67852 + md5: ac93331955c478b2a08ae3a4f081e841 + size: 54427 - path: data/metrics.json hash: md5 - md5: 20efb8ebf0d6908f0ee7b35dbff2e7c7 - size: 242 + md5: 9c25b2a92fa4c5fc59fbb2fdae83d0a2 + size: 225 diff --git a/dvc.yaml b/dvc.yaml index 903cb5a..db54d68 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -44,11 +44,11 @@ stages: outs: - ${files.doc-store} generate-testset: - cmd: head -n 2 data/synthetic-datasets/eidc_rag_test_sample.csv > ${files.test-set} + cmd: head -n ${test-set-size} data/synthetic-datasets/eidc_rag_test_sample.csv > ${files.test-set} outs: - ${files.test-set} run-rag-pipeline: - cmd: python scripts/run_rag_pipeline.py ${files.test-set} ${files.eval-set} ${files.doc-store} -c ${doc-store.collection} + cmd: python scripts/run_rag_pipeline.py -i ${files.test-set} -o ${files.eval-set} -ds ${files.doc-store} -c ${doc-store.collection} -m ${rag.model} deps: - ${files.test-set} - ${files.doc-store} diff --git a/params.yaml b/params.yaml index bfc46af..e679c5f 100644 --- a/params.yaml +++ b/params.yaml @@ -1,6 +1,6 @@ hp: - chunk-size: 500 - overlap: 100 + chunk-size: 250 + overlap: 75 embeddings-model: all-MiniLM-L6-v2 doc-store: collection: eidc-data @@ -18,6 +18,7 @@ files: eval-plot: data/eval.png sub-sample: 0 # sample n datasets for testing (0 will use all datasets) max-length: 0 # truncate longer texts for testing (0 will use all data) +test-set-size: 101 # reduce the size of the test set for faster testing rag: model: llama3.1 prompt: >- diff --git a/run-experiments.sh b/run-experiments.sh index 43a8d67..83849a2 100755 --- a/run-experiments.sh +++ b/run-experiments.sh @@ -1,7 +1,12 @@ #!/bin/bash +NC='\033[0m' +GREEN='\033[0;32m' dvc queue remove --all -dvc exp run --queue -S hp.chunk-size=400 -S sub-sample=1 -S max-length=500 -dvc exp run --queue -S hp.chunk-size=600 -S sub-sample=1 -S max-length=500 +models=("llama3 llama3.1 mistral-nemo") +for model in $models +do + dvc exp run --queue -S rag.model=$model +done dvc queue start dvc queue status -echo "Run `dvc queue status` to check the state of the experiments" +echo -e "Run ${GREEN}dvc queue status${NC} to check the state of the experiments" diff --git a/scripts/chunk_data.py b/scripts/chunk_data.py index 673651d..d2e70d6 100644 --- a/scripts/chunk_data.py +++ b/scripts/chunk_data.py @@ -3,7 +3,9 @@ from typing import Any, Dict, List -def chunk_value(value: str, chunk_size: int, overlap: int, max_length: int) -> List[str]: +def chunk_value( + value: str, chunk_size: int, overlap: int, max_length: int +) -> List[str]: chunks = [] start = 0 end = max_length if len(value) > max_length > 0 else len(value) @@ -35,14 +37,20 @@ def chunk_metadata_file( with open(file) as f: json_data = json.load(f) for metadata in json_data: - chunked_metadata.extend(chunk_metadata_value(metadata, chunk_size, overlap, max_length)) + chunked_metadata.extend( + chunk_metadata_value(metadata, chunk_size, overlap, max_length) + ) return chunked_metadata -def main(files: List[str], ouput_file: str, chunk_size: int, overlap: int, max_length: int) -> None: +def main( + files: List[str], ouput_file: str, chunk_size: int, overlap: int, max_length: int +) -> None: all_chunked_metadata = [] for file in files: - all_chunked_metadata.extend(chunk_metadata_file(file, chunk_size, overlap, max_length)) + all_chunked_metadata.extend( + chunk_metadata_file(file, chunk_size, overlap, max_length) + ) with open(ouput_file, "w") as f: json.dump(all_chunked_metadata, f, indent=4) diff --git a/scripts/create_embeddings.py b/scripts/create_embeddings.py index 220eed0..1ae255e 100644 --- a/scripts/create_embeddings.py +++ b/scripts/create_embeddings.py @@ -1,6 +1,7 @@ +import gc import json from argparse import ArgumentParser -import gc + import torch from sentence_transformers import SentenceTransformer from torch import Tensor diff --git a/scripts/evaluate.py b/scripts/evaluate.py index fc2ffb8..fbe348f 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -55,7 +55,6 @@ def main(eval_dataset: str, metric_output: str, image_output: str) -> None: if metric not in ["question", "ground_truth", "answer", "contexts"] ] - pio.templates.default = "gridon" fig = go.Figure() diff --git a/scripts/fetch_eidc_metadata.py b/scripts/fetch_eidc_metadata.py index 5e883d9..c53220c 100644 --- a/scripts/fetch_eidc_metadata.py +++ b/scripts/fetch_eidc_metadata.py @@ -17,7 +17,8 @@ def main(output_file: str, sample: int) -> None: }, ) json_data = res.json() - json_data["results"] = json_data["results"][:sample] + if sample > 0: + json_data["results"] = json_data["results"][:sample] with open(output_file, "w") as f: json.dump(json_data, f, indent=4) diff --git a/scripts/run_rag_pipeline.py b/scripts/run_rag_pipeline.py index 2c620e5..a85154f 100644 --- a/scripts/run_rag_pipeline.py +++ b/scripts/run_rag_pipeline.py @@ -35,8 +35,6 @@ def build_rag_pipeline(model_name: str, collection_name: str) -> Pipeline: prompt_builder = PromptBuilder(template=template) - model_name = "llama3.1" - print(f"Setting up model ({model_name})...") llm = OllamaGenerator( model=model_name, @@ -85,11 +83,15 @@ def query_pipeline(questions: List[str], rag_pipe: Pipeline) -> Tuple[str, List[ def main( - test_data_file: str, ouput_file: str, doc_store_path: str, collection_name: str + test_data_file: str, + ouput_file: str, + doc_store_path: str, + collection_name: str, + model: str, ) -> None: shutil.copytree(doc_store_path, TMP_DOC_PATH) - rag_pipe = build_rag_pipeline("llama3.1", collection_name) + rag_pipe = build_rag_pipeline(model, collection_name) df = pd.read_csv(test_data_file) df.drop(columns=["rating", "contexts"], inplace=True) @@ -106,15 +108,18 @@ def main( if __name__ == "__main__": parser = ArgumentParser("run_rag_pipeline.py") parser.add_argument( - "test_data_file", + "-i", + "--input", help="File containing test queries to generate response from the RAG pipeline.", ) parser.add_argument( - "output_file", + "-o", + "--output", help="File to output results to.", ) parser.add_argument( - "doc_store_path", + "-ds", + "--doc_store", help="Path to the doc store.", ) parser.add_argument( @@ -123,5 +128,11 @@ def main( help="Collection name in doc store.", default="eidc-data", ) + parser.add_argument( + "-m", + "--model", + help="Model to use in RAG pipeline.", + default="llama3.1", + ) args = parser.parse_args() - main(args.test_data_file, args.output_file, args.doc_store_path, args.collection) + main(args.input, args.output, args.doc_store, args.collection, args.model)