Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add concurrency for webhook worker #875

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
20 changes: 19 additions & 1 deletion bot/kodiak/app_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64
from pathlib import Path
from typing import Any, Optional, Type, TypeVar, overload
from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, overload

import databases
from starlette.config import Config, undefined
Expand Down Expand Up @@ -54,6 +54,24 @@ def __call__(
default=["pull_request", "pull_request_review", "pull_request_comment"],
)
)


def parse_worker_concurrency(items: Sequence[str]) -> Mapping[str, int]:
maps = {}
for item in items:
(install, concurrency) = item.split("=", maxsplit=2)
maps[install] = int(concurrency)
return maps


# 12312309=4,1290301293=1
WEBHOOK_WORKER_CONCURRENCY = parse_worker_concurrency(
config(
"WEBHOOK_WORKER_CONCURRENCY",
cast=CommaSeparatedStrings,
default=[],
)
)
USAGE_REPORTING_QUEUE_LENGTH = config(
"USAGE_REPORTING_QUEUE_LENGTH", cast=int, default=10_000
)
Expand Down
2 changes: 1 addition & 1 deletion bot/kodiak/entrypoints/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def main() -> NoReturn:
if task_meta.kind == "repo":
queue.start_repo_worker(queue_name=task_meta.queue_name)
elif task_meta.kind == "webhook":
queue.start_webhook_worker(queue_name=task_meta.queue_name)
await queue.start_webhook_worker(queue_name=task_meta.queue_name)
else:
assert_never(task_meta.kind)
if ingest_queue_watcher.done():
Expand Down
21 changes: 17 additions & 4 deletions bot/kodiak/logging.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import logging
import sys
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union

import inflection
import sentry_sdk
import structlog
from requests import Response
from sentry_sdk import capture_event
from sentry_sdk.integrations.logging import LoggingIntegration
from sentry_sdk.utils import event_from_exception
from typing_extensions import Literal

from kodiak import app_config as conf
from kodiak.http import Response

