From 4aaf86e0025c732b54aa608082f9be765258dfe9 Mon Sep 17 00:00:00 2001 From: David Schultz Date: Fri, 10 Jan 2025 11:44:38 -0600 Subject: [PATCH] queue gpu tasks separately --- iceprod/rest/handlers/tasks.py | 99 ++++++++++++-------------- iceprod/scheduled_tasks/queue_tasks.py | 39 ++++++---- tests/rest/tasks_test.py | 21 +++++- 3 files changed, 93 insertions(+), 66 deletions(-) diff --git a/iceprod/rest/handlers/tasks.py b/iceprod/rest/handlers/tasks.py index 4fd0fd0a..99471f96 100644 --- a/iceprod/rest/handlers/tasks.py +++ b/iceprod/rest/handlers/tasks.py @@ -308,44 +308,6 @@ async def put(self, task_id): self.finish() -class TaskCountsStatusHandler(APIBase): - """ - Handle task summary grouping by status. - """ - @authorization(roles=['admin', 'system']) - async def get(self): - """ - Get the task counts for all tasks, group by status. - - Params (optional): - status: | separated list of task status to filter by - - Returns: - dict: {: num} - """ - match = {} - status = self.get_argument('status', None) - if status: - status_list = status.split('|') - if any(s not in TASK_STATUS for s in status_list): - raise tornado.web.HTTPError(400, reaosn='Unknown task status') - match['status'] = {'$in': status_list} - - ret = {} - cursor = self.db.tasks.aggregate([ - {'$match': match}, - {'$group': {'_id': '$status', 'total': {'$sum': 1}}}, - ]) - ret = {} - async for row in cursor: - ret[row['_id']] = row['total'] - ret2 = {} - for k in sorted(ret, key=task_status_sort): - ret2[k] = ret[k] - self.write(ret2) - self.finish() - - class DatasetMultiTasksHandler(APIBase): """ Handle multi tasks requests. @@ -553,23 +515,11 @@ async def get(self, dataset_id): self.finish() -class DatasetTaskCountsStatusHandler(APIBase): +class TaskCountsStatusHandler(APIBase): """ Handle task summary grouping by status. """ - @authorization(roles=['admin', 'user', 'system']) - @attr_auth(arg='dataset_id', role='read') - async def get(self, dataset_id): - """ - Get the task counts for all tasks in a dataset, group by status. - - Args: - dataset_id (str): dataset id - - Returns: - dict: {: num} - """ - match = {'dataset_id': dataset_id} + async def counts(self, match): status = self.get_argument('status', None) if status: status_list = status.split('|') @@ -577,6 +527,13 @@ async def get(self, dataset_id): raise tornado.web.HTTPError(400, reaosn='Unknown task status') match['status'] = {'$in': status_list} + gpu = self.get_argument('gpu', None) + if gpu is not None: + if gpu: + match['requirements.gpu'] = {'$gte': 1} + else: + match['$or'] = [{"requirements.gpu": {"$exists": False}}, {"requirements.gpu": {"$lte": 0}}] + ret = {} cursor = self.db.tasks.aggregate([ {'$match': match}, @@ -591,6 +548,44 @@ async def get(self, dataset_id): self.write(ret2) self.finish() + @authorization(roles=['admin', 'system']) + async def get(self): + """ + Get the task counts for all tasks, group by status. + + Params (optional): + status: | separated list of task status to filter by + gpu: bool to select only gpu tasks or non-gpu tasks + + Returns: + dict: {: num} + """ + await self.counts(match={}) + + +class DatasetTaskCountsStatusHandler(TaskCountsStatusHandler): + """ + Handle task summary grouping by status. + """ + @authorization(roles=['admin', 'user', 'system']) + @attr_auth(arg='dataset_id', role='read') + async def get(self, dataset_id): + """ + Get the task counts for all tasks in a dataset, group by status. + + Args: + dataset_id (str): dataset id + + Params (optional): + status: | separated list of task status to filter by + gpu: bool to select only gpu tasks or non-gpu tasks + + Returns: + dict: {: num} + """ + match = {'dataset_id': dataset_id} + await self.counts(match=match) + class DatasetTaskCountsNameStatusHandler(APIBase): """ diff --git a/iceprod/scheduled_tasks/queue_tasks.py b/iceprod/scheduled_tasks/queue_tasks.py index 7234e888..ac9cf50c 100644 --- a/iceprod/scheduled_tasks/queue_tasks.py +++ b/iceprod/scheduled_tasks/queue_tasks.py @@ -10,38 +10,49 @@ import logging import os +from wipac_dev_tools import from_environment, strtobool + from iceprod.client_auth import add_auth_to_argparse, create_rest_client logger = logging.getLogger('queue_tasks') -NTASKS = 250000 -NTASKS_PER_CYCLE = 1000 +default_config = { + 'NTASKS': 250000, + 'NTASKS_PER_CYCLE': 1000, +} -async def run(rest_client, dataset_id=None, ntasks=NTASKS, ntasks_per_cycle=NTASKS_PER_CYCLE, debug=False): +async def run(rest_client, config, dataset_id='', gpus=None, debug=False): """ Actual runtime / loop. Args: rest_client (:py:class:`iceprod.core.rest_client.Client`): rest client + config (dict): config dict + dataset_id (str): dataset to queue + gpus (bool): run on gpu tasks, cpu tasks, or both debug (bool): debug flag to propagate exceptions """ try: + num_tasks_idle = 0 num_tasks_waiting = 0 - num_tasks_queued = 0 if dataset_id: route = f'/datasets/{dataset_id}/task_counts/status' else: route = '/task_counts/status' args = {'status': 'idle|waiting'} + if gpus is not None: + args['gpu'] = gpus + print(args) + return tasks = await rest_client.request('GET', route, args) if 'idle' in tasks: - num_tasks_waiting = tasks['idle'] + num_tasks_idle = tasks['idle'] if 'waiting' in tasks: - num_tasks_queued = tasks['waiting'] - tasks_to_queue = min(num_tasks_waiting, ntasks - num_tasks_queued, ntasks_per_cycle) - logger.warning(f'num tasks idle: {num_tasks_waiting}') - logger.warning(f'num tasks waiting: {num_tasks_queued}') + num_tasks_waiting = tasks['waiting'] + tasks_to_queue = min(num_tasks_idle, config['NTASKS'] - num_tasks_waiting, config['NTASKS_PER_CYCLE']) + logger.warning(f'num tasks idle: {num_tasks_idle}') + logger.warning(f'num tasks waiting: {num_tasks_waiting}') logger.warning(f'tasks to waiting: {tasks_to_queue}') if tasks_to_queue > 0: @@ -128,24 +139,28 @@ async def check_deps(task): def main(): + config = from_environment(default_config) + parser = argparse.ArgumentParser(description='run a scheduled task once') add_auth_to_argparse(parser) parser.add_argument('--dataset-id', help='dataset id') - parser.add_argument('--ntasks', type=int, default=os.environ.get('NTASKS', NTASKS), + parser.add_argument('--gpus', default=None, type=strtobool, help='whether to select only gpu or non-gpu tasks') + parser.add_argument('--ntasks', type=int, default=config['NTASKS'], help='number of tasks to keep queued') - parser.add_argument('--ntasks_per_cycle', type=int, default=os.environ.get('NTASKS_PER_CYCLE', NTASKS_PER_CYCLE), + parser.add_argument('--ntasks_per_cycle', type=int, default=config['NTASKS_PER_CYCLE'], help='number of tasks to queue per cycle') parser.add_argument('--log-level', default='info', help='log level') parser.add_argument('--debug', default=False, action='store_true', help='debug enabled') args = parser.parse_args() + config.update(vars(args)) logformat = '%(asctime)s %(levelname)s %(name)s %(module)s:%(lineno)s - %(message)s' logging.basicConfig(format=logformat, level=getattr(logging, args.log_level.upper())) rest_client = create_rest_client(args) - asyncio.run(run(rest_client, dataset_id=args.dataset_id, ntasks=args.ntasks, ntasks_per_cycle=args.ntasks_per_cycle, debug=args.debug)) + asyncio.run(run(rest_client, dataset_id=args.dataset_id, gpus=args.gpus, config=config, debug=args.debug)) if __name__ == '__main__': diff --git a/tests/rest/tasks_test.py b/tests/rest/tasks_test.py index f1963639..d9254c94 100644 --- a/tests/rest/tasks_test.py +++ b/tests/rest/tasks_test.py @@ -297,13 +297,30 @@ async def test_rest_tasks_dataset_counts_status(server): 'requirements': {}, } ret = await client.request('POST', '/tasks', data) - task_id = ret['result'] + + data = { + 'dataset_id': 'foo', + 'job_id': 'foo1', + 'task_index': 1, + 'job_index': 0, + 'name': 'baz', + 'depends': [], + 'requirements': {'gpu': 1}, + 'status': 'processing' + } + ret = await client.request('POST', '/tasks', data) ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status') - assert ret == {states.TASK_STATUS_START: 1} + assert ret == {states.TASK_STATUS_START: 1, 'processing': 1} ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status?status=complete') assert ret == {} + + ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status?gpu=false') + assert ret == {states.TASK_STATUS_START: 1} + + ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status?gpu=true') + assert ret == {'processing': 1} async def test_rest_tasks_dataset_counts_name_status(server):