Skip to content

Commit

Permalink
Fix issue with monitoring queue stats (PP-1312) (#1871)
Browse files Browse the repository at this point in the history
* Update cloudwatch monitoring to be more resiliant
* Update monitoring
* Make sure we account for the tasks we know about.
  • Loading branch information
jonathangreen authored May 29, 2024
1 parent 662995e commit 7777f85
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 55 deletions.
92 changes: 66 additions & 26 deletions src/palace/manager/celery/monitoring.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass, field
from datetime import datetime
Expand Down Expand Up @@ -108,6 +109,11 @@ def metrics(

return metric_data

def reset(self) -> None:
self.succeeded = 0
self.failed = 0
self.runtime = []


@dataclass
class QueueStats(LoggerMixin):
Expand All @@ -120,14 +126,12 @@ class QueueStats(LoggerMixin):

def update(self, task: Task) -> None:
self.log.debug("Task: %r", task)
if task.uuid in self.queued:
if task.started:
self.log.debug("Task %s started.", task.uuid)
self.queued.remove(task.uuid)
else:
if task.sent and not task.started:
self.log.debug("Task %s queued.", task.uuid)
self.queued.add(task.uuid)
if task.uuid in self.queued and task.started:
self.log.debug(f"Task {task.uuid} started.")
self.queued.remove(task.uuid)
elif task.sent and not task.started:
self.log.debug(f"Task {task.uuid} queued.")
self.queued.add(task.uuid)

def metrics(
self, timestamp: datetime, dimensions: dict[str, str]
Expand Down Expand Up @@ -167,30 +171,66 @@ def __init__(
)
self.namespace = self.app.conf.get("cloudwatch_statistics_namespace")
self.upload_size = self.app.conf.get("cloudwatch_statistics_upload_size")
self.queues = {
str(queue.name): QueueStats() for queue in self.app.conf.get("task_queues")
}
self.queues = defaultdict(
QueueStats,
{queue.name: QueueStats() for queue in self.app.conf.get("task_queues")},
)
self.tasks: defaultdict[str, TaskStats] = defaultdict(
TaskStats,
{
task: TaskStats()
for task in self.app.tasks.keys()
if not self.is_celery_task(task)
},
)

@staticmethod
def is_celery_task(task_name: str) -> bool:
return task_name.startswith("celery.")

@staticmethod
def task_info_str(task: Task) -> str:
return ", ".join(
[
f"[{k}]:{v}"
for k, v in task.info(extra=["name", "sent", "started", "uuid"]).items()
]
)

def reset_tasks(self) -> None:
for task in self.tasks.values():
task.reset()

def on_shutter(self, state: State) -> None:
timestamp = utc_now()
tasks = {
task: TaskStats()
for task in self.app.tasks.keys()
if not task.startswith("celery.")
}

# Reset the task stats for each snapshot
self.reset_tasks()

for task in state.tasks.values():
try:
tasks[task.name].update(task)
self.queues[task.routing_key].update(task)
except KeyError:
self.logger.exception(
"Error processing task %s with routing key %s",
task.name,
task.routing_key,
# Update task stats for each task
if task.name is None:
self.logger.warning(f"Task has no name. {self.task_info_str(task)}.")
elif self.is_celery_task(task.name):
# If this is an internal Celery task, we skip it entirely.
# We don't want to track internal Celery tasks.
continue
else:
self.tasks[task.name].update(task)

# Update queue stats for each task
if task.routing_key is None:
self.logger.warning(
f"Task has no routing_key. {self.task_info_str(task)}."
)
if task.started:
# Its possible that we are tracking this task, so we make sure its not in any of our queues
for queue in self.queues.values():
queue.update(task)
else:
self.queues[task.routing_key].update(task)

self.publish(tasks, self.queues, timestamp)
self.publish(self.tasks, self.queues, timestamp)

def publish(
self,
Expand Down Expand Up @@ -226,4 +266,4 @@ def publish(
else:
self.logger.info("Dry run enabled. Not sending metrics to Cloudwatch.")
for data in chunk:
self.logger.info("Data: %s", data)
self.logger.info(f"Data: {data}")
134 changes: 105 additions & 29 deletions tests/manager/celery/test_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,25 @@ def __init__(self, boto_client: MagicMock):
}
self.client = boto_client
self.state = create_autospec(State)
self.state.tasks = {
"uuid1": self.mock_task("task1", "queue1", runtime=1.0),
"uuid2": self.mock_task("task1", "queue1", runtime=2.0),
"uuid3": self.mock_task("task2", "queue2", succeeded=False, failed=True),
"uuid4": self.mock_task(
"task2", "queue2", started=False, succeeded=False, uuid="uuid4"
),
}
self.state.tasks = self.task_list(
[
self.mock_task("task1", "queue1", runtime=1.0),
self.mock_task("task1", "queue1", runtime=2.0),
self.mock_task("task2", "queue2", succeeded=False, failed=True),
self.mock_task(
"task2", "queue2", started=False, succeeded=False, uuid="uuid4"
),
self.mock_task(
"celery.built_in", "queue2", started=False, succeeded=False
),
]
)
self.create_cloudwatch = partial(Cloudwatch, state=self.state, app=self.app)

@staticmethod
def task_list(tasks: list[Task]) -> dict[str, Task]:
return {task.uuid: task for task in tasks}

def mock_queue(self, name: str) -> MagicMock:
queue = MagicMock()
queue.name = name
Expand All @@ -51,10 +60,6 @@ def mock_task(
) -> Task:
if uuid is None:
uuid = str(uuid4())
if name is None:
name = "task"
if routing_key is None:
routing_key = "queue"
return Task(
uuid=uuid,
name=name,
Expand Down Expand Up @@ -192,8 +197,10 @@ def test_update(self, cloudwatch_camera: CloudwatchCameraFixture):
stats.update(mock_task)
assert len(stats.queued) == 1

# If the task is started, it should be removed from the queue.
# If the task is started, it should be removed from the queue, even if we no longer
# have its routing key.
mock_task.started = True
mock_task.routing_key = None
stats.update(mock_task)
assert len(stats.queued) == 0

Expand Down Expand Up @@ -236,7 +243,11 @@ def test__init__dryrun(self, cloudwatch_camera: CloudwatchCameraFixture):
cloudwatch = cloudwatch_camera.create_cloudwatch()
assert cloudwatch.cloudwatch_client is None

def test_on_shutter(self, cloudwatch_camera: CloudwatchCameraFixture):
def test_on_shutter(
self,
cloudwatch_camera: CloudwatchCameraFixture,
caplog: pytest.LogCaptureFixture,
):
cloudwatch = cloudwatch_camera.create_cloudwatch()
mock_publish = create_autospec(cloudwatch.publish)
cloudwatch.publish = mock_publish
Expand All @@ -255,23 +266,88 @@ def test_on_shutter(self, cloudwatch_camera: CloudwatchCameraFixture):
}
assert time.isoformat() == "2021-01-01T00:00:00+00:00"

def test_on_shutter_error(
self,
cloudwatch_camera: CloudwatchCameraFixture,
caplog: pytest.LogCaptureFixture,
):
cloudwatch_camera.app.tasks = {"task1": MagicMock()}
cloudwatch = cloudwatch_camera.create_cloudwatch()
mock_publish = create_autospec(cloudwatch.publish)
cloudwatch.publish = mock_publish
# We can also handle the case where we see a task with an unknown queue or unknown name.
cloudwatch_camera.state.tasks = cloudwatch_camera.task_list(
[
cloudwatch_camera.mock_task(
"task2",
"unknown_queue",
started=False,
succeeded=False,
uuid="uuid5",
runtime=3.0,
),
cloudwatch_camera.mock_task(
"unknown_task",
"queue1",
failed=True,
succeeded=False,
uuid="uuid6",
),
]
)
cloudwatch.on_shutter(cloudwatch_camera.state)
mock_publish.assert_called_once()
[tasks, queues, time] = mock_publish.call_args.args
[tasks, queues, _] = mock_publish.call_args.args
assert tasks == {
"task1": TaskStats(),
"task2": TaskStats(),
"unknown_task": TaskStats(failed=1),
}
assert queues == {
"queue1": QueueStats(),
"queue2": QueueStats(queued={"uuid4"}),
"unknown_queue": QueueStats(queued={"uuid5"}),
}

assert tasks == {"task1": TaskStats(succeeded=2, runtime=[1.0, 2.0])}
assert queues == {"queue1": QueueStats(), "queue2": QueueStats()}
assert time is not None
assert "Error processing task" in caplog.text
# We can handle tasks with no name or tasks with no routing key.
caplog.clear()
cloudwatch_camera.state.tasks = cloudwatch_camera.task_list(
[
cloudwatch_camera.mock_task(
None, "unknown_queue", started=True, uuid="uuid7"
),
cloudwatch_camera.mock_task(
"task2", None, started=False, succeeded=False, uuid="uuid8"
),
cloudwatch_camera.mock_task(None, None, started=True, uuid="uuid5"),
]
)
cloudwatch.on_shutter(cloudwatch_camera.state)
[tasks, queues, _] = mock_publish.call_args.args
assert tasks == {
"task1": TaskStats(),
"task2": TaskStats(),
"unknown_task": TaskStats(),
}
assert queues == {
"queue1": QueueStats(),
"queue2": QueueStats(queued={"uuid4"}),
"unknown_queue": QueueStats(),
}

# We log the information about tasks with no name or routing key.
(
no_name_warning_1,
no_routing_key_warning_1,
no_name_warning_2,
no_routing_key_warning_2,
) = caplog.messages
assert (
"Task has no name. [routing_key]:unknown_queue, [sent]:True, [started]:True, [uuid]:uuid7."
in no_name_warning_1
)
assert (
"Task has no routing_key. [name]:task2, [sent]:True, [started]:False, [uuid]:uuid8."
in no_routing_key_warning_1
)
assert (
"Task has no name. [sent]:True, [started]:True, [uuid]:uuid5."
in no_name_warning_2
)
assert (
"Task has no routing_key. [sent]:True, [started]:True, [uuid]:uuid5."
in no_routing_key_warning_2
)

def test_publish(
self,
Expand Down

0 comments on commit 7777f85

Please sign in to comment.