Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add concurrency for webhook worker #875

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion bot/kodiak/entrypoints/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def main() -> NoReturn:
if task_meta.kind == "repo":
queue.start_repo_worker(queue_name=task_meta.queue_name)
elif task_meta.kind == "webhook":
queue.start_webhook_worker(queue_name=task_meta.queue_name)
await queue.start_webhook_worker(queue_name=task_meta.queue_name)
else:
assert_never(task_meta.kind)
if ingest_queue_watcher.done():
Expand Down
85 changes: 65 additions & 20 deletions bot/kodiak/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from asyncio.tasks import Task
from dataclasses import dataclass
from datetime import timedelta
from typing import Iterator, MutableMapping, NoReturn, Optional, Tuple
from typing import Any, Iterator, MutableMapping, NoReturn, Optional, Sequence, Tuple

import sentry_sdk
import structlog
Expand Down Expand Up @@ -461,7 +461,7 @@ class TaskMeta:
class RedisWebhookQueue:
def __init__(self) -> None:
self.worker_tasks: MutableMapping[
str, tuple[Task[NoReturn], Literal["repo", "webhook"]]
str, list[tuple[Task[NoReturn], Literal["repo", "webhook"]]]
] = {} # type: ignore [assignment]

async def create(self) -> None:
Expand All @@ -476,13 +476,20 @@ async def create(self) -> None:

for webhook_result in webhook_queues:
queue_name = webhook_result.decode()
self.start_webhook_worker(queue_name=queue_name)
await self.start_webhook_worker(queue_name=queue_name)

async def start_webhook_worker(self, *, queue_name: str) -> None:
concurrency_str = await redis_bot.get("queue_concurrency:" + queue_name)
try:
concurrency = int(concurrency_str or 1)
except ValueError:
concurrency = 1

def start_webhook_worker(self, *, queue_name: str) -> None:
self._start_worker(
queue_name,
"webhook",
webhook_event_consumer(webhook_queue=self, queue_name=queue_name),
concurrency=concurrency,
)

def start_repo_worker(self, *, queue_name: str) -> None:
Expand All @@ -499,21 +506,58 @@ def _start_worker(
key: str,
kind: Literal["repo", "webhook"],
fut: typing.Coroutine[None, None, NoReturn],
*,
concurrency: int = 1,
) -> None:
log = logger.bind(queue_name=key, kind=kind)
worker_task_result = self.worker_tasks.get(key)
if worker_task_result is not None:
worker_task, _task_kind = worker_task_result
if not worker_task.done():
return
log.info("task failed")
# task failed. record result and restart
exception = worker_task.exception()
log.info("exception", excep=exception)
sentry_sdk.capture_exception(exception)
log.info("creating task for queue")
# create new task for queue
self.worker_tasks[key] = (asyncio.create_task(fut), kind)
worker_task_results = () # type: Sequence[tuple[Task[NoReturn], Literal["repo", "webhook"]]]
try:
worker_task_results = self.worker_tasks[key]
except KeyError:
pass
new_workers: list[tuple[asyncio.Task[Any], Literal["repo", "webhook"]]] = []

previous_task_count = len(worker_task_results)
failed_task_count = 0

for (worker_task, _task_kind) in worker_task_results:
if worker_task.done():
log.info("task failed")
# task failed. record result.
exception = worker_task.exception()
log.info("exception", excep=exception)
sentry_sdk.capture_exception(exception)
failed_task_count += 1
else:
new_workers.append((worker_task, _task_kind))
tasks_to_create = concurrency - len(new_workers)

tasks_created = 0
tasks_cancelled = 0
# we need to create tasks
if tasks_to_create > 0:
for _ in range(tasks_to_create):
new_workers.append((asyncio.create_task(fut), kind))
tasks_created += 1
# we need to remove tasks
elif tasks_to_create < 0:
# split off tasks we need to cancel.
new_workers, workers_to_delete = (
new_workers[:concurrency],
new_workers[concurrency:],
)
for (task, _kind) in workers_to_delete:
task.cancel()
tasks_cancelled += 1

self.worker_tasks[key] = new_workers
log.info(
"start_workers",
previous_task_count=previous_task_count,
failed_task_count=failed_task_count,
tasks_created=tasks_created,
tasks_cancelled=tasks_cancelled,
)

async def enqueue(self, *, event: WebhookEvent) -> None:
"""
Expand All @@ -531,7 +575,7 @@ async def enqueue(self, *, event: WebhookEvent) -> None:
install=event.installation_id,
)
log.info("enqueue webhook event")
self.start_webhook_worker(queue_name=queue_name)
await self.start_webhook_worker(queue_name=queue_name)

async def enqueue_for_repo(
self, *, event: WebhookEvent, first: bool
Expand Down Expand Up @@ -577,8 +621,9 @@ async def enqueue_for_repo(
return find_position((key for key, value in kvs), event.json().encode())

def all_tasks(self) -> Iterator[tuple[TaskMeta, Task[NoReturn]]]:
for queue_name, (task, task_kind) in self.worker_tasks.items():
yield (TaskMeta(kind=task_kind, queue_name=queue_name), task)
for queue_name, tasks in self.worker_tasks.items():
for (task, task_kind) in tasks:
yield (TaskMeta(kind=task_kind, queue_name=queue_name), task)


def get_merge_queue_name(event: WebhookEvent) -> str:
Expand Down