Skip to content

Commit

Permalink
Add exponential backoff to get_response
Browse files Browse the repository at this point in the history
  • Loading branch information
jakethekoenig committed Jul 15, 2024
1 parent 6cbef5d commit f2b5581
Showing 1 changed file with 43 additions and 18 deletions.
61 changes: 43 additions & 18 deletions spice/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import glob
import json
import time
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
Expand All @@ -15,7 +16,7 @@
from jinja2 import DictLoader, Environment
from openai.types.chat.completion_create_params import ResponseFormat

from spice.errors import InvalidModelError, UnknownModelError
from spice.errors import APIConnectionError, APIError, InvalidModelError, UnknownModelError
from spice.models import EmbeddingModel, Model, TextModel, TranscriptionModel, get_model_from_name
from spice.providers import Provider, get_provider_from_name
from spice.spice_message import MessagesEncoder, SpiceMessage
Expand Down Expand Up @@ -230,6 +231,9 @@ def __init__(
logging_dir: Optional[Path | str] = None,
logging_callback: Optional[Callable[[SpiceResponse, str, str], None]] = None,
default_temperature: Optional[float] = None,
retry_count: int = 0,
retry_wait: float = 1,
retry_backoff: float = 1.5,
):
"""
Creates a new Spice client.
Expand All @@ -249,6 +253,12 @@ def __init__(
and the name of the call after every call finishes.
default_temperature: The default temperature to use for chat completions if no other temperature is given.
retry_count: The number of times to retry on api failure before propogating. If 0, will not retry.
retry_wait: The time to wait between retries.
retry_backoff: The factor to increase the wait time between retries.
"""

if isinstance(default_text_model, str):
Expand Down Expand Up @@ -278,6 +288,10 @@ def __init__(
self.logging_callback = logging_callback
self.new_run("spice")

self._retry_count = retry_count
self._retry_wait = retry_wait
self._retry_backoff = retry_backoff

def new_run(self, name: str):
"""
Create a new run. All llm calls will be logged in a folder with the run name and a timestamp.
Expand Down Expand Up @@ -451,23 +465,34 @@ async def get_response(
elif i > 1 and call_args.temperature is not None:
call_args.temperature = max(0.5, call_args.temperature)

with client.catch_and_convert_errors():
if streaming_callback is not None:
stream = await client.get_chat_completion_or_stream(call_args)
stream = cast(AsyncIterator, stream)
streaming_spice_response = StreamingSpiceResponse(
text_model, call_args, client, stream, None, streaming_callback
)
chat_completion = await streaming_spice_response.complete_response()
text, input_tokens, output_tokens = (
chat_completion.text,
chat_completion.input_tokens,
chat_completion.output_tokens,
)

else:
chat_completion = await client.get_chat_completion_or_stream(call_args)
text, input_tokens, output_tokens = client.extract_text_and_tokens(chat_completion, call_args)
input_tokens, output_tokens, text = 0, 0, ""
for i in range(self._retry_count + 1):
try:
with client.catch_and_convert_errors():
if streaming_callback is not None:
stream = await client.get_chat_completion_or_stream(call_args)
stream = cast(AsyncIterator, stream)
streaming_spice_response = StreamingSpiceResponse(
text_model, call_args, client, stream, None, streaming_callback
)
chat_completion = await streaming_spice_response.complete_response()
text, input_tokens, output_tokens = (
chat_completion.text,
chat_completion.input_tokens,
chat_completion.output_tokens,
)

else:
chat_completion = await client.get_chat_completion_or_stream(call_args)
text, input_tokens, output_tokens = client.extract_text_and_tokens(
chat_completion, call_args
)
break
except (APIConnectionError, APIError) as e:
if i == self._retry_count:
raise e
else:
time.sleep(self._retry_wait * (self._retry_backoff**i))

completion_cost = text_request_cost(text_model, input_tokens, output_tokens)
if completion_cost is not None:
Expand Down

0 comments on commit f2b5581

Please sign in to comment.