Skip to content

Commit

Permalink
queue gpu tasks separately
Browse files Browse the repository at this point in the history
  • Loading branch information
dsschult committed Jan 10, 2025
1 parent ea65812 commit 4aaf86e
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 66 deletions.
99 changes: 47 additions & 52 deletions iceprod/rest/handlers/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,44 +308,6 @@ async def put(self, task_id):
self.finish()


class TaskCountsStatusHandler(APIBase):
"""
Handle task summary grouping by status.
"""
@authorization(roles=['admin', 'system'])
async def get(self):
"""
Get the task counts for all tasks, group by status.
Params (optional):
status: | separated list of task status to filter by
Returns:
dict: {<status>: num}
"""
match = {}
status = self.get_argument('status', None)
if status:
status_list = status.split('|')
if any(s not in TASK_STATUS for s in status_list):
raise tornado.web.HTTPError(400, reaosn='Unknown task status')
match['status'] = {'$in': status_list}

ret = {}
cursor = self.db.tasks.aggregate([
{'$match': match},
{'$group': {'_id': '$status', 'total': {'$sum': 1}}},
])
ret = {}
async for row in cursor:
ret[row['_id']] = row['total']
ret2 = {}
for k in sorted(ret, key=task_status_sort):
ret2[k] = ret[k]
self.write(ret2)
self.finish()


class DatasetMultiTasksHandler(APIBase):
"""
Handle multi tasks requests.
Expand Down Expand Up @@ -553,30 +515,25 @@ async def get(self, dataset_id):
self.finish()


class DatasetTaskCountsStatusHandler(APIBase):
class TaskCountsStatusHandler(APIBase):
"""
Handle task summary grouping by status.
"""
@authorization(roles=['admin', 'user', 'system'])
@attr_auth(arg='dataset_id', role='read')
async def get(self, dataset_id):
"""
Get the task counts for all tasks in a dataset, group by status.
Args:
dataset_id (str): dataset id
Returns:
dict: {<status>: num}
"""
match = {'dataset_id': dataset_id}
async def counts(self, match):
status = self.get_argument('status', None)
if status:
status_list = status.split('|')
if any(s not in TASK_STATUS for s in status_list):
raise tornado.web.HTTPError(400, reaosn='Unknown task status')
match['status'] = {'$in': status_list}

gpu = self.get_argument('gpu', None)
if gpu is not None:
if gpu:
match['requirements.gpu'] = {'$gte': 1}
else:
match['$or'] = [{"requirements.gpu": {"$exists": False}}, {"requirements.gpu": {"$lte": 0}}]

ret = {}
cursor = self.db.tasks.aggregate([
{'$match': match},
Expand All @@ -591,6 +548,44 @@ async def get(self, dataset_id):
self.write(ret2)
self.finish()

@authorization(roles=['admin', 'system'])
async def get(self):
"""
Get the task counts for all tasks, group by status.
Params (optional):
status: | separated list of task status to filter by
gpu: bool to select only gpu tasks or non-gpu tasks
Returns:
dict: {<status>: num}
"""
await self.counts(match={})


class DatasetTaskCountsStatusHandler(TaskCountsStatusHandler):
"""
Handle task summary grouping by status.
"""
@authorization(roles=['admin', 'user', 'system'])
@attr_auth(arg='dataset_id', role='read')
async def get(self, dataset_id):
"""
Get the task counts for all tasks in a dataset, group by status.
Args:
dataset_id (str): dataset id
Params (optional):
status: | separated list of task status to filter by
gpu: bool to select only gpu tasks or non-gpu tasks
Returns:
dict: {<status>: num}
"""
match = {'dataset_id': dataset_id}
await self.counts(match=match)


class DatasetTaskCountsNameStatusHandler(APIBase):
"""
Expand Down
39 changes: 27 additions & 12 deletions iceprod/scheduled_tasks/queue_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,49 @@
import logging
import os

from wipac_dev_tools import from_environment, strtobool

from iceprod.client_auth import add_auth_to_argparse, create_rest_client

logger = logging.getLogger('queue_tasks')

NTASKS = 250000
NTASKS_PER_CYCLE = 1000
default_config = {
'NTASKS': 250000,
'NTASKS_PER_CYCLE': 1000,
}


