diff --git a/pyproject.toml b/pyproject.toml index 6839819f5..f55ade76a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,7 +184,7 @@ module = [ "html_sanitizer", "isbnlib", "jwcrypto", - "kombu", + "kombu.*", "lxml.*", "money", "multipledispatch", diff --git a/src/palace/manager/celery/monitoring.py b/src/palace/manager/celery/monitoring.py index 29197549b..c3830d76e 100644 --- a/src/palace/manager/celery/monitoring.py +++ b/src/palace/manager/celery/monitoring.py @@ -5,16 +5,20 @@ from dataclasses import dataclass, field from datetime import datetime from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse import boto3 from boto3.exceptions import Boto3Error from botocore.exceptions import BotoCoreError from celery.events.snapshot import Polaroid from celery.events.state import State, Task +from kombu.transport.redis import PrefixedStrictRedis +from redis import ConnectionPool +from palace.manager.core.exceptions import PalaceValueError from palace.manager.util import chunks from palace.manager.util.datetime_helpers import utc_now -from palace.manager.util.log import LoggerMixin, logger_for_cls +from palace.manager.util.log import logger_for_cls if TYPE_CHECKING: from mypy_boto3_cloudwatch.literals import StandardUnitType @@ -115,29 +119,20 @@ def reset(self) -> None: self.runtime = [] -@dataclass -class QueueStats(LoggerMixin): +@dataclass(frozen=True) +class QueueStats: """ Tracks the number of tasks queued for a specific queue, so we can report this out to Cloudwatch metrics. """ - queued: set[str] = field(default_factory=set) - - def update(self, task: Task) -> None: - self.log.debug("Task: %r", task) - if task.uuid in self.queued and task.started: - self.log.debug(f"Task {task.uuid} started.") - self.queued.remove(task.uuid) - elif task.sent and not task.started: - self.log.debug(f"Task {task.uuid} queued.") - self.queued.add(task.uuid) + queued: int def metrics( self, timestamp: datetime, dimensions: dict[str, str] ) -> list[MetricDatumTypeDef]: return [ - value_metric("QueueWaiting", len(self.queued), timestamp, dimensions), + value_metric("QueueWaiting", self.queued, timestamp, dimensions), ] @@ -161,6 +156,11 @@ def __init__( # because the base class Polaroid already defines a logger attribute, # which conflicts with the logger() method in LoggerMixin. self.logger = logger_for_cls(self.__class__) + broker_url = self.app.conf.get("broker_url") + broker_type = urlparse(broker_url).scheme if broker_url else None + if broker_type != "redis": + raise PalaceValueError(f"Broker type '{broker_type}' is not supported.") + region = self.app.conf.get("cloudwatch_statistics_region") dryrun = self.app.conf.get("cloudwatch_statistics_dryrun") self.cloudwatch_client = ( @@ -169,12 +169,10 @@ def __init__( self.manager_name = self.app.conf.get("broker_transport_options", {}).get( "global_keyprefix" ) + self.redis_client = self.get_redis_client(broker_url, self.manager_name) self.namespace = self.app.conf.get("cloudwatch_statistics_namespace") self.upload_size = self.app.conf.get("cloudwatch_statistics_upload_size") - self.queues = defaultdict( - QueueStats, - {queue.name: QueueStats() for queue in self.app.conf.get("task_queues")}, - ) + self.queues = {queue.name for queue in self.app.conf.get("task_queues")} self.tasks: defaultdict[str, TaskStats] = defaultdict( TaskStats, { @@ -184,6 +182,15 @@ def __init__( }, ) + @classmethod + def get_redis_client( + cls, broker_url: str, global_keyprefix: str | None + ) -> PrefixedStrictRedis: + connection_pool = ConnectionPool.from_url(broker_url) + return PrefixedStrictRedis( + connection_pool=connection_pool, global_keyprefix=global_keyprefix + ) + @staticmethod def is_celery_task(task_name: str) -> bool: return task_name.startswith("celery.") @@ -201,12 +208,11 @@ def reset_tasks(self) -> None: for task in self.tasks.values(): task.reset() - def on_shutter(self, state: State) -> None: - timestamp = utc_now() - + def update_task_stats(self, state: State) -> None: # Reset the task stats for each snapshot self.reset_tasks() + # Update task stats for each task in the state for task in state.tasks.values(): # Update task stats for each task if task.name is None: @@ -218,19 +224,17 @@ def on_shutter(self, state: State) -> None: else: self.tasks[task.name].update(task) - # Update queue stats for each task - if task.routing_key is None: - self.logger.warning( - f"Task has no routing_key. {self.task_info_str(task)}." - ) - if task.started: - # Its possible that we are tracking this task, so we make sure its not in any of our queues - for queue in self.queues.values(): - queue.update(task) - else: - self.queues[task.routing_key].update(task) + def get_queue_stats(self) -> dict[str, QueueStats]: + return { + queue: QueueStats(self.redis_client.llen(queue)) for queue in self.queues + } + + def on_shutter(self, state: State) -> None: + timestamp = utc_now() + self.update_task_stats(state) + queue_stats = self.get_queue_stats() - self.publish(self.tasks, self.queues, timestamp) + self.publish(self.tasks, queue_stats, timestamp) def publish( self, diff --git a/src/palace/manager/core/exceptions.py b/src/palace/manager/core/exceptions.py index 477a91131..834ee166c 100644 --- a/src/palace/manager/core/exceptions.py +++ b/src/palace/manager/core/exceptions.py @@ -10,6 +10,10 @@ def __init__(self, message: str | None = None): self.message = message +class PalaceValueError(BasePalaceException, ValueError): + ... + + class IntegrationException(BasePalaceException): """An exception that happens when the site's connection to a third-party service is broken. diff --git a/tests/manager/celery/test_monitoring.py b/tests/manager/celery/test_monitoring.py index c5cd140f9..d455542ea 100644 --- a/tests/manager/celery/test_monitoring.py +++ b/tests/manager/celery/test_monitoring.py @@ -1,5 +1,4 @@ -from functools import partial -from unittest.mock import MagicMock, create_autospec, patch +from unittest.mock import MagicMock, call, create_autospec, patch from uuid import uuid4 import pytest @@ -25,18 +24,27 @@ def __init__(self, boto_client: MagicMock): self.state = create_autospec(State) self.state.tasks = self.task_list( [ - self.mock_task("task1", "queue1", runtime=1.0), - self.mock_task("task1", "queue1", runtime=2.0), - self.mock_task("task2", "queue2", succeeded=False, failed=True), - self.mock_task( - "task2", "queue2", started=False, succeeded=False, uuid="uuid4" - ), - self.mock_task( - "celery.built_in", "queue2", started=False, succeeded=False - ), + self.mock_task("task1", runtime=1.0), + self.mock_task("task1", runtime=2.0), + self.mock_task("task2", succeeded=False, failed=True), + self.mock_task("task2", started=False, succeeded=False, uuid="uuid4"), + self.mock_task("celery.built_in", started=False, succeeded=False), ] ) - self.create_cloudwatch = partial(Cloudwatch, state=self.state, app=self.app) + self._mock_get_redis: MagicMock | None = None + + def create_cloudwatch(self): + with patch.object(Cloudwatch, "get_redis_client") as mock_get_redis: + self._mock_get_redis = mock_get_redis + return Cloudwatch(state=self.state, app=self.app) + + @property + def mock_get_redis(self): + if self._mock_get_redis is None: + raise ValueError( + "get_redis_client not mocked because create_cloudwatch was not called." + ) + return self._mock_get_redis @staticmethod def task_list(tasks: list[Task]) -> dict[str, Task]: @@ -50,7 +58,7 @@ def mock_queue(self, name: str) -> MagicMock: def mock_task( self, name: str | None = None, - routing_key: str | None = None, + *, sent: bool = True, started: bool = True, succeeded: bool = True, @@ -63,7 +71,6 @@ def mock_task( return Task( uuid=uuid, name=name, - routing_key=routing_key, sent=sent, started=started, succeeded=succeeded, @@ -73,6 +80,7 @@ def mock_task( def configure_app( self, + broker_url: str = "redis://testtesttest:1234/1", region: str = "region", dry_run: bool = False, manager_name: str = "manager", @@ -82,6 +90,7 @@ def configure_app( ) -> None: queues = queues or ["queue1", "queue2"] self.app.conf = { + "broker_url": broker_url, "cloudwatch_statistics_region": region, "cloudwatch_statistics_dryrun": dry_run, "broker_transport_options": {"global_keyprefix": manager_name}, @@ -172,40 +181,8 @@ def test_metrics_with_empty_runtime(self): class TestQueueStats: - def test_update(self, cloudwatch_camera: CloudwatchCameraFixture): - stats = QueueStats() - - assert len(stats.queued) == 0 - - mock_task = cloudwatch_camera.mock_task(sent=False, started=False) - - # Task is not started or sent, so it should not be in the queue. - stats.update(mock_task) - assert len(stats.queued) == 0 - - # Task is both sent and started, so its being processed and should not be in the queue. - mock_task = cloudwatch_camera.mock_task(sent=True, started=True) - stats.update(mock_task) - assert len(stats.queued) == 0 - - # Task is sent but not started, so it should be in the queue. - mock_task = cloudwatch_camera.mock_task(sent=True, started=False) - stats.update(mock_task) - assert len(stats.queued) == 1 - - # If the task is sent again, it should still be in the queue, but not duplicated. - stats.update(mock_task) - assert len(stats.queued) == 1 - - # If the task is started, it should be removed from the queue, even if we no longer - # have its routing key. - mock_task.started = True - mock_task.routing_key = None - stats.update(mock_task) - assert len(stats.queued) == 0 - def test_metrics(self): - stats = QueueStats(queued={"uuid1", "uuid2"}) + stats = QueueStats(queued=2) timestamp = MagicMock() dimensions = {"key": "value", "key2": "value2"} expected_dimensions = [ @@ -218,11 +195,6 @@ def test_metrics(self): assert metric["Dimensions"] == expected_dimensions assert metric["Unit"] == "Count" - stats = QueueStats() - [metric] = stats.metrics(timestamp, dimensions) - assert metric["MetricName"] == "QueueWaiting" - assert metric["Value"] == 0 - class TestCloudwatch: def test__init__(self, cloudwatch_camera: CloudwatchCameraFixture): @@ -236,7 +208,18 @@ def test__init__(self, cloudwatch_camera: CloudwatchCameraFixture): assert cloudwatch.manager_name == "manager" assert cloudwatch.namespace == "namespace" assert cloudwatch.upload_size == 100 - assert cloudwatch.queues == {"queue1": QueueStats(), "queue2": QueueStats()} + assert cloudwatch.queues == {"queue1", "queue2"} + assert cloudwatch.redis_client == cloudwatch_camera.mock_get_redis.return_value + cloudwatch_camera.mock_get_redis.assert_called_once_with( + "redis://testtesttest:1234/1", + "manager", + ) + + def test__init__error(self, cloudwatch_camera: CloudwatchCameraFixture): + cloudwatch_camera.configure_app(broker_url="sqs://") + with pytest.raises(ValueError) as exc_info: + cloudwatch_camera.create_cloudwatch() + assert "Broker type 'sqs' is not supported." in str(exc_info.value) def test__init__dryrun(self, cloudwatch_camera: CloudwatchCameraFixture): cloudwatch_camera.configure_app(dry_run=True) @@ -246,14 +229,17 @@ def test__init__dryrun(self, cloudwatch_camera: CloudwatchCameraFixture): def test_on_shutter( self, cloudwatch_camera: CloudwatchCameraFixture, - caplog: pytest.LogCaptureFixture, ): - caplog.set_level(LogLevel.warning) cloudwatch = cloudwatch_camera.create_cloudwatch() mock_publish = create_autospec(cloudwatch.publish) cloudwatch.publish = mock_publish + cloudwatch_camera.mock_get_redis.return_value.llen.return_value = 10 with freeze_time("2021-01-01"): cloudwatch.on_shutter(cloudwatch_camera.state) + assert cloudwatch_camera.mock_get_redis.return_value.llen.call_count == 2 + cloudwatch_camera.mock_get_redis.return_value.llen.assert_has_calls( + [call("queue1"), call("queue2")], any_order=True + ) mock_publish.assert_called_once() [tasks, queues, time] = mock_publish.call_args.args @@ -262,25 +248,23 @@ def test_on_shutter( "task2": TaskStats(failed=1), } assert queues == { - "queue1": QueueStats(), - "queue2": QueueStats(queued={"uuid4"}), + "queue1": QueueStats(queued=10), + "queue2": QueueStats(queued=10), } assert time.isoformat() == "2021-01-01T00:00:00+00:00" - # We can also handle the case where we see a task with an unknown queue or unknown name. + def test_on_shutter_unknown_task_name( + self, + cloudwatch_camera: CloudwatchCameraFixture, + ): + # We can also handle the case where we see a task with an unknown name. + cloudwatch = cloudwatch_camera.create_cloudwatch() + mock_publish = create_autospec(cloudwatch.publish) + cloudwatch.publish = mock_publish cloudwatch_camera.state.tasks = cloudwatch_camera.task_list( [ - cloudwatch_camera.mock_task( - "task2", - "unknown_queue", - started=False, - succeeded=False, - uuid="uuid5", - runtime=3.0, - ), cloudwatch_camera.mock_task( "unknown_task", - "queue1", failed=True, succeeded=False, uuid="uuid6", @@ -288,67 +272,49 @@ def test_on_shutter( ] ) cloudwatch.on_shutter(cloudwatch_camera.state) - [tasks, queues, _] = mock_publish.call_args.args + [tasks, _, _] = mock_publish.call_args.args assert tasks == { "task1": TaskStats(), "task2": TaskStats(), "unknown_task": TaskStats(failed=1), } - assert queues == { - "queue1": QueueStats(), - "queue2": QueueStats(queued={"uuid4"}), - "unknown_queue": QueueStats(queued={"uuid5"}), - } - # We can handle tasks with no name or tasks with no routing key. - caplog.clear() + def test_on_shutter_no_task_name( + self, + cloudwatch_camera: CloudwatchCameraFixture, + caplog: pytest.LogCaptureFixture, + ): + # We can handle tasks with no name + caplog.set_level(LogLevel.warning) + cloudwatch = cloudwatch_camera.create_cloudwatch() + mock_publish = create_autospec(cloudwatch.publish) + cloudwatch.publish = mock_publish cloudwatch_camera.state.tasks = cloudwatch_camera.task_list( [ - cloudwatch_camera.mock_task( - None, "unknown_queue", started=True, uuid="uuid7" - ), - cloudwatch_camera.mock_task( - "task2", None, started=False, succeeded=False, uuid="uuid8" - ), - cloudwatch_camera.mock_task(None, None, started=True, uuid="uuid5"), + cloudwatch_camera.mock_task(None, started=True, uuid="uuid7"), + cloudwatch_camera.mock_task(None, started=True, uuid="uuid5"), ] ) cloudwatch.on_shutter(cloudwatch_camera.state) - [tasks, queues, _] = mock_publish.call_args.args + [tasks, _, _] = mock_publish.call_args.args assert tasks == { "task1": TaskStats(), "task2": TaskStats(), - "unknown_task": TaskStats(), - } - assert queues == { - "queue1": QueueStats(), - "queue2": QueueStats(queued={"uuid4"}), - "unknown_queue": QueueStats(), } - # We log the information about tasks with no name or routing key. + # We log the information about tasks with no name ( no_name_warning_1, - no_routing_key_warning_1, no_name_warning_2, - no_routing_key_warning_2, ) = caplog.messages assert ( - "Task has no name. [routing_key]:unknown_queue, [sent]:True, [started]:True, [uuid]:uuid7." + "Task has no name. [sent]:True, [started]:True, [uuid]:uuid7." in no_name_warning_1 ) - assert ( - "Task has no routing_key. [name]:task2, [sent]:True, [started]:False, [uuid]:uuid8." - in no_routing_key_warning_1 - ) assert ( "Task has no name. [sent]:True, [started]:True, [uuid]:uuid5." in no_name_warning_2 ) - assert ( - "Task has no routing_key. [sent]:True, [started]:True, [uuid]:uuid5." - in no_routing_key_warning_2 - ) def test_publish( self, @@ -363,7 +329,7 @@ def test_publish( } timestamp = MagicMock() tasks = {"task1": TaskStats(succeeded=2, failed=5, runtime=[3.5, 2.2])} - queues = {"queue1": QueueStats(queued={"uuid1", "uuid2"})} + queues = {"queue1": QueueStats(queued=2)} cloudwatch.publish(tasks, queues, timestamp) mock_put_metric_data.assert_called_once()