Skip to content

Commit

Permalink
Make Retry and Converter strategies
Browse files Browse the repository at this point in the history
Clean up get_response. Light touch ups.
  • Loading branch information
jakethekoenig committed Jul 15, 2024
1 parent 3ca2d67 commit d9f0632
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 123 deletions.
16 changes: 16 additions & 0 deletions spice/call_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Callable, Collection, Dict, Generic, List, Optional, TypeVar, cast

from openai.types.chat.completion_create_params import ResponseFormat

from spice.spice_message import MessagesEncoder, SpiceMessage


@dataclass
class SpiceCallArgs:
model: str
messages: Collection[SpiceMessage]
stream: bool = False
temperature: Optional[float] = None
max_tokens: Optional[int] = None
response_format: Optional[ResponseFormat] = None
29 changes: 0 additions & 29 deletions spice/custom_retry_strategy.py

This file was deleted.

49 changes: 0 additions & 49 deletions spice/retry_strategy.py

This file was deleted.

27 changes: 27 additions & 0 deletions spice/retry_strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Generic, Optional, TypeVar

from spice.spice import SpiceCallArgs

T = TypeVar("T")

# An object that defines how get_response should validate/convert a response. It must implement the decide and
# get_retry_name methods. The decide method takes in the previous call_args, which attempt number this is, and the model
# output. The method returns a tuple of Behavior, SpiceCallArgs, and the result the result and the name of the run. If
# the Behavior is RETURN, then the result will be returned as the result of the SpiceResponse object. If the Behavior
# is RETRY, then the llm will be called again with the new spice_args. It's up to the Behavior to eventually return
# RETURN or throw an exception.


class Behavior(Enum):
RETRY = "retry"
RETURN = "return"


class RetryStrategy(ABC, Generic[T]):
@abstractmethod
def decide(
self, call_args: SpiceCallArgs, attempt_number: int, model_output: str, name: str
) -> tuple[Behavior, SpiceCallArgs, T, str]:
pass
38 changes: 38 additions & 0 deletions spice/retry_strategy/converter_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import dataclasses
from typing import Any, Callable, Optional

from spice.retry_strategy import Behavior, RetryStrategy
from spice.spice import SpiceCallArgs
from spice.spice_message import assistant_message, user_message


def default_failure_message(message: str) -> str:
return f"Failed to convert response for the following reason: {message}\n\nPlease try again."


class ConverterStrategy(RetryStrategy):
def __init__(
self,
converter: Callable[[str], Any],
retries: int = 0,
render_failure_message: Callable[[str], str] = default_failure_message,
):
self.converter = converter
self.retries = retries
self.render_failure_message = render_failure_message

def decide(
self, call_args: SpiceCallArgs, attempt_number: int, model_output: str, name: str
) -> tuple[Behavior, SpiceCallArgs, Any, str]:
try:
result = self.converter(model_output)
return Behavior.RETURN, call_args, result, name
except Exception as e:
if attempt_number < self.retries:
messages = list(call_args.messages)
messages.append(assistant_message(model_output))
messages.append(user_message(self.render_failure_message(str(e))))
call_args = dataclasses.replace(call_args, messages=messages)
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail"
else:
raise ValueError("Failed to get a valid response after all retries")
38 changes: 38 additions & 0 deletions spice/retry_strategy/default_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import dataclasses
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Optional, TypeVar

from spice.retry_strategy import Behavior, RetryStrategy, T
from spice.spice import SpiceCallArgs


class DefaultRetryStrategy(RetryStrategy):
def __init__(
self, validator: Optional[Callable[[str], bool]] = None, converter: Callable[[str], T] = str, retries: int = 0
):
self.validator = validator
self.converter = converter
self.retries = retries

def decide(
self, call_args: SpiceCallArgs, attempt_number: int, model_output: str, name: str
) -> tuple[Behavior, SpiceCallArgs, Any, str]:
if attempt_number == 1 and call_args.temperature is not None:
dataclasses.replace(call_args, temperature=max(0.2, call_args.temperature))
elif attempt_number > 1 and call_args.temperature is not None:
dataclasses.replace(call_args, temperature=max(0.5, call_args.temperature))

if self.validator and not self.validator(model_output):
if attempt_number < self.retries:
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail"
else:
raise ValueError("Failed to get a valid response after all retries")
try:
result = self.converter(model_output)
return Behavior.RETURN, call_args, result, name
except Exception:
if attempt_number < self.retries:
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail"
else:
raise ValueError("Failed to get a valid response after all retries")
51 changes: 51 additions & 0 deletions spice/retry_strategy/validator_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import dataclasses
from typing import Any, Callable, Optional, Tuple

from spice.call_args import SpiceCallArgs
from spice.retry_strategy import Behavior, RetryStrategy
from spice.spice_message import assistant_message, user_message


def default_failure_message(message: str) -> str:
return f"Failed to validate response for the following reason: {message}\n\nPlease try again."


class ValidatorStrategy(RetryStrategy):
"""
Validates the model output and if it fails puts the failure reason in model context.
"""