################################################################################
# based on https://github.com/kiwicom/structlog-sentry/blob/18adbfdac85930ca5578e7ef95c1f2dc169c2f2f/structlog_sentry/__init__.py#L10-L86
Expand Down Expand Up @@ -117,11 +118,23 @@ def add_request_info_processor(
"""
response = event_dict.get("res", None)
if isinstance(response, Response):
event_dict["response_content"] = cast(Any, response)._content
event_dict["response_content"] = response._content
event_dict["response_status_code"] = response.status_code
event_dict["request_body"] = response.request.body
event_dict["request_url"] = response.request.url
event_dict["request_method"] = response.request.method

for header in (
"retry-after",
"x-ratelimit-limit",
"x-ratelimit-remaining",
"x-ratelimit-reset",
"x-ratelimit-used",
"x-ratelimit-resource",
):
event_dict[
f"response_header_{inflection.underscore(header)}"
] = response.headers.get(header)

return event_dict


Expand Down
5 changes: 4 additions & 1 deletion bot/kodiak/pull_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from kodiak.evaluation import mergeable
from kodiak.http import HTTPStatusError as HTTPError
from kodiak.queries import Client, EventInfoResponse
from kodiak.queries import Client, EventInfoResponse, SecondaryRateLimit

logger = structlog.get_logger()

Expand Down Expand Up @@ -152,6 +152,9 @@ async def evaluate_pr(
continue
log.warning("api_call_retries_remaining", exc_info=True)
return
except SecondaryRateLimit:
log.info("secondary_rate_limit")
await requeue_callback()
except asyncio.TimeoutError:
# On timeout we add the PR to the back of the queue to try again.
log.warning("mergeable_timeout", exc_info=True)
Expand Down
13 changes: 11 additions & 2 deletions bot/kodiak/queries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,10 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None
...


class SecondaryRateLimit(Exception):
...


class Client:
throttler: ThrottlerProtocol

Expand Down Expand Up @@ -852,7 +856,7 @@ def __init__(self, *, owner: str, repo: str, installation_id: str):
)

async def __aenter__(self) -> Client:
self.throttler = get_thottler_for_installation(
self.throttler = await get_thottler_for_installation(
installation_id=self.installation_id
)
return self
Expand Down Expand Up @@ -883,6 +887,11 @@ async def send_query(
try:
res.raise_for_status()
except http.HTTPError:
if (
res.status_code == 403
and b"You have exceeded a secondary rate limit" in res.content
):
raise SecondaryRateLimit()
log.warning("github api request error", res=res, exc_info=True)
return None
return cast(GraphQLResponse, res.json())
Expand Down Expand Up @@ -1383,7 +1392,7 @@ async def get_token_for_install(
app_token = generate_jwt(
private_key=conf.PRIVATE_KEY, app_identifier=conf.GITHUB_APP_ID
)
throttler = get_thottler_for_installation(
throttler = await get_thottler_for_installation(
# this isn't a real installation ID, but it provides rate limiting
# for our GithubApp instead of the installations we typically act as
installation_id=APPLICATION_ID
Expand Down
106 changes: 81 additions & 25 deletions bot/kodiak/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,20 @@
import time
import typing
import urllib
import uuid
from asyncio.tasks import Task
from dataclasses import dataclass
from datetime import timedelta
from typing import Iterator, MutableMapping, NoReturn, Optional, Tuple
from typing import (
Any,
Callable,
Iterator,
MutableMapping,
NoReturn,
Optional,
Sequence,
Tuple,
)

import sentry_sdk
import structlog
Expand Down Expand Up @@ -367,7 +377,9 @@ async def webhook_event_consumer(
scope.set_tag("queue", queue_name)
scope.set_tag("installation", installation_id_from_queue(queue_name))
log = logger.bind(
queue=queue_name, install=installation_id_from_queue(queue_name)
queue=queue_name,
install=installation_id_from_queue(queue_name),
task_id=uuid.uuid4().hex,
)
log.info("start webhook event consumer")
while True:
Expand Down Expand Up @@ -461,7 +473,7 @@ class TaskMeta:
class RedisWebhookQueue:
def __init__(self) -> None:
self.worker_tasks: MutableMapping[
str, tuple[Task[NoReturn], Literal["repo", "webhook"]]
str, list[tuple[Task[NoReturn], Literal["repo", "webhook"]]]
] = {} # type: ignore [assignment]

async def create(self) -> None:
Expand All @@ -476,20 +488,26 @@ async def create(self) -> None:

for webhook_result in webhook_queues:
queue_name = webhook_result.decode()
self.start_webhook_worker(queue_name=queue_name)

def start_webhook_worker(self, *, queue_name: str) -> None:
await self.start_webhook_worker(queue_name=queue_name)

async def start_webhook_worker(self, *, queue_name: str) -> None:
concurrency_str = await redis_bot.get("queue_concurrency:" + queue_name)
try:
concurrency = int(concurrency_str or 1)
except ValueError:
concurrency = 1
self._start_worker(
queue_name,
"webhook",
webhook_event_consumer(webhook_queue=self, queue_name=queue_name),
lambda: webhook_event_consumer(webhook_queue=self, queue_name=queue_name),
concurrency=concurrency,
)

def start_repo_worker(self, *, queue_name: str) -> None:
self._start_worker(
queue_name,
"repo",
repo_queue_consumer(
lambda: repo_queue_consumer(
queue_name=queue_name,
),
)
Expand All @@ -498,22 +516,59 @@ def _start_worker(
self,
key: str,
kind: Literal["repo", "webhook"],
fut: typing.Coroutine[None, None, NoReturn],
fut: Callable[[], typing.Coroutine[None, None, NoReturn]],
*,
concurrency: int = 1,
) -> None:
log = logger.bind(queue_name=key, kind=kind)
worker_task_result = self.worker_tasks.get(key)
if worker_task_result is not None:
worker_task, _task_kind = worker_task_result
if not worker_task.done():
return
log.info("task failed")
# task failed. record result and restart
exception = worker_task.exception()
log.info("exception", excep=exception)
sentry_sdk.capture_exception(exception)
log.info("creating task for queue")
# create new task for queue
self.worker_tasks[key] = (asyncio.create_task(fut), kind)
worker_task_results = () # type: Sequence[tuple[Task[NoReturn], Literal["repo", "webhook"]]]
try:
worker_task_results = self.worker_tasks[key]
except KeyError:
pass
new_workers: list[tuple[asyncio.Task[Any], Literal["repo", "webhook"]]] = []

previous_task_count = len(worker_task_results)
failed_task_count = 0

for (worker_task, _task_kind) in worker_task_results:
if worker_task.done():
log.info("task failed")
# task failed. record result.
exception = worker_task.exception()
log.info("exception", excep=exception)
sentry_sdk.capture_exception(exception)
failed_task_count += 1
else:
new_workers.append((worker_task, _task_kind))
tasks_to_create = concurrency - len(new_workers)

tasks_created = 0
tasks_cancelled = 0
# we need to create tasks
if tasks_to_create > 0:
for _ in range(tasks_to_create):
new_workers.append((asyncio.create_task(fut()), kind))
tasks_created += 1
# we need to remove tasks
elif tasks_to_create < 0:
# split off tasks we need to cancel.
new_workers, workers_to_delete = (
new_workers[:concurrency],
new_workers[concurrency:],
)
for (task, _kind) in workers_to_delete:
task.cancel()
tasks_cancelled += 1

self.worker_tasks[key] = new_workers
log.info(
"start_workers",
previous_task_count=previous_task_count,
failed_task_count=failed_task_count,
tasks_created=tasks_created,
tasks_cancelled=tasks_cancelled,
)

async def enqueue(self, *, event: WebhookEvent) -> None:
"""
Expand All @@ -531,7 +586,7 @@ async def enqueue(self, *, event: WebhookEvent) -> None:
install=event.installation_id,
)
log.info("enqueue webhook event")
self.start_webhook_worker(queue_name=queue_name)
await self.start_webhook_worker(queue_name=queue_name)

