Skip to content

Commit

Permalink
Merge pull request #36 from epinzur/fix_bugs
Browse files Browse the repository at this point in the history
fixed some bugs
  • Loading branch information
epinzur authored Jun 27, 2024
2 parents 9ec1de6 + 7dc899a commit 09ed7c7
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 48 deletions.
14 changes: 3 additions & 11 deletions ragulate/cli_commands/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,19 @@ def setup_ingest(subparsers):
ingest_parser.add_argument(
"--var-name",
type=str,
help=(
"The name of a variable in the ingest script",
"This should be paired with a `--var-value` argument",
"and can be passed multiple times.",
),
help="The name of a variable in the ingest script. This should be paired with a `--var-value` argument and can be passed multiple times.",
action="append",
)
ingest_parser.add_argument(
"--var-value",
type=str,
help=(
"The value of a variable in the ingest script",
"This should be paired with a `--var-name` argument",
"and can be passed multiple times.",
),
help="The value of a variable in the ingest script. This should be paired with a `--var-name` argument and can be passed multiple times.",
action="append",
)
ingest_parser.add_argument(
"--dataset",
type=str,
help=("The name of a dataset to ingest", "This can be passed multiple times."),
help="The name of a dataset to ingest. This can be passed multiple times.",
action="append",
)
ingest_parser.set_defaults(func=lambda args: call_ingest(**vars(args)))
Expand Down
44 changes: 10 additions & 34 deletions ragulate/cli_commands/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def setup_query(subparsers):
query_parser = subparsers.add_parser("query", help="Run an query pipeline")
query_parser = subparsers.add_parser("query", help="Run a query pipeline")
query_parser.add_argument(
"-n",
"--name",
Expand All @@ -32,68 +32,47 @@ def setup_query(subparsers):
query_parser.add_argument(
"--var-name",
type=str,
help=(
"The name of a variable in the query script",
"This should be paired with a `--var-value` argument",
"and can be passed multiple times.",
),
help="The name of a variable in the query script. This should be paired with a `--var-value` argument and can be passed multiple times.",
action="append",
)
query_parser.add_argument(
"--var-value",
type=str,
help=(
"The value of a variable in the query script",
"This should be paired with a `--var-name` argument",
"and can be passed multiple times.",
),
help="The value of a variable in the query script. This should be paired with a `--var-name` argument and can be passed multiple times.",
action="append",
)
query_parser.add_argument(
"--dataset",
type=str,
help=("The name of a dataset to query", "This can be passed multiple times."),
help="The name of a dataset to query. This can be passed multiple times.",
action="append",
)
query_parser.add_argument(
"--subset",
type=str,
help=(
"The subset of the dataset to query",
"Only valid when a single dataset is passed.",
),
help="The subset of the dataset to query. Only valid when a single dataset is passed.",
action="append",
)
query_parser.add_argument(
"--sample",
type=float,
help=(
"A decimal percentage of the queries to sample for the test",
"Default is 1.0 (100%)",
),
help="A decimal percentage of the queries to sample for the test. Default is 1.0.",
default=1.0,
)
query_parser.add_argument(
"--seed",
type=int,
help=(
"Random seed to use for query sampling",
"Ensures reproducibility of tests",
),
help="Random seed to use for query sampling. Ensures reproducibility of tests.",
)
query_parser.add_argument(
"--restart",
help=(
"Flag to restart the query process instead of resuming.",
"WARNING: this will delete all existing data this query name,",
"not just the data for the tagged datasets.",
),
help="Flag to restart the query process instead of resuming. WARNING: this will delete all existing data for this query name, not just the data for the tagged datasets.",
action="store_true",
)
query_parser.add_argument(
"--provider",
type=str,
help=("The name of the LLM Provider to use for Evaluation."),
help="The name of the LLM Provider to use for Evaluation.",
choices=[
"OpenAI",
"AzureOpenAI",
Expand All @@ -107,10 +86,7 @@ def setup_query(subparsers):
query_parser.add_argument(
"--model",
type=str,
help=(
"The name or id of the LLM model or deployment to use for Evaluation.",
"Generally used in combination with the --provider param.",
),
help="The name or id of the LLM model or deployment to use for Evaluation. Generally used in combination with the `--provider` param.",
)
query_parser.set_defaults(func=lambda args: call_query(**vars(args)))

Expand Down
5 changes: 3 additions & 2 deletions ragulate/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import List

import inflection

from .base_dataset import BaseDataset
from .crag_dataset import CragDataset
Expand All @@ -8,7 +9,7 @@

def find_dataset(name: str) -> BaseDataset:
root_path = "datasets"
name = name.lower()
name = inflection.underscore(name)
for kind in os.listdir(root_path):
kind_path = os.path.join(root_path, kind)
if os.path.isdir(kind_path):
Expand Down
4 changes: 3 additions & 1 deletion ragulate/pipelines/query_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def __init__(
queries = [queries[i] for i in sampled_indices]

# Check for existing records and filter queries
existing_records = self._tru.get_records_and_feedback(app_ids=[dataset])
existing_records = self._tru.get_records_and_feedback(
app_ids=[dataset.name]
)
existing_queries = {record.query for record in existing_records}
queries = [query for query in queries if query not in existing_queries]

Expand Down

0 comments on commit 09ed7c7

Please sign in to comment.