Skip to content

Commit

Permalink
lots of comments addressed
Browse files Browse the repository at this point in the history
  • Loading branch information
jigold committed Feb 5, 2024
1 parent 8a468ba commit 8aa3bb8
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 194 deletions.
18 changes: 10 additions & 8 deletions batch/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from hailtop.utils import humanize_timedelta_msecs, time_msecs_str

from .batch_format_version import BatchFormatVersion
from .exceptions import NonExistentBatchError, OpenBatchError
from .constants import ROOT_JOB_GROUP_ID
from .exceptions import NonExistentJobGroupError
from .utils import coalesce

log = logging.getLogger('batch')
Expand Down Expand Up @@ -160,17 +161,18 @@ async def cancel_job_group_in_db(db, batch_id, job_group_id):
async def cancel(tx):
record = await tx.execute_and_fetchone(
"""
SELECT `state` FROM batches
WHERE id = %s AND NOT deleted
SELECT `state`
FROM job_groups
LEFT JOIN batches ON batches.id = job_groups.batch_id
LEFT JOIN batch_updates ON job_groups.batch_id = batch_updates.batch_id AND
job_groups.update_id = batch_updates.update_id
WHERE batch_id = %s AND job_group_id = %s AND NOT deleted AND (batch_updates.committed OR job_groups.job_group_id = %s)
FOR UPDATE;
""",
(batch_id,),
(batch_id, job_group_id, ROOT_JOB_GROUP_ID),
)
if not record:
raise NonExistentBatchError(batch_id)

if record['state'] == 'open':
raise OpenBatchError(batch_id)
raise NonExistentJobGroupError(batch_id, job_group_id)

await tx.just_execute('CALL cancel_job_group(%s, %s);', (batch_id, job_group_id))

Expand Down
44 changes: 21 additions & 23 deletions batch/batch/driver/canceller.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def user_cancelled_ready_jobs(user, remaining) -> AsyncIterator[Dict[str,
(user,),
):
if job_group['cancelled']:
async for record in self.db.select_and_fetchall( # FIXME: Do we need a new index again?
async for record in self.db.select_and_fetchall(
"""
SELECT jobs.job_id
FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled)
Expand All @@ -118,7 +118,7 @@ async def user_cancelled_ready_jobs(user, remaining) -> AsyncIterator[Dict[str,
record['batch_id'] = job_group['batch_id']
yield record
else:
async for record in self.db.select_and_fetchall( # FIXME: Do we need a new index again?
async for record in self.db.select_and_fetchall(
"""
SELECT jobs.job_id
FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled)
Expand Down Expand Up @@ -185,29 +185,28 @@ async def cancel_cancelled_creating_jobs_loop_body(self):
async def user_cancelled_creating_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]:
async for job_group in self.db.select_and_fetchall(
"""
SELECT job_groups.batch_id, job_groups.job_group_id, job_groups_cancelled.id IS NOT NULL AS cancelled
SELECT job_groups.batch_id, job_groups.job_group_id
FROM job_groups
LEFT JOIN job_groups_cancelled
ON job_groups.batch_id = job_groups_cancelled.id AND
job_groups.job_group_id = job_groups_cancelled.job_group_id
INNER JOIN job_groups_cancelled
ON job_groups.batch_id = job_groups_cancelled.id AND
job_groups.job_group_id = job_groups_cancelled.job_group_id
WHERE user = %s AND `state` = 'running';
""",
(user,),
):
if job_group['cancelled']:
async for record in self.db.select_and_fetchall(
"""
async for record in self.db.select_and_fetchall(
"""
SELECT jobs.job_id, attempts.attempt_id, attempts.instance_name
FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled)
STRAIGHT_JOIN attempts
ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id
WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND state = 'Creating' AND always_run = 0 AND cancelled = 0
LIMIT %s;
""",
(job_group['batch_id'], job_group['job_group_id'], remaining.value),
):
record['batch_id'] = job_group['batch_id']
yield record
(job_group['batch_id'], job_group['job_group_id'], remaining.value),
):
record['batch_id'] = job_group['batch_id']
yield record

waitable_pool = WaitableSharedPool(self.async_worker_pool)

Expand Down Expand Up @@ -286,27 +285,26 @@ async def user_cancelled_running_jobs(user, remaining) -> AsyncIterator[Dict[str
"""
SELECT job_groups.batch_id, job_groups.job_group_id, job_groups_cancelled.id IS NOT NULL AS cancelled
FROM job_groups
LEFT JOIN job_groups_cancelled
ON job_groups.batch_id = job_groups_cancelled.id AND
job_groups.job_group_id = job_groups_cancelled.job_group_id
INNER JOIN job_groups_cancelled
ON job_groups.batch_id = job_groups_cancelled.id AND
job_groups.job_group_id = job_groups_cancelled.job_group_id
WHERE user = %s AND `state` = 'running';
""",
(user,),
):
if job_group['cancelled']:
async for record in self.db.select_and_fetchall(
"""
async for record in self.db.select_and_fetchall(
"""
SELECT jobs.job_id, attempts.attempt_id, attempts.instance_name
FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled)
STRAIGHT_JOIN attempts
ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id
WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND state = 'Running' AND always_run = 0 AND cancelled = 0
LIMIT %s;
""",
(job_group['batch_id'], job_group['job_group_id'], remaining.value),
):
record['batch_id'] = job_group['batch_id']
yield record
(job_group['batch_id'], job_group['job_group_id'], remaining.value),
):
record['batch_id'] = job_group['batch_id']
yield record

