Skip to content

Commit

Permalink
More accurately monitor celery queue length (PP-1397) (#1915)
Browse files Browse the repository at this point in the history
* More accurately monitor queue length
* Fix test
  • Loading branch information
jonathangreen authored Jun 21, 2024
1 parent b66ee21 commit 11f4e2c
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 136 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ module = [
"html_sanitizer",
"isbnlib",
"jwcrypto",
"kombu",
"kombu.*",
"lxml.*",
"money",
"multipledispatch",
Expand Down
70 changes: 37 additions & 33 deletions src/palace/manager/celery/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

import boto3
from boto3.exceptions import Boto3Error
from botocore.exceptions import BotoCoreError
from celery.events.snapshot import Polaroid
from celery.events.state import State, Task
from kombu.transport.redis import PrefixedStrictRedis
from redis import ConnectionPool

from palace.manager.core.exceptions import PalaceValueError
from palace.manager.util import chunks
from palace.manager.util.datetime_helpers import utc_now
from palace.manager.util.log import LoggerMixin, logger_for_cls
from palace.manager.util.log import logger_for_cls

if TYPE_CHECKING:
from mypy_boto3_cloudwatch.literals import StandardUnitType
Expand Down Expand Up @@ -115,29 +119,20 @@ def reset(self) -> None:
self.runtime = []


@dataclass
class QueueStats(LoggerMixin):
@dataclass(frozen=True)
class QueueStats:
"""
Tracks the number of tasks queued for a specific queue, so we can
report this out to Cloudwatch metrics.
"""

queued: set[str] = field(default_factory=set)

def update(self, task: Task) -> None:
self.log.debug("Task: %r", task)
if task.uuid in self.queued and task.started:
self.log.debug(f"Task {task.uuid} started.")
self.queued.remove(task.uuid)
elif task.sent and not task.started:
self.log.debug(f"Task {task.uuid} queued.")
self.queued.add(task.uuid)
queued: int

def metrics(
self, timestamp: datetime, dimensions: dict[str, str]
) -> list[MetricDatumTypeDef]:
return [
value_metric("QueueWaiting", len(self.queued), timestamp, dimensions),
value_metric("QueueWaiting", self.queued, timestamp, dimensions),
]


Expand All @@ -161,6 +156,11 @@ def __init__(
# because the base class Polaroid already defines a logger attribute,
# which conflicts with the logger() method in LoggerMixin.
self.logger = logger_for_cls(self.__class__)
broker_url = self.app.conf.get("broker_url")
broker_type = urlparse(broker_url).scheme if broker_url else None
if broker_type != "redis":
raise PalaceValueError(f"Broker type '{broker_type}' is not supported.")

region = self.app.conf.get("cloudwatch_statistics_region")
dryrun = self.app.conf.get("cloudwatch_statistics_dryrun")
self.cloudwatch_client = (
Expand All @@ -169,12 +169,10 @@ def __init__(
self.manager_name = self.app.conf.get("broker_transport_options", {}).get(
"global_keyprefix"
)
self.redis_client = self.get_redis_client(broker_url, self.manager_name)
self.namespace = self.app.conf.get("cloudwatch_statistics_namespace")
self.upload_size = self.app.conf.get("cloudwatch_statistics_upload_size")
self.queues = defaultdict(
QueueStats,
{queue.name: QueueStats() for queue in self.app.conf.get("task_queues")},
)
self.queues = {queue.name for queue in self.app.conf.get("task_queues")}
self.tasks: defaultdict[str, TaskStats] = defaultdict(
TaskStats,
{
Expand All @@ -184,6 +182,15 @@ def __init__(
},
)

@classmethod
def get_redis_client(
cls, broker_url: str, global_keyprefix: str | None
) -> PrefixedStrictRedis:
connection_pool = ConnectionPool.from_url(broker_url)
return PrefixedStrictRedis(
connection_pool=connection_pool, global_keyprefix=global_keyprefix
)

@staticmethod
def is_celery_task(task_name: str) -> bool:
return task_name.startswith("celery.")
Expand All @@ -201,12 +208,11 @@ def reset_tasks(self) -> None:
for task in self.tasks.values():
task.reset()

def on_shutter(self, state: State) -> None:
timestamp = utc_now()

def update_task_stats(self, state: State) -> None:
# Reset the task stats for each snapshot
self.reset_tasks()

# Update task stats for each task in the state
for task in state.tasks.values():
# Update task stats for each task
if task.name is None:
Expand All @@ -218,19 +224,17 @@ def on_shutter(self, state: State) -> None:
else:
self.tasks[task.name].update(task)

# Update queue stats for each task
if task.routing_key is None:
self.logger.warning(
f"Task has no routing_key. {self.task_info_str(task)}."
)
if task.started:
# Its possible that we are tracking this task, so we make sure its not in any of our queues
for queue in self.queues.values():
queue.update(task)
else:
self.queues[task.routing_key].update(task)
def get_queue_stats(self) -> dict[str, QueueStats]:
return {
queue: QueueStats(self.redis_client.llen(queue)) for queue in self.queues
}

def on_shutter(self, state: State) -> None:
timestamp = utc_now()
self.update_task_stats(state)
queue_stats = self.get_queue_stats()

self.publish(self.tasks, self.queues, timestamp)
self.publish(self.tasks, queue_stats, timestamp)

def publish(
self,
Expand Down
4 changes: 4 additions & 0 deletions src/palace/manager/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ def __init__(self, message: str | None = None):
self.message = message


class PalaceValueError(BasePalaceException, ValueError):
...


class IntegrationException(BasePalaceException):
"""An exception that happens when the site's connection to a
third-party service is broken.
Expand Down
Loading

0 comments on commit 11f4e2c

Please sign in to comment.