diff --git a/llama_index/embeddings/openai.py b/llama_index/embeddings/openai.py index 02fce2d5d0c53..370b6faa5465f 100644 --- a/llama_index/embeddings/openai.py +++ b/llama_index/embeddings/openai.py @@ -4,22 +4,24 @@ from typing import Any, Dict, List, Optional, Tuple import openai -from tenacity import ( - retry, - stop_after_attempt, - stop_after_delay, - stop_all, - wait_random_exponential, -) from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks.base import CallbackManager from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding from llama_index.llms.openai_utils import ( + create_retry_decorator, resolve_from_aliases, resolve_openai_credentials, ) +embedding_retry_decorator = create_retry_decorator( + max_retries=6, + random_exponential=True, + stop_after_delay_seconds=60, + min_seconds=1, + max_seconds=20, +) + class OpenAIEmbeddingMode(str, Enum): """OpenAI embedding mode.""" @@ -100,10 +102,7 @@ class OpenAIEmbeddingModeModel(str, Enum): } -@retry( - wait=wait_random_exponential(min=1, max=20), - stop=stop_all(stop_after_attempt(6), stop_after_delay(60)), -) +@embedding_retry_decorator def get_embedding( text: str, engine: Optional[str] = None, **kwargs: Any ) -> List[float]: @@ -123,10 +122,7 @@ def get_embedding( ] -@retry( - wait=wait_random_exponential(min=1, max=20), - stop=stop_all(stop_after_attempt(6), stop_after_delay(60)), -) +@embedding_retry_decorator async def aget_embedding( text: str, engine: Optional[str] = None, **kwargs: Any ) -> List[float]: @@ -146,10 +142,7 @@ async def aget_embedding( ][0]["embedding"] -@retry( - wait=wait_random_exponential(min=1, max=20), - stop=stop_all(stop_after_attempt(6), stop_after_delay(60)), -) +@embedding_retry_decorator def get_embeddings( list_of_text: List[str], engine: Optional[str] = None, **kwargs: Any ) -> List[List[float]]: @@ -170,10 +163,7 @@ def get_embeddings( return [d["embedding"] for d in data] -@retry( - wait=wait_random_exponential(min=1, max=20), - stop=stop_all(stop_after_attempt(6), stop_after_delay(60)), -) +@embedding_retry_decorator async def aget_embeddings( list_of_text: List[str], engine: Optional[str] = None, **kwargs: Any ) -> List[List[float]]: diff --git a/llama_index/llms/openai_utils.py b/llama_index/llms/openai_utils.py index 9920ba5d9f29e..6bcec4f44e8a4 100644 --- a/llama_index/llms/openai_utils.py +++ b/llama_index/llms/openai_utils.py @@ -9,8 +9,11 @@ retry, retry_if_exception_type, stop_after_attempt, + stop_after_delay, wait_exponential, + wait_random_exponential, ) +from tenacity.stop import stop_base from llama_index.bridge.pydantic import BaseModel from llama_index.llms.base import ChatMessage @@ -107,21 +110,37 @@ CompletionClientType = Union[Type[Completion], Type[ChatCompletion]] -def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]: - min_seconds = 4 - max_seconds = 10 - # Wait 2^x * 1 second between each retry starting with - # 4 seconds, then up to 10 seconds, then 10 seconds afterwards +def create_retry_decorator( + max_retries: int, + random_exponential: bool = False, + stop_after_delay_seconds: Optional[float] = None, + min_seconds: float = 4, + max_seconds: float = 10, +) -> Callable[[Any], Any]: + wait_strategy = ( + wait_random_exponential(min=min_seconds, max=max_seconds) + if random_exponential + else wait_exponential(multiplier=1, min=min_seconds, max=max_seconds) + ) + + stop_strategy: stop_base = stop_after_attempt(max_retries) + if stop_after_delay_seconds is not None: + stop_strategy = stop_strategy | stop_after_delay(stop_after_delay_seconds) + return retry( reraise=True, - stop=stop_after_attempt(max_retries), - wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + stop=stop_strategy, + wait=wait_strategy, retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) + retry_if_exception_type( + ( + openai.error.Timeout, + openai.error.APIError, + openai.error.APIConnectionError, + openai.error.RateLimitError, + openai.error.ServiceUnavailableError, + ) + ) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -129,7 +148,7 @@ def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]: def completion_with_retry(is_chat_model: bool, max_retries: int, **kwargs: Any) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(max_retries=max_retries) + retry_decorator = create_retry_decorator(max_retries=max_retries) @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: @@ -143,7 +162,7 @@ async def acompletion_with_retry( is_chat_model: bool, max_retries: int, **kwargs: Any ) -> Any: """Use tenacity to retry the async completion call.""" - retry_decorator = _create_retry_decorator(max_retries=max_retries) + retry_decorator = create_retry_decorator(max_retries=max_retries) @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: diff --git a/tests/llms/test_openai_utils.py b/tests/llms/test_openai_utils.py index b61491efb75e4..ab7e30c203d42 100644 --- a/tests/llms/test_openai_utils.py +++ b/tests/llms/test_openai_utils.py @@ -1,8 +1,10 @@ from typing import List +import openai import pytest from llama_index.llms.base import ChatMessage, MessageRole from llama_index.llms.openai_utils import ( + create_retry_decorator, from_openai_message_dicts, to_openai_message_dicts, ) @@ -134,3 +136,36 @@ def test_from_openai_message_dicts_function_calling_azure( azure_openi_message_dicts_with_function_calling ) assert chat_messages == azure_chat_messages_with_function_calling + + +def test_create_retry_decorator() -> None: + test_retry_decorator = create_retry_decorator( + max_retries=6, + random_exponential=False, + stop_after_delay_seconds=10, + min_seconds=2, + max_seconds=5, + ) + + @test_retry_decorator + def mock_function() -> str: + # Simulate OpenAI API call with potential errors + if mock_function.retry.statistics["attempt_number"] == 1: + raise openai.error.Timeout(message="Timeout error") + elif mock_function.retry.statistics["attempt_number"] == 2: + raise openai.error.APIError(message="API error") + elif mock_function.retry.statistics["attempt_number"] == 3: + raise openai.error.APIConnectionError(message="API connection error") + elif mock_function.retry.statistics["attempt_number"] == 4: + raise openai.error.ServiceUnavailableError( + message="Service Unavailable error" + ) + elif mock_function.retry.statistics["attempt_number"] == 5: + raise openai.error.RateLimitError("Rate limit error") + else: + # Succeed on the final attempt + return "Success" + + # Test that the decorator retries as expected + with pytest.raises(openai.error.RateLimitError, match="Rate limit error"): + mock_function()