diff --git a/bot/kodiak/app_config.py b/bot/kodiak/app_config.py index c1a3b9314..213c28a94 100644 --- a/bot/kodiak/app_config.py +++ b/bot/kodiak/app_config.py @@ -1,6 +1,6 @@ import base64 from pathlib import Path -from typing import Any, Optional, Type, TypeVar, overload +from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, overload import databases from starlette.config import Config, undefined @@ -54,6 +54,24 @@ def __call__( default=["pull_request", "pull_request_review", "pull_request_comment"], ) ) + + +def parse_worker_concurrency(items: Sequence[str]) -> Mapping[str, int]: + maps = {} + for item in items: + (install, concurrency) = item.split("=", maxsplit=2) + maps[install] = int(concurrency) + return maps + + +# 12312309=4,1290301293=1 +WEBHOOK_WORKER_CONCURRENCY = parse_worker_concurrency( + config( + "WEBHOOK_WORKER_CONCURRENCY", + cast=CommaSeparatedStrings, + default=[], + ) +) USAGE_REPORTING_QUEUE_LENGTH = config( "USAGE_REPORTING_QUEUE_LENGTH", cast=int, default=10_000 ) diff --git a/bot/kodiak/entrypoints/worker.py b/bot/kodiak/entrypoints/worker.py index a2ca46aa3..2c51e6c4d 100644 --- a/bot/kodiak/entrypoints/worker.py +++ b/bot/kodiak/entrypoints/worker.py @@ -127,7 +127,7 @@ async def main() -> NoReturn: if task_meta.kind == "repo": queue.start_repo_worker(queue_name=task_meta.queue_name) elif task_meta.kind == "webhook": - queue.start_webhook_worker(queue_name=task_meta.queue_name) + await queue.start_webhook_worker(queue_name=task_meta.queue_name) else: assert_never(task_meta.kind) if ingest_queue_watcher.done(): diff --git a/bot/kodiak/logging.py b/bot/kodiak/logging.py index 54eaaeea8..4963bd3a5 100644 --- a/bot/kodiak/logging.py +++ b/bot/kodiak/logging.py @@ -1,16 +1,17 @@ import logging import sys -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union +import inflection import sentry_sdk import structlog -from requests import Response from sentry_sdk import capture_event from sentry_sdk.integrations.logging import LoggingIntegration from sentry_sdk.utils import event_from_exception from typing_extensions import Literal from kodiak import app_config as conf +from kodiak.http import Response ################################################################################ # based on https://github.com/kiwicom/structlog-sentry/blob/18adbfdac85930ca5578e7ef95c1f2dc169c2f2f/structlog_sentry/__init__.py#L10-L86 @@ -117,11 +118,23 @@ def add_request_info_processor( """ response = event_dict.get("res", None) if isinstance(response, Response): - event_dict["response_content"] = cast(Any, response)._content + event_dict["response_content"] = response._content event_dict["response_status_code"] = response.status_code - event_dict["request_body"] = response.request.body event_dict["request_url"] = response.request.url event_dict["request_method"] = response.request.method + + for header in ( + "retry-after", + "x-ratelimit-limit", + "x-ratelimit-remaining", + "x-ratelimit-reset", + "x-ratelimit-used", + "x-ratelimit-resource", + ): + event_dict[ + f"response_header_{inflection.underscore(header)}" + ] = response.headers.get(header) + return event_dict diff --git a/bot/kodiak/pull_request.py b/bot/kodiak/pull_request.py index 120adb9d9..c074592f5 100644 --- a/bot/kodiak/pull_request.py +++ b/bot/kodiak/pull_request.py @@ -16,7 +16,7 @@ ) from kodiak.evaluation import mergeable from kodiak.http import HTTPStatusError as HTTPError -from kodiak.queries import Client, EventInfoResponse +from kodiak.queries import Client, EventInfoResponse, SecondaryRateLimit logger = structlog.get_logger() @@ -152,6 +152,9 @@ async def evaluate_pr( continue log.warning("api_call_retries_remaining", exc_info=True) return + except SecondaryRateLimit: + log.info("secondary_rate_limit") + await requeue_callback() except asyncio.TimeoutError: # On timeout we add the PR to the back of the queue to try again. log.warning("mergeable_timeout", exc_info=True) diff --git a/bot/kodiak/queries/__init__.py b/bot/kodiak/queries/__init__.py index 486917a57..f8574a893 100644 --- a/bot/kodiak/queries/__init__.py +++ b/bot/kodiak/queries/__init__.py @@ -822,6 +822,10 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None ... +class SecondaryRateLimit(Exception): + ... + + class Client: throttler: ThrottlerProtocol @@ -852,7 +856,7 @@ def __init__(self, *, owner: str, repo: str, installation_id: str): ) async def __aenter__(self) -> Client: - self.throttler = get_thottler_for_installation( + self.throttler = await get_thottler_for_installation( installation_id=self.installation_id ) return self @@ -883,6 +887,11 @@ async def send_query( try: res.raise_for_status() except http.HTTPError: + if ( + res.status_code == 403 + and b"You have exceeded a secondary rate limit" in res.content + ): + raise SecondaryRateLimit() log.warning("github api request error", res=res, exc_info=True) return None return cast(GraphQLResponse, res.json()) @@ -1383,7 +1392,7 @@ async def get_token_for_install( app_token = generate_jwt( private_key=conf.PRIVATE_KEY, app_identifier=conf.GITHUB_APP_ID ) - throttler = get_thottler_for_installation( + throttler = await get_thottler_for_installation( # this isn't a real installation ID, but it provides rate limiting # for our GithubApp instead of the installations we typically act as installation_id=APPLICATION_ID diff --git a/bot/kodiak/queue.py b/bot/kodiak/queue.py index 94880a37f..d42d6598a 100644 --- a/bot/kodiak/queue.py +++ b/bot/kodiak/queue.py @@ -5,10 +5,20 @@ import time import typing import urllib +import uuid from asyncio.tasks import Task from dataclasses import dataclass from datetime import timedelta -from typing import Iterator, MutableMapping, NoReturn, Optional, Tuple +from typing import ( + Any, + Callable, + Iterator, + MutableMapping, + NoReturn, + Optional, + Sequence, + Tuple, +) import sentry_sdk import structlog @@ -367,7 +377,9 @@ async def webhook_event_consumer( scope.set_tag("queue", queue_name) scope.set_tag("installation", installation_id_from_queue(queue_name)) log = logger.bind( - queue=queue_name, install=installation_id_from_queue(queue_name) + queue=queue_name, + install=installation_id_from_queue(queue_name), + task_id=uuid.uuid4().hex, ) log.info("start webhook event consumer") while True: @@ -461,7 +473,7 @@ class TaskMeta: class RedisWebhookQueue: def __init__(self) -> None: self.worker_tasks: MutableMapping[ - str, tuple[Task[NoReturn], Literal["repo", "webhook"]] + str, list[tuple[Task[NoReturn], Literal["repo", "webhook"]]] ] = {} # type: ignore [assignment] async def create(self) -> None: @@ -476,20 +488,26 @@ async def create(self) -> None: for webhook_result in webhook_queues: queue_name = webhook_result.decode() - self.start_webhook_worker(queue_name=queue_name) - - def start_webhook_worker(self, *, queue_name: str) -> None: + await self.start_webhook_worker(queue_name=queue_name) + + async def start_webhook_worker(self, *, queue_name: str) -> None: + concurrency_str = await redis_bot.get("queue_concurrency:" + queue_name) + try: + concurrency = int(concurrency_str or 1) + except ValueError: + concurrency = 1 self._start_worker( queue_name, "webhook", - webhook_event_consumer(webhook_queue=self, queue_name=queue_name), + lambda: webhook_event_consumer(webhook_queue=self, queue_name=queue_name), + concurrency=concurrency, ) def start_repo_worker(self, *, queue_name: str) -> None: self._start_worker( queue_name, "repo", - repo_queue_consumer( + lambda: repo_queue_consumer( queue_name=queue_name, ), ) @@ -498,22 +516,59 @@ def _start_worker( self, key: str, kind: Literal["repo", "webhook"], - fut: typing.Coroutine[None, None, NoReturn], + fut: Callable[[], typing.Coroutine[None, None, NoReturn]], + *, + concurrency: int = 1, ) -> None: log = logger.bind(queue_name=key, kind=kind) - worker_task_result = self.worker_tasks.get(key) - if worker_task_result is not None: - worker_task, _task_kind = worker_task_result - if not worker_task.done(): - return - log.info("task failed") - # task failed. record result and restart - exception = worker_task.exception() - log.info("exception", excep=exception) - sentry_sdk.capture_exception(exception) - log.info("creating task for queue") - # create new task for queue - self.worker_tasks[key] = (asyncio.create_task(fut), kind) + worker_task_results = () # type: Sequence[tuple[Task[NoReturn], Literal["repo", "webhook"]]] + try: + worker_task_results = self.worker_tasks[key] + except KeyError: + pass + new_workers: list[tuple[asyncio.Task[Any], Literal["repo", "webhook"]]] = [] + + previous_task_count = len(worker_task_results) + failed_task_count = 0 + + for (worker_task, _task_kind) in worker_task_results: + if worker_task.done(): + log.info("task failed") + # task failed. record result. + exception = worker_task.exception() + log.info("exception", excep=exception) + sentry_sdk.capture_exception(exception) + failed_task_count += 1 + else: + new_workers.append((worker_task, _task_kind)) + tasks_to_create = concurrency - len(new_workers) + + tasks_created = 0 + tasks_cancelled = 0 + # we need to create tasks + if tasks_to_create > 0: + for _ in range(tasks_to_create): + new_workers.append((asyncio.create_task(fut()), kind)) + tasks_created += 1 + # we need to remove tasks + elif tasks_to_create < 0: + # split off tasks we need to cancel. + new_workers, workers_to_delete = ( + new_workers[:concurrency], + new_workers[concurrency:], + ) + for (task, _kind) in workers_to_delete: + task.cancel() + tasks_cancelled += 1 + + self.worker_tasks[key] = new_workers + log.info( + "start_workers", + previous_task_count=previous_task_count, + failed_task_count=failed_task_count, + tasks_created=tasks_created, + tasks_cancelled=tasks_cancelled, + ) async def enqueue(self, *, event: WebhookEvent) -> None: """ @@ -531,7 +586,7 @@ async def enqueue(self, *, event: WebhookEvent) -> None: install=event.installation_id, ) log.info("enqueue webhook event") - self.start_webhook_worker(queue_name=queue_name) + await self.start_webhook_worker(queue_name=queue_name) async def enqueue_for_repo( self, *, event: WebhookEvent, first: bool @@ -577,8 +632,9 @@ async def enqueue_for_repo( return find_position((key for key, value in kvs), event.json().encode()) def all_tasks(self) -> Iterator[tuple[TaskMeta, Task[NoReturn]]]: - for queue_name, (task, task_kind) in self.worker_tasks.items(): - yield (TaskMeta(kind=task_kind, queue_name=queue_name), task) + for queue_name, tasks in self.worker_tasks.items(): + for (task, task_kind) in tasks: + yield (TaskMeta(kind=task_kind, queue_name=queue_name), task) def get_merge_queue_name(event: WebhookEvent) -> str: diff --git a/bot/kodiak/test_logging.py b/bot/kodiak/test_logging.py index 929be9f09..470f5e7c8 100644 --- a/bot/kodiak/test_logging.py +++ b/bot/kodiak/test_logging.py @@ -1,10 +1,9 @@ -import json import logging -from typing import Any, cast +from typing import Any import pytest -from requests import PreparedRequest, Request, Response +from kodiak.http import Request, Response from kodiak.logging import ( SentryLevel, SentryProcessor, @@ -160,20 +159,13 @@ def test_add_request_info_processor() -> None: url = "https://api.example.com/v1/me" payload = dict(user_id=54321) req = Request("POST", url, json=payload) - res = Response() - res.status_code = 500 - res.url = url - res.reason = "Internal Server Error" - cast( - Any, res - )._content = b"Your request could not be completed due to an internal error." - res.request = cast(PreparedRequest, req.prepare()) # type: ignore + res = Response(status_code=500, request=req) + res._content = b"Your request could not be completed due to an internal error." event_dict = add_request_info_processor( None, None, dict(event="request failed", res=res) ) - assert event_dict["response_content"] == cast(Any, res)._content + assert event_dict["response_content"] == res._content assert event_dict["response_status_code"] == res.status_code - assert event_dict["request_body"] == json.dumps(payload).encode() assert event_dict["request_url"] == req.url assert event_dict["request_method"] == "POST" assert event_dict["res"] is res diff --git a/bot/kodiak/test_queries.py b/bot/kodiak/test_queries.py index 639990249..605961f0c 100644 --- a/bot/kodiak/test_queries.py +++ b/bot/kodiak/test_queries.py @@ -55,7 +55,8 @@ def github_installation_id() -> str: @pytest.fixture def api_client(mocker: MockFixture, github_installation_id: str) -> Client: mocker.patch( - "kodiak.queries.get_thottler_for_installation", return_value=FakeThottler() + "kodiak.queries.get_thottler_for_installation", + return_value=wrap_future(FakeThottler()), ) client = Client(installation_id=github_installation_id, owner="foo", repo="foo") mocker.patch.object(client, "send_query") diff --git a/bot/kodiak/throttle.py b/bot/kodiak/throttle.py index ce977ea87..f2288d543 100644 --- a/bot/kodiak/throttle.py +++ b/bot/kodiak/throttle.py @@ -1,10 +1,13 @@ import asyncio import time -from collections import defaultdict, deque -from typing import Any, Mapping +from collections import deque +from dataclasses import dataclass, field +from typing import Any, MutableMapping from typing_extensions import Deque +from kodiak.redis_client import redis_bot + class Throttler: """ @@ -66,12 +69,47 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: pass -# installation_id => Throttler -THROTTLER_CACHE: Mapping[str, Throttler] = defaultdict( - # TODO(chdsbd): Store rate limits in redis and update via http rate limit response headers - lambda: Throttler(rate_limit=5000 / 60 / 60) -) +DEFAULT_REQUESTS_PER_HOUR = 5000 +CACHE_TTL_SECS = 60 + + +@dataclass(frozen=True) +class ThrottleEntry: + throttler: Throttler + insertion_time: float = field(default_factory=time.monotonic) + + def expired(self) -> bool: + return (time.monotonic() - self.insertion_time) > CACHE_TTL_SECS + + +class ThrottleCache: + + _cache: MutableMapping[str, ThrottleEntry] + + def __init__(self) -> None: + self._cache = {} + + async def get(self, *, installation_id: str) -> Throttler: + cache_entry = self._cache.get(installation_id) + if cache_entry and not cache_entry.expired(): + return cache_entry.throttler + try: + requests_per_hour_str = await redis_bot.get( + "installation_rate_limit:" + installation_id + ) + installation_rate_limit = int( + requests_per_hour_str or DEFAULT_REQUESTS_PER_HOUR + ) + except ValueError: + installation_rate_limit = DEFAULT_REQUESTS_PER_HOUR + self._cache[installation_id] = ThrottleEntry( + throttler=Throttler(rate_limit=installation_rate_limit / 60 / 60) + ) + return self._cache[installation_id].throttler + + +throttle_cache = ThrottleCache() -def get_thottler_for_installation(*, installation_id: str) -> Throttler: - return THROTTLER_CACHE[installation_id] +async def get_thottler_for_installation(*, installation_id: str) -> Throttler: + return await throttle_cache.get(installation_id=installation_id)