-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
10 changed files
with
230 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Any, AsyncIterator, Callable, Collection, Dict, Generic, List, Optional, TypeVar, cast | ||
|
||
from openai.types.chat.completion_create_params import ResponseFormat | ||
|
||
from spice.spice_message import MessagesEncoder, SpiceMessage | ||
|
||
|
||
@dataclass | ||
class SpiceCallArgs: | ||
model: str | ||
messages: Collection[SpiceMessage] | ||
stream: bool = False | ||
temperature: Optional[float] = None | ||
max_tokens: Optional[int] = None | ||
response_format: Optional[ResponseFormat] = None |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from abc import ABC, abstractmethod | ||
from enum import Enum | ||
from typing import Any, Callable, Generic, Optional, TypeVar | ||
|
||
from spice.spice import SpiceCallArgs | ||
|
||
T = TypeVar("T") | ||
|
||
# An object that defines how get_response should validate/convert a response. It must implement the decide and | ||
# get_retry_name methods. The decide method takes in the previous call_args, which attempt number this is, and the model | ||
# output. The method returns a tuple of Behavior, SpiceCallArgs, and the result the result and the name of the run. If | ||
# the Behavior is RETURN, then the result will be returned as the result of the SpiceResponse object. If the Behavior | ||
# is RETRY, then the llm will be called again with the new spice_args. It's up to the Behavior to eventually return | ||
# RETURN or throw an exception. | ||
|
||
|
||
class Behavior(Enum): | ||
RETRY = "retry" | ||
RETURN = "return" | ||
|
||
|
||
class RetryStrategy(ABC, Generic[T]): | ||
@abstractmethod | ||
def decide( | ||
self, call_args: SpiceCallArgs, attempt_number: int, model_output: str, name: str | ||
) -> tuple[Behavior, SpiceCallArgs, T, str]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import dataclasses | ||
from typing import Any, Callable, Optional | ||
|
||
from spice.retry_strategy import Behavior, RetryStrategy | ||
from spice.spice import SpiceCallArgs | ||
from spice.spice_message import assistant_message, user_message | ||
|
||
|
||
def default_failure_message(message: str) -> str: | ||
return f"Failed to convert response for the following reason: {message}\n\nPlease try again." | ||
|
||
|
||
class ConverterStrategy(RetryStrategy): | ||
def __init__( | ||
self, | ||
converter: Callable[[str], Any], | ||
retries: int = 0, | ||
render_failure_message: Callable[[str], str] = default_failure_message, | ||
): | ||
self.converter = converter | ||
self.retries = retries | ||
self.render_failure_message = render_failure_message | ||
|
||
def decide( | ||
self, call_args: SpiceCallArgs, attempt_number: int, model_output: str, name: str | ||
) -> tuple[Behavior, SpiceCallArgs, Any, str]: | ||
try: | ||
result = self.converter(model_output) | ||
return Behavior.RETURN, call_args, result, name | ||
except Exception as e: | ||
if attempt_number < self.retries: | ||
messages = list(call_args.messages) | ||
messages.append(assistant_message(model_output)) | ||
messages.append(user_message(self.render_failure_message(str(e)))) | ||
call_args = dataclasses.replace(call_args, messages=messages) | ||
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail" | ||
else: | ||
raise ValueError("Failed to get a valid response after all retries") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import dataclasses | ||
from abc import ABC, abstractmethod | ||
from enum import Enum | ||
from typing import Any, Callable, Optional, TypeVar | ||
|
||
from spice.retry_strategy import Behavior, RetryStrategy, T | ||
from spice.spice import SpiceCallArgs | ||
|
||
|
||
class DefaultRetryStrategy(RetryStrategy): | ||
def __init__( | ||
self, validator: Optional[Callable[[str], bool]] = None, converter: Callable[[str], T] = str, retries: int = 0 | ||
): | ||
self.validator = validator | ||
self.converter = converter | ||
self.retries = retries | ||
|
||
def decide( | ||
self, call_args: SpiceCallArgs, attempt_number: int, model_output: str, name: str | ||
) -> tuple[Behavior, SpiceCallArgs, Any, str]: | ||
if attempt_number == 1 and call_args.temperature is not None: | ||
dataclasses.replace(call_args, temperature=max(0.2, call_args.temperature)) | ||
elif attempt_number > 1 and call_args.temperature is not None: | ||
dataclasses.replace(call_args, temperature=max(0.5, call_args.temperature)) | ||
|
||
if self.validator and not self.validator(model_output): | ||
if attempt_number < self.retries: | ||
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail" | ||
else: | ||
raise ValueError("Failed to get a valid response after all retries") | ||
try: | ||
result = self.converter(model_output) | ||
return Behavior.RETURN, call_args, result, name | ||
except Exception: | ||
if attempt_number < self.retries: | ||
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail" | ||
else: | ||
raise ValueError("Failed to get a valid response after all retries") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import dataclasses | ||
from typing import Any, Callable, Optional, Tuple | ||
|
||
from spice.call_args import SpiceCallArgs | ||
from spice.retry_strategy import Behavior, RetryStrategy | ||
from spice.spice_message import assistant_message, user_message | ||
|
||
|
||
def default_failure_message(message: str) -> str: | ||
return f"Failed to validate response for the following reason: {message}\n\nPlease try again." | ||
|
||
|
||
class ValidatorStrategy(RetryStrategy): | ||
""" | ||
Validates the model output and if it fails puts the failure reason in model context. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
validator: Callable[[str], Tuple[bool, str]], | ||
retries: int = 0, | ||
render_failure_message: Callable[[str], str] = default_failure_message, | ||
): | ||
""" | ||
Args: | ||
validator: A function that takes in the model output and returns a tuple of a boolean and a message. The | ||
boolean indicates if the model output is valid and the message is the reason why it is invalid. | ||
retries: The number of retries to attempt before failing. | ||
render_failure_message: A function that takes in the failure message and returns a string that will be | ||
displayed to the llm. | ||
""" | ||
self.validator = validator | ||
self.retries = retries | ||
self.render_failure_message = render_failure_message | ||
|
||
def decide( | ||
self, call_args: SpiceCallArgs, attempt_number: int, model_output: str, name: str | ||
) -> tuple[Behavior, SpiceCallArgs, Any, str]: | ||
passed, message = self.validator(model_output) | ||
if not passed: | ||
if attempt_number < self.retries: | ||
messages = list(call_args.messages) | ||
messages.append(assistant_message(model_output)) | ||
messages.append(user_message(self.render_failure_message(message))) | ||
call_args = dataclasses.replace(call_args, messages=messages) | ||
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail" | ||
else: | ||
raise ValueError("Failed to get a valid response after all retries") | ||
else: | ||
return Behavior.RETURN, call_args, model_output, name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.