diff --git a/edx_event_bus_redis/__init__.py b/edx_event_bus_redis/__init__.py index a083009..e18706e 100644 --- a/edx_event_bus_redis/__init__.py +++ b/edx_event_bus_redis/__init__.py @@ -5,6 +5,6 @@ from edx_event_bus_redis.internal.consumer import RedisEventConsumer from edx_event_bus_redis.internal.producer import create_producer -__version__ = '0.5.0' +__version__ = '0.5.1' default_app_config = 'edx_event_bus_redis.apps.EdxEventBusRedisConfig' # pylint: disable=invalid-name diff --git a/edx_event_bus_redis/internal/consumer.py b/edx_event_bus_redis/internal/consumer.py index df161b8..88d4557 100644 --- a/edx_event_bus_redis/internal/consumer.py +++ b/edx_event_bus_redis/internal/consumer.py @@ -19,7 +19,7 @@ from edx_event_bus_redis.internal.message import RedisMessage from .config import get_full_topic, load_common_settings -from .utils import AUDIT_LOGGING_ENABLED +from .utils import AUDIT_LOGGING_ENABLED, Timeout logger = logging.getLogger(__name__) @@ -148,11 +148,19 @@ def _shut_down(self): def _read_pending_msgs(self) -> Optional[tuple]: """ Read pending messages, if no messages found return None. + + These redis calls don't have timout args, and we've seen that they + can hang indefinitely when redis goes down. So we wrap them in a + timeout context manager. """ logger.debug("Consuming pending msgs first.") + if self.claim_msgs_older_than is not None: - self.consumer.autoclaim(self.consumer_name, min_idle_time=self.claim_msgs_older_than, count=1) - msg_meta = self.consumer.pending(count=1, consumer=self.consumer_name) + with Timeout(5): + self.consumer.autoclaim(self.consumer_name, min_idle_time=self.claim_msgs_older_than, count=1) + with Timeout(5): + msg_meta = self.consumer.pending(count=1, consumer=self.consumer_name) + if msg_meta: return self.consumer[msg_meta[0]['message_id']] logger.debug("No more pending messages.") @@ -426,7 +434,7 @@ def _is_fatal_redis_error(self, error: Optional[Exception]) -> bool: Arguments: error: An exception instance, or None if no error. """ - if error and isinstance(error, RedisConnectionError): + if error and isinstance(error, (RedisConnectionError, TimeoutError)): return True return False diff --git a/edx_event_bus_redis/internal/tests/test_utils.py b/edx_event_bus_redis/internal/tests/test_utils.py index 79d8968..f17c0fb 100644 --- a/edx_event_bus_redis/internal/tests/test_utils.py +++ b/edx_event_bus_redis/internal/tests/test_utils.py @@ -1,6 +1,7 @@ """ Test header conversion utils """ +import time from datetime import datetime, timezone from unittest.mock import Mock, patch from uuid import uuid1 @@ -16,6 +17,7 @@ HEADER_ID, HEADER_SOURCELIB, HEADER_TIME, + Timeout, encode, get_headers_from_metadata, get_metadata_from_headers, @@ -139,3 +141,16 @@ def test_generate_metadata_from_missing_or_bad_headers(self, msg_id, msg_time, s expected_metadata = EventsMetadata(event_type="abc", id=TEST_UUID) generated_metadata = get_metadata_from_headers(headers) self.assertDictEqual(attr.asdict(generated_metadata), attr.asdict(expected_metadata)) + + +class TestTimeout(TestCase): + """ + Test the timeout context manager + """ + def test_timeout(self): + """ + Test that the timeout decorator raises a TimeoutError if the function takes too long + """ + with pytest.raises(TimeoutError): + with Timeout(1): + time.sleep(2) diff --git a/edx_event_bus_redis/internal/utils.py b/edx_event_bus_redis/internal/utils.py index 9e2e07e..eb33e0b 100644 --- a/edx_event_bus_redis/internal/utils.py +++ b/edx_event_bus_redis/internal/utils.py @@ -3,6 +3,7 @@ """ import logging +import signal from datetime import datetime from typing import Tuple from uuid import UUID @@ -115,3 +116,31 @@ def get_headers_from_metadata(event_metadata: oed.EventsMetadata): values[encode(header.message_header_key)] = encode(header.from_metadata(event_metadata_value)) return values + + +class Timeout: + """ + Context manager to raise a TimeoutError after a specified number of seconds. + + Some redis calls don't have a timeout parameter, so this can be used to enforce a timeout. + """ + def __init__(self, timeout_seconds): + self.timeout_seconds = timeout_seconds + + def __enter__(self): + """ + Start the timer, if we don't __exit__ in self.seconds it will raise the TimeoutError. + """ + signal.signal(signal.SIGALRM, Timeout._raise_timeout) + signal.alarm(self.timeout_seconds) + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ + Stop the signal timer on context exit. + """ + signal.alarm(0) + + @staticmethod + def _raise_timeout(signum, frame): + raise TimeoutError