Skip to content

Commit

Permalink
Refactor retry logic into call_llm method
Browse files Browse the repository at this point in the history
Incorporated feedback to refactor the retry logic from the `catch_and_convert_errors` context manager into a new `call_llm` method within the `Spice` class. This method handles the retry logic and is now used in place of the context manager in relevant parts of the code. This approach provides a cleaner and more appropriate solution for implementing retries.

Updated the `WrappedOpenAIClient` and `WrappedAnthropicClient` classes to remove retry logic from the `catch_and_convert_errors` method.
  • Loading branch information
mentatai committed Jul 14, 2024
1 parent 8680b69 commit 20c68f6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 57 deletions.
54 changes: 35 additions & 19 deletions spice/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,30 @@ def __init__(
self.logging_callback = logging_callback
self.new_run("spice")

async def call_llm(self, client: WrappedClient, call_args: SpiceCallArgs, streaming_callback: Optional[Callable[[str], None]] = None):
retries = 0
delay = self.base_delay
while retries <= self.max_retries:
try:
with client.catch_and_convert_errors():
if streaming_callback is not None:
stream = await client.get_chat_completion_or_stream(call_args)
stream = cast(AsyncIterator, stream)
streaming_spice_response = StreamingSpiceResponse(
self._get_text_model(call_args.model), call_args, client, stream, None, streaming_callback
)
return await streaming_spice_response.complete_response()
else:
chat_completion = await client.get_chat_completion_or_stream(call_args)
text, input_tokens, output_tokens = client.extract_text_and_tokens(chat_completion, call_args)
return text, input_tokens, output_tokens
except (APIConnectionError, APIError) as e:
if retries == self.max_retries:
raise e
time.sleep(min(delay, self.max_delay))
delay *= 2
retries += 1

def new_run(self, name: str):
"""
Create a new run. All llm calls will be logged in a folder with the run name and a timestamp.
Expand Down Expand Up @@ -459,23 +483,12 @@ async def get_response(
elif i > 1 and call_args.temperature is not None:
call_args.temperature = max(0.5, call_args.temperature)

with client.catch_and_convert_errors(max_retries=self.max_retries, base_delay=self.base_delay, max_delay=self.max_delay):
if streaming_callback is not None:
stream = await client.get_chat_completion_or_stream(call_args)
stream = cast(AsyncIterator, stream)
streaming_spice_response = StreamingSpiceResponse(
text_model, call_args, client, stream, None, streaming_callback
)
chat_completion = await streaming_spice_response.complete_response()
text, input_tokens, output_tokens = (
chat_completion.text,
chat_completion.input_tokens,
chat_completion.output_tokens,
)

else:
chat_completion = await client.get_chat_completion_or_stream(call_args)
text, input_tokens, output_tokens = client.extract_text_and_tokens(chat_completion, call_args)
try:
text, input_tokens, output_tokens = await self.call_llm(client, call_args, streaming_callback)
except (APIConnectionError, APIError) as e:
if i == retries:
raise e
continue

completion_cost = text_request_cost(text_model, input_tokens, output_tokens)
if completion_cost is not None:
Expand Down Expand Up @@ -550,8 +563,11 @@ async def stream_response(
client = self._get_client(text_model, provider)
call_args = self._fix_call_args(messages, text_model, True, temperature, max_tokens, response_format)

with client.catch_and_convert_errors():
stream = await client.get_chat_completion_or_stream(call_args)
try:
stream = await self.call_llm(client, call_args, streaming_callback)
except (APIConnectionError, APIError) as e:
raise e

stream = cast(AsyncIterator, stream)

def callback(response: SpiceResponse, cache: List[float] = [0]):
Expand Down
57 changes: 19 additions & 38 deletions spice/wrapped_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,15 @@ def extract_text_and_tokens(self, chat_completion, call_args: SpiceCallArgs):

@override
@contextmanager
def catch_and_convert_errors(self, max_retries: int = 0, base_delay: float = 1.0, max_delay: float = 32.0):
retries = 0
delay = base_delay
while retries <= max_retries:
try:
yield
return
except openai.APIConnectionError as e:
if retries == max_retries:
raise APIConnectionError(f"OpenAI Connection Error: {e.message}") from e
except openai.AuthenticationError as e:
raise AuthenticationError(f"OpenAI Authentication Error: {e.message}") from e
except openai.APIStatusError as e:
if retries == max_retries:
raise APIError(f"OpenAI Status Error: {e.message}") from e
time.sleep(min(delay, max_delay))
delay *= 2
retries += 1
def catch_and_convert_errors(self):
try:
yield
except openai.APIConnectionError as e:
raise APIConnectionError(f"OpenAI Connection Error: {e.message}") from e
except openai.AuthenticationError as e:
raise AuthenticationError(f"OpenAI Authentication Error: {e.message}") from e
except openai.APIStatusError as e:
raise APIError(f"OpenAI Status Error: {e.message}") from e

def _get_encoding_for_model(self, model: Model | str) -> tiktoken.Encoding:
from spice.models import Model
Expand Down Expand Up @@ -397,26 +388,16 @@ def extract_text_and_tokens(self, chat_completion, call_args: SpiceCallArgs):

@override
@contextmanager
def catch_and_convert_errors(self, max_retries: int = 0, base_delay: float = 1.0, max_delay: float = 32.0):
retries = 0
delay = base_delay
while retries <= max_retries:
try:
yield
return
except anthropic.APIConnectionError as e:
if retries == max_retries:
raise APIConnectionError(f"Anthropic Connection Error: {e.message}") from e
except anthropic.AuthenticationError as e:
raise AuthenticationError(f"Anthropic Authentication Error: {e.message}") from e
except anthropic.APIStatusError as e:
if retries == max_retries:
raise APIError(f"Anthropic Status Error: {e.message}") from e
time.sleep(min(delay, max_delay))
delay *= 2
retries += 1

# Anthropic doesn't give us a way to count tokens, so we just use OpenAI's token counting functions and multiply by a pre-determined multiplier
def catch_and_convert_errors(self):
try:
yield
except anthropic.APIConnectionError as e:
raise APIConnectionError(f"Anthropic Connection Error: {e.message}") from e
except anthropic.AuthenticationError as e:
raise AuthenticationError(f"Anthropic Authentication Error: {e.message}") from e
except anthropic.APIStatusError as e:
raise APIError(f"Anthropic Status Error: {e.message}") from e

class _FakeWrappedOpenAIClient(WrappedOpenAIClient):
def __init__(self):
pass
Expand Down

0 comments on commit 20c68f6

Please sign in to comment.