waitable_pool = WaitableSharedPool(self.async_worker_pool)

Expand Down
5 changes: 2 additions & 3 deletions batch/batch/driver/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,12 @@ async def notify_batch_job_complete(db: Database, client_session: httpx.ClientSe
LEFT JOIN LATERAL (
SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown
FROM (
SELECT batch_id, job_group_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage`
SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage`
FROM aggregated_job_group_resources_v3
WHERE job_groups.batch_id = aggregated_job_group_resources_v3.batch_id AND job_groups.job_group_id = aggregated_job_group_resources_v3.job_group_id
GROUP BY batch_id, job_group_id, resource_id
GROUP BY resource_id
) AS usage_t
LEFT JOIN resources ON usage_t.resource_id = resources.resource_id
GROUP BY batch_id, job_group_id
) AS cost_t ON TRUE
LEFT JOIN job_groups_cancelled
ON job_groups.batch_id = job_groups_cancelled.id AND job_groups.job_group_id = job_groups_cancelled.job_group_id
Expand Down
32 changes: 10 additions & 22 deletions batch/batch/driver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,11 +1203,11 @@ async def check(tx):
await check()


async def _cancel_batch(app, batch_id):
async def _cancel_job_group(app, batch_id, job_group_id):
try:
await cancel_job_group_in_db(app['db'], batch_id, ROOT_JOB_GROUP_ID)
await cancel_job_group_in_db(app['db'], batch_id, job_group_id)
except BatchUserError as exc:
log.info(f'cannot cancel batch because {exc.message}')
log.info(f'cannot cancel job group because {exc.message}')
return
set_cancel_state_changed(app)

Expand All @@ -1229,34 +1229,22 @@ async def monitor_billing_limits(app):
(record['billing_project'],),
)
async for batch in running_batches:
await _cancel_batch(app, batch['id'])
await _cancel_job_group(app, batch['id'], ROOT_JOB_GROUP_ID)


async def cancel_fast_failing_batches(app):
async def cancel_fast_failing_job_groups(app):
db: Database = app['db']

<<<<<<< HEAD
records = db.select_and_fetchall(
"""
SELECT job_groups.batch_id, job_groups_n_jobs_in_complete_states.n_failed
SELECT job_groups.batch_id, job_groups.job_group_id, job_groups_n_jobs_in_complete_states.n_failed
FROM job_groups
LEFT JOIN job_groups_n_jobs_in_complete_states
ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id
WHERE state = 'running' AND cancel_after_n_failures IS NOT NULL AND n_failed >= cancel_after_n_failures AND job_groups.job_group_id = %s
WHERE state = 'running' AND cancel_after_n_failures IS NOT NULL AND n_failed >= cancel_after_n_failures;
""",
(ROOT_JOB_GROUP_ID,),
)
=======
records = db.select_and_fetchall("""
SELECT batches.id, job_groups_n_jobs_in_complete_states.n_failed
FROM batches
LEFT JOIN job_groups_n_jobs_in_complete_states
ON batches.id = job_groups_n_jobs_in_complete_states.id
WHERE state = 'running' AND cancel_after_n_failures IS NOT NULL AND n_failed >= cancel_after_n_failures
""")
>>>>>>> f47efb4d4f95c9377cb1d15b4c06a61e4139334d
async for batch in records:
await _cancel_batch(app, batch['batch_id'])
async for job_group in records:
await _cancel_job_group(app, job_group['batch_id'], job_group['job_group_id'])


USER_CORES = pc.Gauge('batch_user_cores', 'Batch user cores (i.e. total in-use cores)', ['state', 'user', 'inst_coll'])
Expand Down Expand Up @@ -1618,7 +1606,7 @@ async def close_and_wait():
exit_stack.push_async_callback(app['task_manager'].shutdown_and_wait)

task_manager.ensure_future(periodically_call(10, monitor_billing_limits, app))
task_manager.ensure_future(periodically_call(10, cancel_fast_failing_batches, app))
task_manager.ensure_future(periodically_call(10, cancel_fast_failing_job_groups, app))
task_manager.ensure_future(periodically_call(60, scheduling_cancelling_bump, app))
task_manager.ensure_future(periodically_call(15, monitor_system, app))
task_manager.ensure_future(periodically_call(5, refresh_globals_from_db, app, db))
Expand Down
Loading

0 comments on commit 8aa3bb8

Please sign in to comment.