async def run(rest_client, dataset_id=None, ntasks=NTASKS, ntasks_per_cycle=NTASKS_PER_CYCLE, debug=False):
async def run(rest_client, config, dataset_id='', gpus=None, debug=False):
"""
Actual runtime / loop.
Args:
rest_client (:py:class:`iceprod.core.rest_client.Client`): rest client
config (dict): config dict
dataset_id (str): dataset to queue
gpus (bool): run on gpu tasks, cpu tasks, or both
debug (bool): debug flag to propagate exceptions
"""
try:
num_tasks_idle = 0
num_tasks_waiting = 0
num_tasks_queued = 0
if dataset_id:
route = f'/datasets/{dataset_id}/task_counts/status'
else:
route = '/task_counts/status'
args = {'status': 'idle|waiting'}
if gpus is not None:
args['gpu'] = gpus
print(args)
return
tasks = await rest_client.request('GET', route, args)
if 'idle' in tasks:
num_tasks_waiting = tasks['idle']
num_tasks_idle = tasks['idle']
if 'waiting' in tasks:
num_tasks_queued = tasks['waiting']
tasks_to_queue = min(num_tasks_waiting, ntasks - num_tasks_queued, ntasks_per_cycle)
logger.warning(f'num tasks idle: {num_tasks_waiting}')
logger.warning(f'num tasks waiting: {num_tasks_queued}')
num_tasks_waiting = tasks['waiting']
tasks_to_queue = min(num_tasks_idle, config['NTASKS'] - num_tasks_waiting, config['NTASKS_PER_CYCLE'])
logger.warning(f'num tasks idle: {num_tasks_idle}')
logger.warning(f'num tasks waiting: {num_tasks_waiting}')
logger.warning(f'tasks to waiting: {tasks_to_queue}')

if tasks_to_queue > 0:
Expand Down Expand Up @@ -128,24 +139,28 @@ async def check_deps(task):


def main():
config = from_environment(default_config)

parser = argparse.ArgumentParser(description='run a scheduled task once')
add_auth_to_argparse(parser)
parser.add_argument('--dataset-id', help='dataset id')
parser.add_argument('--ntasks', type=int, default=os.environ.get('NTASKS', NTASKS),
parser.add_argument('--gpus', default=None, type=strtobool, help='whether to select only gpu or non-gpu tasks')
parser.add_argument('--ntasks', type=int, default=config['NTASKS'],
help='number of tasks to keep queued')
parser.add_argument('--ntasks_per_cycle', type=int, default=os.environ.get('NTASKS_PER_CYCLE', NTASKS_PER_CYCLE),
parser.add_argument('--ntasks_per_cycle', type=int, default=config['NTASKS_PER_CYCLE'],
help='number of tasks to queue per cycle')
parser.add_argument('--log-level', default='info', help='log level')
parser.add_argument('--debug', default=False, action='store_true', help='debug enabled')

args = parser.parse_args()
config.update(vars(args))

logformat = '%(asctime)s %(levelname)s %(name)s %(module)s:%(lineno)s - %(message)s'
logging.basicConfig(format=logformat, level=getattr(logging, args.log_level.upper()))

rest_client = create_rest_client(args)

asyncio.run(run(rest_client, dataset_id=args.dataset_id, ntasks=args.ntasks, ntasks_per_cycle=args.ntasks_per_cycle, debug=args.debug))
asyncio.run(run(rest_client, dataset_id=args.dataset_id, gpus=args.gpus, config=config, debug=args.debug))


if __name__ == '__main__':
Expand Down
21 changes: 19 additions & 2 deletions tests/rest/tasks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,30 @@ async def test_rest_tasks_dataset_counts_status(server):
'requirements': {},
}
ret = await client.request('POST', '/tasks', data)
task_id = ret['result']

data = {
'dataset_id': 'foo',
'job_id': 'foo1',
'task_index': 1,
'job_index': 0,
'name': 'baz',
'depends': [],
'requirements': {'gpu': 1},
'status': 'processing'
}
ret = await client.request('POST', '/tasks', data)

ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status')
assert ret == {states.TASK_STATUS_START: 1}
assert ret == {states.TASK_STATUS_START: 1, 'processing': 1}

ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status?status=complete')
assert ret == {}

ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status?gpu=false')
assert ret == {states.TASK_STATUS_START: 1}

ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status?gpu=true')
assert ret == {'processing': 1}


async def test_rest_tasks_dataset_counts_name_status(server):
Expand Down

0 comments on commit 4aaf86e

Please sign in to comment.