Skip to content

Commit

Permalink
Merge pull request #13 from epinzur/sai/additions
Browse files Browse the repository at this point in the history
Add Support for Other Providers
  • Loading branch information
epinzur authored Jun 27, 2024
2 parents 1bcbe1d + 21047c8 commit 9ec1de6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 9 deletions.
38 changes: 32 additions & 6 deletions ragulate/cli_commands/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def setup_query(subparsers):
)
query_parser.add_argument(
"-s",
"--script_path",
"--script",
type=str,
help="The path to the python script that contains the query method",
required=True,
)
query_parser.add_argument(
"-m",
"--method-name",
"--method",
type=str,
help="The name of the method in the script to run query",
required=True,
Expand Down Expand Up @@ -90,19 +90,43 @@ def setup_query(subparsers):
),
action="store_true",
)
query_parser.add_argument(
"--provider",
type=str,
help=("The name of the LLM Provider to use for Evaluation."),
choices=[
"OpenAI",
"AzureOpenAI",
"Bedrock",
"LiteLLM",
"Langchain",
"Huggingface",
],
default="OpenAI",
)
query_parser.add_argument(
"--model",
type=str,
help=(
"The name or id of the LLM model or deployment to use for Evaluation.",
"Generally used in combination with the --provider param.",
),
)
query_parser.set_defaults(func=lambda args: call_query(**vars(args)))

def call_query(
name: str,
script_path: str,
method_name: str,
script: str,
method: str,
var_name: List[str],
var_value: List[str],
dataset: List[str],
subset: List[str],
sample: float,
seed: int,
restart: bool,
provider: str,
model: str,
**kwargs,
):
if sample <= 0.0 or sample > 1.0:
Expand All @@ -124,12 +148,14 @@ def call_query(

query_pipeline = QueryPipeline(
recipe_name=name,
script_path=script_path,
method_name=method_name,
script_path=script,
method_name=method,
ingredients=ingredients,
datasets=datasets,
sample_percent=sample,
random_seed=seed,
restart_pipeline=restart,
llm_provider=provider,
model_name=model,
)
query_pipeline.query()
37 changes: 34 additions & 3 deletions ragulate/pipelines/query_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@

from tqdm import tqdm
from trulens_eval import Tru, TruChain
from trulens_eval.feedback.provider import OpenAI
from trulens_eval.feedback.provider import (
AzureOpenAI,
Bedrock,
Huggingface,
Langchain,
LiteLLM,
OpenAI,
)
from trulens_eval.feedback.provider.base import LLMProvider
from trulens_eval.schema.feedback import FeedbackMode, FeedbackResultStatus

from ragulate.datasets import BaseDataset
Expand Down Expand Up @@ -47,7 +55,9 @@ def __init__(
datasets: List[BaseDataset],
sample_percent: float = 1.0,
random_seed: Optional[int] = None,
restart_pipeline: bool = False,
restart_pipeline: Optional[bool] = False,
llm_provider: Optional[str] = "OpenAI",
model_name: Optional[str] = None,
**kwargs,
):
super().__init__(
Expand All @@ -61,6 +71,8 @@ def __init__(
self.sample_percent = sample_percent
self.random_seed = random_seed
self.restart_pipeline = restart_pipeline
self.llm_provider = llm_provider
self.model_name = model_name

# Set up the signal handler for SIGINT (Ctrl-C)
signal.signal(signal.SIGINT, self.signal_handler)
Expand Down Expand Up @@ -136,11 +148,30 @@ def update_progress(self, query_change: int = 0):

self._finished_feedbacks = done

def get_provider(self) -> LLMProvider:
provider_name = self.provider_name.lower()
model_name = self.model_name

if provider_name == "openai":
return OpenAI(model_engine=model_name)
elif provider_name == "azureopenai":
return AzureOpenAI(deployment_name=model_name)
elif provider_name == "bedrock":
return Bedrock(model_id=model_name)
elif provider_name == "litellm":
return LiteLLM(model_engine=model_name)
elif provider_name == "Langchain":
return Langchain(model_engine=model_name)
elif provider_name == "huggingface":
return Huggingface(name=model_name)
else:
raise ValueError(f"Unsupported provider: {provider_name}")

def query(self):
query_method = self.get_method()

pipeline = query_method(**self.ingredients)
llm_provider = OpenAI(model_engine="gpt-3.5-turbo")
llm_provider = self.get_provider()

feedbacks = Feedbacks(llm_provider=llm_provider, pipeline=pipeline)

Expand Down

0 comments on commit 9ec1de6

Please sign in to comment.