diff --git a/src/dispatch/plugins/dispatch_slack/service.py b/src/dispatch/plugins/dispatch_slack/service.py index 908220fca445..e1c29355ee3f 100644 --- a/src/dispatch/plugins/dispatch_slack/service.py +++ b/src/dispatch/plugins/dispatch_slack/service.py @@ -8,7 +8,13 @@ from slack_sdk.errors import SlackApiError from slack_sdk.web.client import WebClient from slack_sdk.web.slack_response import SlackResponse -from tenacity import retry, retry_if_exception, wait_exponential, stop_after_attempt, RetryCallState +from tenacity import ( + retry, + retry_if_exception, + RetryCallState, + wait_exponential, + stop_after_attempt, +) from typing import Dict, List, Optional @@ -21,28 +27,6 @@ log = logging.getLogger(__name__) -class SlackRetryException(Exception): - def __init__(self, wait_time: int | None = None): - self.wait_time = wait_time - if wait_time: - super().__init__(f"Retrying slack call in {wait_time} seconds.") - else: - super().__init__("Retrying slack call.") - - def get_wait_time(self) -> int: - return self.wait_time - -def slack_wait_strategy(retry_state: RetryCallState) -> float | int: - """Determines the wait time for the Slack retry strategy""" - exc = retry_state.outcome.exception() - - if exc.get_wait_time(): - return exc.get_wait_time() - - # Fallback to exponential backoff if no custom wait time is specified - return wait_exponential(multiplier=1, min=1, max=60)(retry_state) - - def create_slack_client(config: SlackConversationConfiguration) -> WebClient: """Creates a Slack Web API client.""" return WebClient(token=config.api_bot_token.get_secret_value()) @@ -107,47 +91,86 @@ def chunks(ids, n): yield ids[i : i + n] -@retry(stop=stop_after_attempt(5), retry=retry_if_exception(SlackRetryException), wait=slack_wait_strategy) -def make_call(client: WebClient, endpoint: str, **kwargs) -> SlackResponse: - """Makes a call to the Slack API. +def should_retry(exception: Exception) -> bool: + """ + Determine if a retry should be attempted based on the exception type. + + Args: + exception (Exception): The exception that was raised. + + Returns: + bool: True if a retry should be attempted, False otherwise. + """ + match exception: + case SlackApiError(): + # Retry if it's not a fatal error + return exception.response["error"] != SlackAPIErrorCode.FATAL_ERROR + case TimeoutError() | Timeout(): + # Always retry on timeout errors + return True + case _: + # Don't retry for other types of exceptions + return False - This function attempts to be resilient to common Slack API errors, such as rate limiting and fatal errors. - Rate limiting will be retried after the specified wait time (as returned by slack), and fatal errors will be raised as exceptions. + +def get_wait_time(retry_state: RetryCallState) -> int | float: + """ + Determine the wait time before the next retry attempt. Args: - client (WebClient): Slack web client. - endpoint (str): The Slack API endpoint to call. + retry_state (RetryCallState): The current state of the retry process. - Raises: - SlackRetryException: If the call fails and should be retried. + Returns: + int | float: The number of seconds to wait before the next retry. + """ + exception = retry_state.outcome.exception() + match exception: + case SlackApiError() if "Retry-After" in exception.response.headers: + # Use the Retry-After header value if present + return int(exception.response.headers["Retry-After"]) + case _: + # Use exponential backoff for other cases + return wait_exponential(multiplier=1, min=1, max=60)(retry_state) + + +@retry( + stop=stop_after_attempt(5), + retry=retry_if_exception(should_retry), + wait=get_wait_time, +) +def make_call( + client: WebClient, + endpoint: str, + **kwargs, +) -> SlackResponse: + """ + Make a call to the Slack API with built-in retry logic. + + Args: + client (WebClient): The Slack WebClient instance. + endpoint (str): The Slack API endpoint to call. + **kwargs: Additional keyword arguments to pass to the API call. Returns: SlackResponse: The response from the Slack API. + + Raises: + SlackApiError: If there's an error from the Slack API. + TimeoutError: If the request times out. + Timeout: If the request times out (from requests library). """ try: - if endpoint in set(SlackAPIGetEndpoints): + if endpoint in SlackAPIGetEndpoints: + # Use GET method for specific endpoints return client.api_call(endpoint, http_verb="GET", params=kwargs) + # Use POST method (default) for other endpoints return client.api_call(endpoint, json=kwargs) - except SlackApiError as exception: - message = ( - f"SlackAPIError. Response: {exception.response}. Endpoint: {endpoint}. Kwargs: {kwargs}" + except (SlackApiError, TimeoutError, Timeout) as exc: + log.warning( + f"{type(exc).__name__} for Slack API. Endpoint: {endpoint}. Kwargs: {kwargs}", + exc_info=exc if isinstance(exc, SlackApiError) else None, ) - - error = exception.response["error"] - if error == SlackAPIErrorCode.FATAL_ERROR: - log.warn(message) - raise SlackRetryException from None - - elif exception.response.headers.get("Retry-After"): - wait = int(exception.response.headers["Retry-After"]) - log.warn(f"SlackError: Rate limit hit. Waiting {wait} seconds.") - raise SlackRetryException(wait) from None - - # fatal error, don't retry - raise exception - except (TimeoutError, Timeout) as exception: - log.warn(f"{type(exception).__name__} error {exception} for slack. Endpoint: {endpoint}. Kwargs: {kwargs}") - raise SlackRetryException from None + raise def list_conversation_messages(client: WebClient, conversation_id: str, **kwargs) -> SlackResponse: @@ -251,10 +274,15 @@ def set_conversation_topic(client: WebClient, conversation_id: str, topic: str) ) -def set_conversation_description(client: WebClient, conversation_id: str, description: str) -> SlackResponse: +def set_conversation_description( + client: WebClient, conversation_id: str, description: str +) -> SlackResponse: """Sets the topic of the specified conversation.""" return make_call( - client, SlackAPIPostEndpoints.conversations_set_purpose, channel=conversation_id, purpose=description + client, + SlackAPIPostEndpoints.conversations_set_purpose, + channel=conversation_id, + purpose=description, ) @@ -309,7 +337,10 @@ def unarchive_conversation(client: WebClient, conversation_id: str) -> SlackResp def rename_conversation(client: WebClient, conversation_id: str, name: str) -> SlackResponse: """Renames an existing conversation.""" return make_call( - client, SlackAPIPostEndpoints.conversations_rename, channel=conversation_id, name=name.lower() + client, + SlackAPIPostEndpoints.conversations_rename, + channel=conversation_id, + name=name.lower(), )