def __init__(
self,
validator: Callable[[str], Tuple[bool, str]],
retries: int = 0,
render_failure_message: Callable[[str], str] = default_failure_message,
):
"""
Args:
validator: A function that takes in the model output and returns a tuple of a boolean and a message. The
boolean indicates if the model output is valid and the message is the reason why it is invalid.
retries: The number of retries to attempt before failing.
render_failure_message: A function that takes in the failure message and returns a string that will be
displayed to the llm.
"""
self.validator = validator
self.retries = retries
self.render_failure_message = render_failure_message

def decide(
self, call_args: SpiceCallArgs, attempt_number: int, model_output: str, name: str
) -> tuple[Behavior, SpiceCallArgs, Any, str]:
passed, message = self.validator(model_output)
if not passed:
if attempt_number < self.retries:
messages = list(call_args.messages)
messages.append(assistant_message(model_output))
messages.append(user_message(self.render_failure_message(message)))
call_args = dataclasses.replace(call_args, messages=messages)
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail"
else:
raise ValueError("Failed to get a valid response after all retries")
else:
return Behavior.RETURN, call_args, model_output, name
55 changes: 17 additions & 38 deletions spice/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,28 @@
import dataclasses
import glob
import json
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from json import JSONDecodeError
from pathlib import Path
from timeit import default_timer as timer
from typing import Any, AsyncIterator, Callable, Collection, Dict, Generic, List, Optional, TypeVar, Union, cast
from typing import Any, AsyncIterator, Callable, Collection, Dict, Generic, List, Optional, TypeVar, cast

import httpx
from jinja2 import DictLoader, Environment
from openai.types.chat.completion_create_params import ResponseFormat

from spice.call_args import SpiceCallArgs
from spice.errors import InvalidModelError, UnknownModelError
from spice.models import EmbeddingModel, Model, TextModel, TranscriptionModel, get_model_from_name
from spice.providers import Provider, get_provider_from_name
from spice.retry_strategy import DefaultRetryStrategy, RetryStrategy
from spice.retry_strategy import Behavior, RetryStrategy
from spice.retry_strategy.default_strategy import DefaultRetryStrategy
from spice.spice_message import MessagesEncoder, SpiceMessage
from spice.utils import embeddings_request_cost, string_identity, text_request_cost, transcription_request_cost
from spice.wrapped_clients import WrappedClient


@dataclass
class SpiceCallArgs:
model: str
messages: Collection[SpiceMessage]
stream: bool = False
temperature: Optional[float] = None
max_tokens: Optional[int] = None
response_format: Optional[ResponseFormat] = None


T = TypeVar("T")


Expand Down Expand Up @@ -449,17 +438,14 @@ async def get_response(

cost = 0
attempt_number = 0
text_model = self._get_text_model(model)
call_args = self._fix_call_args(
messages, text_model, streaming_callback is not None, temperature, max_tokens, response_format
)
while True:
start_time = timer()
text_model = self._get_text_model(model)
text_model = self._get_text_model(call_args.model)
client = self._get_client(text_model, provider)
call_args = self._fix_call_args(
messages, text_model, streaming_callback is not None, temperature, max_tokens, response_format
)
if attempt_number == 1 and call_args.temperature is not None:
call_args.temperature = max(0.2, call_args.temperature)
elif attempt_number > 1 and call_args.temperature is not None:
call_args.temperature = max(0.5, call_args.temperature)

with client.catch_and_convert_errors():
if streaming_callback is not None:
Expand All @@ -484,26 +470,19 @@ async def get_response(
self._total_cost += completion_cost

end_time = timer()
if name:
retry_name = f"{name}-retry-{attempt_number}-fail"
else:
retry_name = f"retry-{attempt_number}-fail"

behavior, next_call_args, result = retry_strategy.decide(call_args, attempt_number, text)
behavior, next_call_args, result, call_name = retry_strategy.decide(
call_args, attempt_number, text, name or ""
)
response = SpiceResponse(
call_args, text, end_time - start_time, input_tokens, output_tokens, True, cost, _result=result
)
self._log_response(response, call_name)
if behavior == Behavior.RETURN:
response = SpiceResponse(
call_args, text, end_time - start_time, input_tokens, output_tokens, True, cost, _result=result
)
self._log_response(response, name)
return response
else:
response = SpiceResponse(
call_args, text, end_time - start_time, input_tokens, output_tokens, True, cost
)
self._log_response(response, retry_name)
attempt_number += 1

raise ValueError("Failed to get a valid response after all retries")
call_args = next_call_args

async def stream_response(
self,
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ async def convert_string_to_asynciter(
class WrappedTestClient(WrappedClient):
"""
A wrapped client that can be used in tests. Accepts what it should respond with in its constructor.
Stores all calls to get_chat_completion_or_stream in the calls attribute for testing.
"""

def __init__(self, response: str | Iterator[str]):
self.calls = list[SpiceCallArgs]()
if isinstance(response, str):
self.response = iter(response)
else:
Expand All @@ -50,6 +53,7 @@ def __init__(self, response: str | Iterator[str]):
async def get_chat_completion_or_stream(
self, call_args: SpiceCallArgs
) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]:
self.calls.append(call_args)
response = next(self.response)

if call_args.stream:
Expand Down
Loading

0 comments on commit d9f0632

Please sign in to comment.