diff --git a/src/palace/manager/celery/monitoring.py b/src/palace/manager/celery/monitoring.py index a8746be04c..29197549b9 100644 --- a/src/palace/manager/celery/monitoring.py +++ b/src/palace/manager/celery/monitoring.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass, field from datetime import datetime @@ -108,6 +109,11 @@ def metrics( return metric_data + def reset(self) -> None: + self.succeeded = 0 + self.failed = 0 + self.runtime = [] + @dataclass class QueueStats(LoggerMixin): @@ -120,14 +126,12 @@ class QueueStats(LoggerMixin): def update(self, task: Task) -> None: self.log.debug("Task: %r", task) - if task.uuid in self.queued: - if task.started: - self.log.debug("Task %s started.", task.uuid) - self.queued.remove(task.uuid) - else: - if task.sent and not task.started: - self.log.debug("Task %s queued.", task.uuid) - self.queued.add(task.uuid) + if task.uuid in self.queued and task.started: + self.log.debug(f"Task {task.uuid} started.") + self.queued.remove(task.uuid) + elif task.sent and not task.started: + self.log.debug(f"Task {task.uuid} queued.") + self.queued.add(task.uuid) def metrics( self, timestamp: datetime, dimensions: dict[str, str] @@ -167,30 +171,66 @@ def __init__( ) self.namespace = self.app.conf.get("cloudwatch_statistics_namespace") self.upload_size = self.app.conf.get("cloudwatch_statistics_upload_size") - self.queues = { - str(queue.name): QueueStats() for queue in self.app.conf.get("task_queues") - } + self.queues = defaultdict( + QueueStats, + {queue.name: QueueStats() for queue in self.app.conf.get("task_queues")}, + ) + self.tasks: defaultdict[str, TaskStats] = defaultdict( + TaskStats, + { + task: TaskStats() + for task in self.app.tasks.keys() + if not self.is_celery_task(task) + }, + ) + + @staticmethod + def is_celery_task(task_name: str) -> bool: + return task_name.startswith("celery.") + + @staticmethod + def task_info_str(task: Task) -> str: + return ", ".join( + [ + f"[{k}]:{v}" + for k, v in task.info(extra=["name", "sent", "started", "uuid"]).items() + ] + ) + + def reset_tasks(self) -> None: + for task in self.tasks.values(): + task.reset() def on_shutter(self, state: State) -> None: timestamp = utc_now() - tasks = { - task: TaskStats() - for task in self.app.tasks.keys() - if not task.startswith("celery.") - } + + # Reset the task stats for each snapshot + self.reset_tasks() for task in state.tasks.values(): - try: - tasks[task.name].update(task) - self.queues[task.routing_key].update(task) - except KeyError: - self.logger.exception( - "Error processing task %s with routing key %s", - task.name, - task.routing_key, + # Update task stats for each task + if task.name is None: + self.logger.warning(f"Task has no name. {self.task_info_str(task)}.") + elif self.is_celery_task(task.name): + # If this is an internal Celery task, we skip it entirely. + # We don't want to track internal Celery tasks. + continue + else: + self.tasks[task.name].update(task) + + # Update queue stats for each task + if task.routing_key is None: + self.logger.warning( + f"Task has no routing_key. {self.task_info_str(task)}." ) + if task.started: + # Its possible that we are tracking this task, so we make sure its not in any of our queues + for queue in self.queues.values(): + queue.update(task) + else: + self.queues[task.routing_key].update(task) - self.publish(tasks, self.queues, timestamp) + self.publish(self.tasks, self.queues, timestamp) def publish( self, @@ -226,4 +266,4 @@ def publish( else: self.logger.info("Dry run enabled. Not sending metrics to Cloudwatch.") for data in chunk: - self.logger.info("Data: %s", data) + self.logger.info(f"Data: {data}") diff --git a/tests/manager/celery/test_monitoring.py b/tests/manager/celery/test_monitoring.py index 9611297a22..9f63c40f50 100644 --- a/tests/manager/celery/test_monitoring.py +++ b/tests/manager/celery/test_monitoring.py @@ -23,16 +23,25 @@ def __init__(self, boto_client: MagicMock): } self.client = boto_client self.state = create_autospec(State) - self.state.tasks = { - "uuid1": self.mock_task("task1", "queue1", runtime=1.0), - "uuid2": self.mock_task("task1", "queue1", runtime=2.0), - "uuid3": self.mock_task("task2", "queue2", succeeded=False, failed=True), - "uuid4": self.mock_task( - "task2", "queue2", started=False, succeeded=False, uuid="uuid4" - ), - } + self.state.tasks = self.task_list( + [ + self.mock_task("task1", "queue1", runtime=1.0), + self.mock_task("task1", "queue1", runtime=2.0), + self.mock_task("task2", "queue2", succeeded=False, failed=True), + self.mock_task( + "task2", "queue2", started=False, succeeded=False, uuid="uuid4" + ), + self.mock_task( + "celery.built_in", "queue2", started=False, succeeded=False + ), + ] + ) self.create_cloudwatch = partial(Cloudwatch, state=self.state, app=self.app) + @staticmethod + def task_list(tasks: list[Task]) -> dict[str, Task]: + return {task.uuid: task for task in tasks} + def mock_queue(self, name: str) -> MagicMock: queue = MagicMock() queue.name = name @@ -51,10 +60,6 @@ def mock_task( ) -> Task: if uuid is None: uuid = str(uuid4()) - if name is None: - name = "task" - if routing_key is None: - routing_key = "queue" return Task( uuid=uuid, name=name, @@ -192,8 +197,10 @@ def test_update(self, cloudwatch_camera: CloudwatchCameraFixture): stats.update(mock_task) assert len(stats.queued) == 1 - # If the task is started, it should be removed from the queue. + # If the task is started, it should be removed from the queue, even if we no longer + # have its routing key. mock_task.started = True + mock_task.routing_key = None stats.update(mock_task) assert len(stats.queued) == 0 @@ -236,7 +243,11 @@ def test__init__dryrun(self, cloudwatch_camera: CloudwatchCameraFixture): cloudwatch = cloudwatch_camera.create_cloudwatch() assert cloudwatch.cloudwatch_client is None - def test_on_shutter(self, cloudwatch_camera: CloudwatchCameraFixture): + def test_on_shutter( + self, + cloudwatch_camera: CloudwatchCameraFixture, + caplog: pytest.LogCaptureFixture, + ): cloudwatch = cloudwatch_camera.create_cloudwatch() mock_publish = create_autospec(cloudwatch.publish) cloudwatch.publish = mock_publish @@ -255,23 +266,88 @@ def test_on_shutter(self, cloudwatch_camera: CloudwatchCameraFixture): } assert time.isoformat() == "2021-01-01T00:00:00+00:00" - def test_on_shutter_error( - self, - cloudwatch_camera: CloudwatchCameraFixture, - caplog: pytest.LogCaptureFixture, - ): - cloudwatch_camera.app.tasks = {"task1": MagicMock()} - cloudwatch = cloudwatch_camera.create_cloudwatch() - mock_publish = create_autospec(cloudwatch.publish) - cloudwatch.publish = mock_publish + # We can also handle the case where we see a task with an unknown queue or unknown name. + cloudwatch_camera.state.tasks = cloudwatch_camera.task_list( + [ + cloudwatch_camera.mock_task( + "task2", + "unknown_queue", + started=False, + succeeded=False, + uuid="uuid5", + runtime=3.0, + ), + cloudwatch_camera.mock_task( + "unknown_task", + "queue1", + failed=True, + succeeded=False, + uuid="uuid6", + ), + ] + ) cloudwatch.on_shutter(cloudwatch_camera.state) - mock_publish.assert_called_once() - [tasks, queues, time] = mock_publish.call_args.args + [tasks, queues, _] = mock_publish.call_args.args + assert tasks == { + "task1": TaskStats(), + "task2": TaskStats(), + "unknown_task": TaskStats(failed=1), + } + assert queues == { + "queue1": QueueStats(), + "queue2": QueueStats(queued={"uuid4"}), + "unknown_queue": QueueStats(queued={"uuid5"}), + } - assert tasks == {"task1": TaskStats(succeeded=2, runtime=[1.0, 2.0])} - assert queues == {"queue1": QueueStats(), "queue2": QueueStats()} - assert time is not None - assert "Error processing task" in caplog.text + # We can handle tasks with no name or tasks with no routing key. + caplog.clear() + cloudwatch_camera.state.tasks = cloudwatch_camera.task_list( + [ + cloudwatch_camera.mock_task( + None, "unknown_queue", started=True, uuid="uuid7" + ), + cloudwatch_camera.mock_task( + "task2", None, started=False, succeeded=False, uuid="uuid8" + ), + cloudwatch_camera.mock_task(None, None, started=True, uuid="uuid5"), + ] + ) + cloudwatch.on_shutter(cloudwatch_camera.state) + [tasks, queues, _] = mock_publish.call_args.args + assert tasks == { + "task1": TaskStats(), + "task2": TaskStats(), + "unknown_task": TaskStats(), + } + assert queues == { + "queue1": QueueStats(), + "queue2": QueueStats(queued={"uuid4"}), + "unknown_queue": QueueStats(), + } + + # We log the information about tasks with no name or routing key. + ( + no_name_warning_1, + no_routing_key_warning_1, + no_name_warning_2, + no_routing_key_warning_2, + ) = caplog.messages + assert ( + "Task has no name. [routing_key]:unknown_queue, [sent]:True, [started]:True, [uuid]:uuid7." + in no_name_warning_1 + ) + assert ( + "Task has no routing_key. [name]:task2, [sent]:True, [started]:False, [uuid]:uuid8." + in no_routing_key_warning_1 + ) + assert ( + "Task has no name. [sent]:True, [started]:True, [uuid]:uuid5." + in no_name_warning_2 + ) + assert ( + "Task has no routing_key. [sent]:True, [started]:True, [uuid]:uuid5." + in no_routing_key_warning_2 + ) def test_publish( self,