async def enqueue_for_repo(
self, *, event: WebhookEvent, first: bool
Expand Down Expand Up @@ -577,8 +632,9 @@ async def enqueue_for_repo(
return find_position((key for key, value in kvs), event.json().encode())

def all_tasks(self) -> Iterator[tuple[TaskMeta, Task[NoReturn]]]:
for queue_name, (task, task_kind) in self.worker_tasks.items():
yield (TaskMeta(kind=task_kind, queue_name=queue_name), task)
for queue_name, tasks in self.worker_tasks.items():
for (task, task_kind) in tasks:
yield (TaskMeta(kind=task_kind, queue_name=queue_name), task)


def get_merge_queue_name(event: WebhookEvent) -> str:
Expand Down
18 changes: 5 additions & 13 deletions bot/kodiak/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
import logging
from typing import Any, cast
from typing import Any

import pytest
from requests import PreparedRequest, Request, Response

from kodiak.http import Request, Response
from kodiak.logging import (
SentryLevel,
SentryProcessor,
Expand Down Expand Up @@ -160,20 +159,13 @@ def test_add_request_info_processor() -> None:
url = "https://api.example.com/v1/me"
payload = dict(user_id=54321)
req = Request("POST", url, json=payload)
res = Response()
res.status_code = 500
res.url = url
res.reason = "Internal Server Error"
cast(
Any, res
)._content = b"Your request could not be completed due to an internal error."
res.request = cast(PreparedRequest, req.prepare()) # type: ignore
res = Response(status_code=500, request=req)
res._content = b"Your request could not be completed due to an internal error."
event_dict = add_request_info_processor(
None, None, dict(event="request failed", res=res)
)
assert event_dict["response_content"] == cast(Any, res)._content
assert event_dict["response_content"] == res._content
assert event_dict["response_status_code"] == res.status_code
assert event_dict["request_body"] == json.dumps(payload).encode()
assert event_dict["request_url"] == req.url
assert event_dict["request_method"] == "POST"
assert event_dict["res"] is res
3 changes: 2 additions & 1 deletion bot/kodiak/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def github_installation_id() -> str:
@pytest.fixture
def api_client(mocker: MockFixture, github_installation_id: str) -> Client:
mocker.patch(
"kodiak.queries.get_thottler_for_installation", return_value=FakeThottler()
"kodiak.queries.get_thottler_for_installation",
return_value=wrap_future(FakeThottler()),
)
client = Client(installation_id=github_installation_id, owner="foo", repo="foo")
mocker.patch.object(client, "send_query")
Expand Down
Loading