From f5b46770fbfaf5a10af514a4f77758d4a2de3f20 Mon Sep 17 00:00:00 2001 From: Le Bao Hiep Date: Sun, 22 Sep 2024 22:53:01 +0700 Subject: [PATCH] Add judge balancer --- dmoj/settings.py | 3 + judge/balancer/bridge_handler.py | 245 +++++++++++++++++++++++ judge/balancer/daemon.py | 27 +++ judge/balancer/judge_handler.py | 158 +++++++++++++++ judge/balancer/judge_list.py | 0 judge/balancer/scheduler.py | 134 +++++++++++++ judge/balancer/sysinfo.py | 33 +++ judge/bridge/judge_handler.py | 48 +++-- judge/bridge/judge_list.py | 12 +- judge/bridge/monitor.py | 3 +- judge/management/commands/runbalancer.py | 14 ++ 11 files changed, 656 insertions(+), 21 deletions(-) create mode 100644 judge/balancer/bridge_handler.py create mode 100644 judge/balancer/daemon.py create mode 100644 judge/balancer/judge_handler.py create mode 100644 judge/balancer/judge_list.py create mode 100644 judge/balancer/scheduler.py create mode 100644 judge/balancer/sysinfo.py create mode 100644 judge/management/commands/runbalancer.py diff --git a/dmoj/settings.py b/dmoj/settings.py index 06ffdf42c..e9ecdc3c1 100755 --- a/dmoj/settings.py +++ b/dmoj/settings.py @@ -653,6 +653,9 @@ ENABLE_FTS = False +# Balancer configuration +BALANCER_JUDGE_ADDRESS = [('localhost', 8888)] + # Bridged configuration BRIDGED_JUDGE_ADDRESS = [('localhost', 9999)] BRIDGED_JUDGE_PROXIES = None diff --git a/judge/balancer/bridge_handler.py b/judge/balancer/bridge_handler.py new file mode 100644 index 000000000..dff2ee444 --- /dev/null +++ b/judge/balancer/bridge_handler.py @@ -0,0 +1,245 @@ +import errno +import json +import logging +import socket +import ssl +import struct +import threading +import time +import zlib +from typing import Optional + +from judge.balancer import sysinfo + + +log = logging.getLogger(__name__) + + +class JudgeAuthenticationFailed(Exception): + pass + + +class BridgeHandler: + SIZE_PACK = struct.Struct('!I') + + ssl_context: Optional[ssl.SSLContext] + + def __init__( + self, + host: str, + port: int, + id: str, + key: str, + scheduler, + bridge_id: int, + secure: bool = False, + no_cert_check: bool = False, + cert_store: Optional[str] = None, + **kwargs, + ): + self.host = host + self.port = port + self.scheduler = scheduler + self.name = id + self.key = key + self.bridge_id = bridge_id + self._closed = False + + log.info('Preparing to connect to [%s]:%s as: %s', host, port, id) + if secure: + self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.ssl_context.options |= ssl.OP_NO_SSLv2 + self.ssl_context.options |= ssl.OP_NO_SSLv3 + + if not no_cert_check: + self.ssl_context.verify_mode = ssl.CERT_REQUIRED + self.ssl_context.check_hostname = True + + if cert_store is None: + self.ssl_context.load_default_certs() + else: + self.ssl_context.load_verify_locations(cafile=cert_store) + log.info('Configured to use TLS.') + else: + self.ssl_context = None + log.info('TLS not enabled.') + + self.secure = secure + self.no_cert_check = no_cert_check + self.cert_store = cert_store + + self._lock = threading.RLock() + self.shutdown_requested = False + + # Exponential backoff: starting at 4 seconds, max 60 seconds. + # If it fails to connect for something like 7 hours, it could RecursionError. + self.fallback = 4 + + self.conn = None + self._do_reconnect() + + def _connect(self): + problems = [] # should be handled by bridged's monitor + versions = self.scheduler.get_runtime_versions() + + log.info('Opening connection to: [%s]:%s', self.host, self.port) + + while True: + try: + self.conn = socket.create_connection((self.host, self.port), timeout=5) + except OSError as e: + if e.errno != errno.EINTR: + raise + else: + break + + self.conn.settimeout(300) + self.conn.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + + if self.ssl_context: + log.info('Starting TLS on: [%s]:%s', self.host, self.port) + self.conn = self.ssl_context.wrap_socket(self.conn, server_hostname=self.host) + + log.info('Starting handshake with: [%s]:%s', self.host, self.port) + self.input = self.conn.makefile('rb') + self.handshake(problems, versions, self.name, self.key) + log.info('Judge "%s" online: [%s]:%s', self.name, self.host, self.port) + + def _reconnect(self): + if self.shutdown_requested: + log.info('Shutdown requested, not reconnecting.') + return + + log.warning('Attempting reconnection in %.0fs: [%s]:%s', self.fallback, self.host, self.port) + + if self.conn is not None: + log.info('Dropping old connection.') + self.conn.close() + time.sleep(self.fallback) + self.fallback = min(self.fallback * 1.5, 60) # Limit fallback to one minute. + self._do_reconnect() + + def _do_reconnect(self): + try: + self._connect() + except JudgeAuthenticationFailed: + log.error('Authentication as "%s" failed on: [%s]:%s', self.name, self.host, self.port) + self._reconnect() + except socket.error: + log.exception('Connection failed due to socket error: [%s]:%s', self.host, self.port) + self._reconnect() + + def _read_forever(self): + try: + while True: + packet = self._read_single() + if packet is None: + break + self._receive_packet(packet) + except Exception: + self.scheduler.abort_submission(self.bridge_id) + self.scheduler.reset_bridge(self.bridge_id) + self._reconnect() + + def _read_single(self) -> Optional[dict]: + if self.shutdown_requested: + return None + + try: + data = self.input.read(BridgeHandler.SIZE_PACK.size) + except socket.error: + self._reconnect() + return self._read_single() + if not data: + self._reconnect() + return self._read_single() + size = BridgeHandler.SIZE_PACK.unpack(data)[0] + try: + packet = zlib.decompress(self.input.read(size)) + except zlib.error: + self._reconnect() + return self._read_single() + else: + return json.loads(packet.decode('utf-8', 'strict')) + + def listen(self): + threading.Thread(target=self._read_forever).start() + + def shutdown(self): + self.shutdown_requested = True + self._close() + + def _close(self): + if self.conn and not self._closed: + try: + # May already be closed despite self._closed == False if a network error occurred and `close` is being + # called as part of cleanup. + self.conn.shutdown(socket.SHUT_RDWR) + except socket.error: + pass + self._closed = True + + def __del__(self): + self.shutdown() + + def send_packet(self, packet: dict): + for k, v in packet.items(): + if isinstance(v, bytes): + # Make sure we don't have any garbage utf-8 from e.g. weird compilers + # *cough* fpc *cough* that could cause this routine to crash + # We cannot use utf8text because it may not be text. + packet[k] = v.decode('utf-8', 'replace') + + raw = zlib.compress(json.dumps(packet).encode('utf-8')) + with self._lock: + try: + assert self.conn is not None + self.conn.sendall(BridgeHandler.SIZE_PACK.pack(len(raw)) + raw) + except Exception: # connection reset by peer + self.scheduler.abort_submission(self.bridge_id) + self.scheduler.reset_bridge(self.bridge_id) + self._reconnect() + + def _receive_packet(self, packet: dict): + name = packet['name'] + if name == 'ping': + self.ping_packet(packet['when']) + elif name == 'submission-request': + self.submission_acknowledged_packet(packet['submission-id']) + self.scheduler.queue_submission(self.bridge_id, packet) + elif name == 'terminate-submission': + self.scheduler.abort_submission(self.bridge_id) + elif name == 'disconnect': + self.scheduler.abort_submission(self.bridge_id) + self._close() + else: + log.error('Unknown packet %s, payload %s', name, packet) + + def handshake(self, problems: str, runtimes, id: str, key: str): + self.send_packet({'name': 'handshake', 'problems': problems, 'executors': runtimes, 'id': id, 'key': key}) + log.info('Awaiting handshake response: [%s]:%s', self.host, self.port) + try: + data = self.input.read(BridgeHandler.SIZE_PACK.size) + size = BridgeHandler.SIZE_PACK.unpack(data)[0] + packet = zlib.decompress(self.input.read(size)).decode('utf-8', 'strict') + resp = json.loads(packet) + except Exception: + log.exception('Cannot understand handshake response: [%s]:%s', self.host, self.port) + raise JudgeAuthenticationFailed() + else: + if resp['name'] != 'handshake-success': + log.error('Handshake failed.') + raise JudgeAuthenticationFailed() + + def ping_packet(self, when: float): + data = {'name': 'ping-response', 'when': when, 'time': time.time()} + for fn in sysinfo.report_callbacks: + key, value = fn() + data[key] = value + self.send_packet(data) + + def submission_acknowledged_packet(self, sub_id: int): + self.send_packet({'name': 'submission-acknowledged', 'submission-id': sub_id}) + + def executors_packet(self, executors): + self.send_packet({'name': 'executors', 'executors': executors}) diff --git a/judge/balancer/daemon.py b/judge/balancer/daemon.py new file mode 100644 index 000000000..4cd53f277 --- /dev/null +++ b/judge/balancer/daemon.py @@ -0,0 +1,27 @@ +import logging +import signal +import threading + +from judge.balancer.scheduler import Scheduler + +logger = logging.getLogger('judge.balancer') + + +def balancer_daemon(config): + scheduler = Scheduler(config) + scheduler.run() + + stop = threading.Event() + + def signal_handler(signum, _): + logger.info('Exiting due to %s', signal.Signals(signum).name) + stop.set() + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGQUIT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + stop.wait() + finally: + scheduler.shutdown() diff --git a/judge/balancer/judge_handler.py b/judge/balancer/judge_handler.py new file mode 100644 index 000000000..45402dd9e --- /dev/null +++ b/judge/balancer/judge_handler.py @@ -0,0 +1,158 @@ +import json +import logging +import threading +import time + +from django.conf import settings + +from judge.bridge.base_handler import ZlibPacketHandler, proxy_list + +logger = logging.getLogger('judge.balancer') + + +class JudgeHandler(ZlibPacketHandler): + proxies = proxy_list(settings.BRIDGED_JUDGE_PROXIES or []) + + def __init__(self, request, client_address, server, scheduler): + super().__init__(request, client_address, server) + + self.scheduler = scheduler + self.handlers = { + 'grading-begin': self.forward_packet, + 'grading-end': self.forward_packet_and_free_self, + 'compile-error': self.forward_packet_and_free_self, + 'compile-message': self.forward_packet, + 'batch-begin': self.forward_packet, + 'batch-end': self.forward_packet, + 'test-case-status': self.forward_packet, + 'internal-error': self.forward_packet_and_free_self, + 'submission-terminated': self.forward_packet_and_free_self, + 'submission-acknowledged': self.on_submission_acknowledged, + 'ping-response': self.ignore_packet, + 'supported-problems': self.ignore_packet, + 'handshake': self.on_handshake, + } + self.current_submission_id = None + self._no_response_job = None + self.name = None + self._stop_ping = threading.Event() + + def on_connect(self): + self.timeout = 15 + logger.info('Judge connected from: %s', self.client_address) + + def on_disconnect(self): + self._stop_ping.set() + if self.current_submission_id: + self.internal_error_packet('Judge disconnected while handling submission') + bridge_id = self.scheduler.get_paired_bridge(self.name) + self.scheduler.reset_bridge(bridge_id) + logger.error('Judge %s disconnected while handling submission %s', self.name, self.current_submission_id) + + self.scheduler.remove_judge(self) + logger.info('Judge disconnected from: %s with name %s', self.client_address, self.name) + + def send(self, data): + super().send(json.dumps(data, separators=(',', ':'))) + + def on_handshake(self, packet): + if 'id' not in packet or 'key' not in packet: + logger.warning('Malformed handshake: %s', self.client_address) + self.close() + return + + if not self.scheduler.authenticate_judge(packet['id'], packet['key'], self.client_address): + self.close() + return + + self.timeout = 60 + self.name = packet['id'] + + self.send({'name': 'handshake-success'}) + logger.info('Judge authenticated: %s (%s)', self.client_address, packet['id']) + self.scheduler.set_runtime_versions(packet['executors']) + self.scheduler.register_judge(self) + threading.Thread(target=self._ping_thread).start() + + @property + def working(self): + return bool(self.current_submission_id) + + def disconnect(self, force=False): + if force: + # Yank the power out. + self.close() + else: + self.send({'name': 'disconnect'}) + + def submit(self, packet): + self.current_submission_id = packet['submission-id'] + self._no_response_job = threading.Timer(20, self._kill_if_no_response) + self.send(packet) + + def _kill_if_no_response(self): + logger.error('Judge failed to acknowledge submission: %s: %s', self.name, self.current_submission_id) + self.close() + + def on_timeout(self): + if self.name: + logger.warning('Judge seems dead: %s: %s', self.name, self.current_submission_id) + + def on_submission_acknowledged(self, packet): + self.scheduler.forward_packet_to_bridge(self.name, packet) + if not packet.get('submission-id', None) == self.current_submission_id: + logger.error('Wrong acknowledgement: %s: %s, expected: %s', self.name, packet.get('submission-id', None), + self.current_submission_id) + self.close() + logger.info('Submission acknowledged: %d', self.current_submission_id) + if self._no_response_job: + self._no_response_job.cancel() + self._no_response_job = None + + def abort(self): + self.send({'name': 'terminate-submission'}) + + def ping(self): + self.send({'name': 'ping', 'when': time.time()}) + + def on_packet(self, data): + try: + try: + data = json.loads(data) + if 'name' not in data: + raise ValueError + except ValueError: + self.on_malformed(data) + else: + handler = self.handlers.get(data['name'], self.on_malformed) + handler(data) + except Exception: + logger.exception('Error in packet handling (Judge-side): %s', self.name) + + def on_malformed(self, packet): + logger.error('%s: Malformed packet: %s', self.name, packet) + + def forward_packet(self, packet): + self.scheduler.forward_packet_to_bridge(self.name, packet) + + def forward_packet_and_free_self(self, packet): + self.scheduler.forward_packet_to_bridge(self.name, packet) + self.current_submission_id = None + self.scheduler.free_judge(self) + + def ignore_packet(self, packet): + pass + + def internal_error_packet(self, message: str): + self.forward_packet({'name': 'internal-error', 'submission-id': self.current_submission_id, 'message': message}) + + def _ping_thread(self): + try: + while True: + self.ping() + if self._stop_ping.wait(10): + break + except Exception: + logger.exception('Ping error in %s', self.name) + self.close() + raise diff --git a/judge/balancer/judge_list.py b/judge/balancer/judge_list.py new file mode 100644 index 000000000..e69de29bb diff --git a/judge/balancer/scheduler.py b/judge/balancer/scheduler.py new file mode 100644 index 000000000..173ff5146 --- /dev/null +++ b/judge/balancer/scheduler.py @@ -0,0 +1,134 @@ +import hmac +import logging +import threading +from collections import deque +from functools import partial +from threading import RLock + +from django.conf import settings + +from judge.balancer.bridge_handler import BridgeHandler +from judge.balancer.judge_handler import JudgeHandler +from judge.bridge.server import Server + + +logger = logging.getLogger('judge.balancer') + + +class Scheduler: + def __init__(self, config): + self.executors = {} + self.config = config + self.judges = set() + self.queue = deque() + self.lock = RLock() + self.judge_to_bridge = {} + self.bridge_to_judge = {} + + self.judge_server = Server( + settings.BALANCER_JUDGE_ADDRESS, + partial(JudgeHandler, scheduler=self), + ) + + self.bridges = [] + for bridge in config['bridges']: + bridge_id = len(self.bridges) + self.bridges.append(BridgeHandler(scheduler=self, bridge_id=bridge_id, **bridge)) + + def run(self): + threading.Thread(target=self.judge_server.serve_forever).start() + for bridge in self.bridges: + bridge.listen() + + def shutdown(self): + self.judge_server.shutdown() + for bridge in self.bridges: + bridge.shutdown() + + def get_paired_bridge(self, judge_name): + return self.judge_to_bridge.get(judge_name) + + def reset_bridge(self, bridge_id): + with self.lock: + if bridge_id in self.bridge_to_judge: + judge = self.bridge_to_judge[bridge_id] + del self.judge_to_bridge[judge.name] + del self.bridge_to_judge[bridge_id] + + def _try_judge(self): + with self.lock: + available = [judge for judge in self.judges if not judge.working] + while available and self.queue: + judge = available.pop() + bridge_id, packet = self.queue.popleft() + self.judge_to_bridge[judge.name] = bridge_id + self.bridge_to_judge[bridge_id] = judge + + packet['storage_namespace'] = self.config['bridges'][bridge_id].get('storage_namespace') + judge.submit(packet) + + def free_judge(self, judge): + with self.lock: + bridge_id = self.judge_to_bridge[judge.name] + del self.judge_to_bridge[judge.name] + del self.bridge_to_judge[bridge_id] + + self._try_judge() + + def authenticate_judge(self, judge_id, key, client_address): + judge_config = ([judge for judge in self.config['judges'] if judge['id'] == judge_id] or [None])[0] + if judge_config is None: + return False + + if not hmac.compare_digest(judge_config.get('key'), key): + logger.warning('Judge authentication failure: %s', client_address) + return False + + return True + + def register_judge(self, judge): + with self.lock: + # Disconnect all judges with the same name, see + self.disconnect(judge, force=True) + self.judges.add(judge) + self._try_judge() + + def disconnect(self, judge_id, force=False): + with self.lock: + for judge in self.judges: + if judge.name == judge_id: + judge.disconnect(force=force) + + def remove_judge(self, judge): + with self.lock: + bridge_id = self.judge_to_bridge.get(judge.name) + if bridge_id is not None: + del self.judge_to_bridge[judge.name] + del self.bridge_to_judge[bridge_id] + + def set_runtime_versions(self, executors): + self.executors = executors + for bridge in self.bridges: + bridge.executors_packet(executors) + + def get_runtime_versions(self): + return self.executors + + def queue_submission(self, bridge_id: int, packet: dict): + with self.lock: + self.queue.append((bridge_id, packet)) + self._try_judge() + + def abort_submission(self, bridge_id): + try: + judge = self.bridge_to_judge[bridge_id] + judge.abort() + except KeyError: + pass + + def forward_packet_to_bridge(self, judge_name, packet: dict): + try: + bridge_id = self.judge_to_bridge[judge_name] + self.bridges[bridge_id].send_packet(packet) + except KeyError: + pass diff --git a/judge/balancer/sysinfo.py b/judge/balancer/sysinfo.py new file mode 100644 index 000000000..1fcc50e16 --- /dev/null +++ b/judge/balancer/sysinfo.py @@ -0,0 +1,33 @@ +import os +from multiprocessing import cpu_count as _get_cpu_count + +_cpu_count = _get_cpu_count() + + +if hasattr(os, 'getloadavg'): + + def load_fair(): + try: + load = os.getloadavg()[0] / _cpu_count + except OSError: # as of May 2016, Windows' Linux subsystem throws OSError on getloadavg + load = -1 + return 'load', load + +else: + # There exist some Unix platforms (like Android) which don't + # have `getloadavg` implemented, but aren't Windows + # so we manually read the `/proc/loadavg` file. + def load_fair(): + try: + with open('/proc/loadavg', 'r') as f: + load = float(f.read().split()[0]) / _cpu_count + except FileNotFoundError: + load = -1 + return 'load', load + + +def cpu_count(): + return 'cpu-count', _cpu_count + + +report_callbacks = [load_fair, cpu_count] diff --git a/judge/bridge/judge_handler.py b/judge/bridge/judge_handler.py index f8af0818d..6f64ee305 100644 --- a/judge/bridge/judge_handler.py +++ b/judge/bridge/judge_handler.py @@ -53,6 +53,7 @@ def __init__(self, request, client_address, server, judges, ignore_problems_pack 'submission-acknowledged': self.on_submission_acknowledged, 'ping-response': self.on_ping_response, 'supported-problems': self.on_supported_problems, + 'executors': self.on_executors, 'handshake': self.on_handshake, } self._working = False @@ -119,24 +120,18 @@ def _connected(self): judge = self.judge = Judge.objects.get(name=self.name) judge.start_time = timezone.now() judge.online = True - judge.runtimes.set(Language.objects.filter(key__in=list(self.executors.keys())).values_list('id', flat=True)) + + self.update_runtimes() + if self.ignore_problems_packet: - judge.problems.set(self.judges.problem_ids_cache) + self.problems = self.judges.problems + judge.problems.set(self.judges.problem_ids) else: judge.problems.set(Problem.objects.filter(code__in=list(self.problems.keys())).values_list('id', flat=True)) # Cache is_disabled for faster access self.is_disabled = judge.is_disabled - # Delete now in case we somehow crashed and left some over from the last connection - RuntimeVersion.objects.filter(judge=judge).delete() - versions = [] - for lang in judge.runtimes.all(): - versions += [ - RuntimeVersion(language=lang, name=name, version='.'.join(map(str, version)), priority=idx, judge=judge) - for idx, (name, version) in enumerate(self.executors[lang.key]) - ] - RuntimeVersion.objects.bulk_create(versions) judge.last_ip = self.client_address[0] judge.save() self.judge_address = '[%s]:%s' % (self.client_address[0], self.client_address[1]) @@ -327,8 +322,9 @@ def _submission_is_batch(self, id): if not Submission.objects.filter(id=id).update(batch=True): logger.warning('Unknown submission: %s', id) - def update_problems(self, problem_ids): + def update_problems(self, problems, problem_ids): logger.info('%s: Updating problem list', self.name) + self.problems = problems self.judge.problems.set(problem_ids) logger.info('%s: Updated %d problems', self.name, len(problem_ids)) json_log.info(self._make_json_log(action='update-problems', count=len(problem_ids))) @@ -337,9 +333,31 @@ def on_supported_problems(self, packet): if self.ignore_problems_packet: return - self.problems = dict(packet['problems']) - problem_ids = list(Problem.objects.filter(code__in=list(self.problems.keys())).values_list('id', flat=True)) - self.judges.update_problems(self, problem_ids) + problems = dict(packet['problems']) + problem_ids = list(Problem.objects.filter(code__in=list(problems.keys())).values_list('id', flat=True)) + self.judges.update_problems(self, problems, problem_ids) + + def update_runtimes(self): + self.judge.runtimes.set( + Language.objects.filter(key__in=list(self.executors.keys())).values_list('id', flat=True), + ) + + RuntimeVersion.objects.filter(judge=self.judge).delete() + versions = [] + for lang in self.judge.runtimes.all(): + versions += [ + RuntimeVersion(language=lang, name=name, version='.'.join(map(str, version)), + priority=idx, judge=self.judge) + for idx, (name, version) in enumerate(self.executors[lang.key]) + ] + RuntimeVersion.objects.bulk_create(versions) + + def on_executors(self, packet): + logger.info('%s: Updating runtimes', self.name) + self.executors = packet['executors'] + self.update_runtimes() + logger.info('%s: Updated runtimes', self.name) + json_log.info(self._make_json_log(action='update-executors', executors=list(self.executors.keys()))) def on_grading_begin(self, packet): logger.info('%s: Grading has begun on: %s', self.name, packet['submission-id']) diff --git a/judge/bridge/judge_list.py b/judge/bridge/judge_list.py index 1a69921de..2269d106e 100644 --- a/judge/bridge/judge_list.py +++ b/judge/bridge/judge_list.py @@ -25,7 +25,8 @@ def __init__(self): self.node_map = {} self.submission_map = {} self.lock = RLock() - self.problem_ids_cache = [] + self.problems = {} + self.problem_ids = [] def _handle_free_judge(self, judge): with self.lock: @@ -69,17 +70,18 @@ def disconnect(self, judge_id, force=False): if judge.name == judge_id: judge.disconnect(force=force) - def update_problems_all(self, problem_ids): + def update_problems_all(self, problems, problem_ids): with self.lock: - self.problem_ids_cache = problem_ids + self.problems = problems + self.problem_ids = problem_ids for judge in self.judges: judge.update_problems(problem_ids) if not judge.working: self._handle_free_judge(judge) - def update_problems(self, judge, problem_ids): + def update_problems(self, judge, problems, problem_ids): with self.lock: - judge.update_problems(problem_ids) + judge.update_problems(problems, problem_ids) if not judge.working: self._handle_free_judge(judge) diff --git a/judge/bridge/monitor.py b/judge/bridge/monitor.py index 31b927a86..aa093abd7 100644 --- a/judge/bridge/monitor.py +++ b/judge/bridge/monitor.py @@ -81,8 +81,9 @@ def update_supported_problems(self): problems.append(problem) problems = list(set(problems)) + problems = {problem: 0 for problem in problems} problem_ids = list(Problem.objects.filter(code__in=problems).values_list('id', flat=True)) - self.judges.update_problems_all(problem_ids) + self.judges.update_problems_all(problems, problem_ids) def updater_thread(self) -> None: while True: diff --git a/judge/management/commands/runbalancer.py b/judge/management/commands/runbalancer.py new file mode 100644 index 000000000..260db3367 --- /dev/null +++ b/judge/management/commands/runbalancer.py @@ -0,0 +1,14 @@ +import yaml +from django.core.management.base import BaseCommand + +from judge.balancer.daemon import balancer_daemon + + +class Command(BaseCommand): + def add_arguments(self, parser) -> None: + parser.add_argument('-c', '--config', type=str, help='file to load balancer configurations from') + + def handle(self, *args, **options): + with open(options['config'], 'r') as f: + config = yaml.safe_load(f) + balancer_daemon(config)