Skip to content

Commit

Permalink
feat: simplify slack make_call tenacity strategies (#4904)
Browse files Browse the repository at this point in the history
  • Loading branch information
wssheldon authored Jul 1, 2024
1 parent 85ea260 commit 9d1f892
Showing 1 changed file with 86 additions and 55 deletions.
141 changes: 86 additions & 55 deletions src/dispatch/plugins/dispatch_slack/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from slack_sdk.errors import SlackApiError
from slack_sdk.web.client import WebClient
from slack_sdk.web.slack_response import SlackResponse
from tenacity import retry, retry_if_exception, wait_exponential, stop_after_attempt, RetryCallState
from tenacity import (
retry,
retry_if_exception,
RetryCallState,
wait_exponential,
stop_after_attempt,
)

from typing import Dict, List, Optional

Expand All @@ -21,28 +27,6 @@
log = logging.getLogger(__name__)


class SlackRetryException(Exception):
def __init__(self, wait_time: int | None = None):
self.wait_time = wait_time
if wait_time:
super().__init__(f"Retrying slack call in {wait_time} seconds.")
else:
super().__init__("Retrying slack call.")

def get_wait_time(self) -> int:
return self.wait_time

def slack_wait_strategy(retry_state: RetryCallState) -> float | int:
"""Determines the wait time for the Slack retry strategy"""
exc = retry_state.outcome.exception()

if exc.get_wait_time():
return exc.get_wait_time()

# Fallback to exponential backoff if no custom wait time is specified
return wait_exponential(multiplier=1, min=1, max=60)(retry_state)


def create_slack_client(config: SlackConversationConfiguration) -> WebClient:
"""Creates a Slack Web API client."""
return WebClient(token=config.api_bot_token.get_secret_value())
Expand Down Expand Up @@ -107,47 +91,86 @@ def chunks(ids, n):
yield ids[i : i + n]


@retry(stop=stop_after_attempt(5), retry=retry_if_exception(SlackRetryException), wait=slack_wait_strategy)
def make_call(client: WebClient, endpoint: str, **kwargs) -> SlackResponse:
"""Makes a call to the Slack API.
def should_retry(exception: Exception) -> bool:
"""
Determine if a retry should be attempted based on the exception type.
Args:
exception (Exception): The exception that was raised.
Returns:
bool: True if a retry should be attempted, False otherwise.
"""
match exception:
case SlackApiError():
# Retry if it's not a fatal error
return exception.response["error"] != SlackAPIErrorCode.FATAL_ERROR
case TimeoutError() | Timeout():
# Always retry on timeout errors
return True
case _:
# Don't retry for other types of exceptions
return False

This function attempts to be resilient to common Slack API errors, such as rate limiting and fatal errors.
Rate limiting will be retried after the specified wait time (as returned by slack), and fatal errors will be raised as exceptions.

def get_wait_time(retry_state: RetryCallState) -> int | float:
"""
Determine the wait time before the next retry attempt.
Args:
client (WebClient): Slack web client.
endpoint (str): The Slack API endpoint to call.
retry_state (RetryCallState): The current state of the retry process.
Raises:
SlackRetryException: If the call fails and should be retried.
Returns:
int | float: The number of seconds to wait before the next retry.
"""
exception = retry_state.outcome.exception()
match exception:
case SlackApiError() if "Retry-After" in exception.response.headers:
# Use the Retry-After header value if present
return int(exception.response.headers["Retry-After"])
case _:
# Use exponential backoff for other cases
return wait_exponential(multiplier=1, min=1, max=60)(retry_state)


@retry(
stop=stop_after_attempt(5),
retry=retry_if_exception(should_retry),
wait=get_wait_time,
)
def make_call(
client: WebClient,
endpoint: str,
**kwargs,
) -> SlackResponse:
"""
Make a call to the Slack API with built-in retry logic.
Args:
client (WebClient): The Slack WebClient instance.
endpoint (str): The Slack API endpoint to call.
**kwargs: Additional keyword arguments to pass to the API call.
Returns:
SlackResponse: The response from the Slack API.
Raises:
SlackApiError: If there's an error from the Slack API.
TimeoutError: If the request times out.
Timeout: If the request times out (from requests library).
"""
try:
if endpoint in set(SlackAPIGetEndpoints):
if endpoint in SlackAPIGetEndpoints:
# Use GET method for specific endpoints
return client.api_call(endpoint, http_verb="GET", params=kwargs)
# Use POST method (default) for other endpoints
return client.api_call(endpoint, json=kwargs)
except SlackApiError as exception:
message = (
f"SlackAPIError. Response: {exception.response}. Endpoint: {endpoint}. Kwargs: {kwargs}"
except (SlackApiError, TimeoutError, Timeout) as exc:
log.warning(
f"{type(exc).__name__} for Slack API. Endpoint: {endpoint}. Kwargs: {kwargs}",
exc_info=exc if isinstance(exc, SlackApiError) else None,
)

error = exception.response["error"]
if error == SlackAPIErrorCode.FATAL_ERROR:
log.warn(message)
raise SlackRetryException from None

elif exception.response.headers.get("Retry-After"):
wait = int(exception.response.headers["Retry-After"])
log.warn(f"SlackError: Rate limit hit. Waiting {wait} seconds.")
raise SlackRetryException(wait) from None

# fatal error, don't retry
raise exception
except (TimeoutError, Timeout) as exception:
log.warn(f"{type(exception).__name__} error {exception} for slack. Endpoint: {endpoint}. Kwargs: {kwargs}")
raise SlackRetryException from None
raise


def list_conversation_messages(client: WebClient, conversation_id: str, **kwargs) -> SlackResponse:
Expand Down Expand Up @@ -251,10 +274,15 @@ def set_conversation_topic(client: WebClient, conversation_id: str, topic: str)
)


def set_conversation_description(client: WebClient, conversation_id: str, description: str) -> SlackResponse:
def set_conversation_description(
client: WebClient, conversation_id: str, description: str
) -> SlackResponse:
"""Sets the topic of the specified conversation."""
return make_call(
client, SlackAPIPostEndpoints.conversations_set_purpose, channel=conversation_id, purpose=description
client,
SlackAPIPostEndpoints.conversations_set_purpose,
channel=conversation_id,
purpose=description,
)


Expand Down Expand Up @@ -309,7 +337,10 @@ def unarchive_conversation(client: WebClient, conversation_id: str) -> SlackResp
def rename_conversation(client: WebClient, conversation_id: str, name: str) -> SlackResponse:
"""Renames an existing conversation."""
return make_call(
client, SlackAPIPostEndpoints.conversations_rename, channel=conversation_id, name=name.lower()
client,
SlackAPIPostEndpoints.conversations_rename,
channel=conversation_id,
name=name.lower(),
)


Expand Down

0 comments on commit 9d1f892

Please sign in to comment.