From f2b55817e49cbfc69d14e5a80f54c1a8d38fecd4 Mon Sep 17 00:00:00 2001 From: Jake Koenig Date: Mon, 15 Jul 2024 15:18:22 -0700 Subject: [PATCH] Add exponential backoff to get_response --- spice/spice.py | 61 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/spice/spice.py b/spice/spice.py index 41823db..93b9632 100644 --- a/spice/spice.py +++ b/spice/spice.py @@ -3,6 +3,7 @@ import dataclasses import glob import json +import time from collections import defaultdict from dataclasses import dataclass, field from datetime import datetime @@ -15,7 +16,7 @@ from jinja2 import DictLoader, Environment from openai.types.chat.completion_create_params import ResponseFormat -from spice.errors import InvalidModelError, UnknownModelError +from spice.errors import APIConnectionError, APIError, InvalidModelError, UnknownModelError from spice.models import EmbeddingModel, Model, TextModel, TranscriptionModel, get_model_from_name from spice.providers import Provider, get_provider_from_name from spice.spice_message import MessagesEncoder, SpiceMessage @@ -230,6 +231,9 @@ def __init__( logging_dir: Optional[Path | str] = None, logging_callback: Optional[Callable[[SpiceResponse, str, str], None]] = None, default_temperature: Optional[float] = None, + retry_count: int = 0, + retry_wait: float = 1, + retry_backoff: float = 1.5, ): """ Creates a new Spice client. @@ -249,6 +253,12 @@ def __init__( and the name of the call after every call finishes. default_temperature: The default temperature to use for chat completions if no other temperature is given. + + retry_count: The number of times to retry on api failure before propogating. If 0, will not retry. + + retry_wait: The time to wait between retries. + + retry_backoff: The factor to increase the wait time between retries. """ if isinstance(default_text_model, str): @@ -278,6 +288,10 @@ def __init__( self.logging_callback = logging_callback self.new_run("spice") + self._retry_count = retry_count + self._retry_wait = retry_wait + self._retry_backoff = retry_backoff + def new_run(self, name: str): """ Create a new run. All llm calls will be logged in a folder with the run name and a timestamp. @@ -451,23 +465,34 @@ async def get_response( elif i > 1 and call_args.temperature is not None: call_args.temperature = max(0.5, call_args.temperature) - with client.catch_and_convert_errors(): - if streaming_callback is not None: - stream = await client.get_chat_completion_or_stream(call_args) - stream = cast(AsyncIterator, stream) - streaming_spice_response = StreamingSpiceResponse( - text_model, call_args, client, stream, None, streaming_callback - ) - chat_completion = await streaming_spice_response.complete_response() - text, input_tokens, output_tokens = ( - chat_completion.text, - chat_completion.input_tokens, - chat_completion.output_tokens, - ) - - else: - chat_completion = await client.get_chat_completion_or_stream(call_args) - text, input_tokens, output_tokens = client.extract_text_and_tokens(chat_completion, call_args) + input_tokens, output_tokens, text = 0, 0, "" + for i in range(self._retry_count + 1): + try: + with client.catch_and_convert_errors(): + if streaming_callback is not None: + stream = await client.get_chat_completion_or_stream(call_args) + stream = cast(AsyncIterator, stream) + streaming_spice_response = StreamingSpiceResponse( + text_model, call_args, client, stream, None, streaming_callback + ) + chat_completion = await streaming_spice_response.complete_response() + text, input_tokens, output_tokens = ( + chat_completion.text, + chat_completion.input_tokens, + chat_completion.output_tokens, + ) + + else: + chat_completion = await client.get_chat_completion_or_stream(call_args) + text, input_tokens, output_tokens = client.extract_text_and_tokens( + chat_completion, call_args + ) + break + except (APIConnectionError, APIError) as e: + if i == self._retry_count: + raise e + else: + time.sleep(self._retry_wait * (self._retry_backoff**i)) completion_cost = text_request_cost(text_model, input_tokens, output_tokens) if completion_cost is not None: