Skip to content

Commit

Permalink
Order tasks by record created_on, taking into account parent services
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed Jul 19, 2023
1 parent f8aeb95 commit d5aa7bb
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions qcfractal/qcfractal/components/tasks/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

import pydantic
from qcelemental.models import FailedOperation
from sqlalchemy import select
from sqlalchemy import select, func
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.orm import joinedload, Load
from sqlalchemy.orm import joinedload, aliased, Load

from qcfractal.components.managers.db_models import ComputeManagerORM
from qcfractal.components.record_db_models import BaseRecordORM
from qcfractal.components.services.db_models import ServiceQueueORM, ServiceDependencyORM
from qcportal.all_results import AllResultTypes
from qcportal.compression import CompressionEnum, compress
from qcportal.compression import decompress
Expand Down Expand Up @@ -248,6 +249,23 @@ def claim_tasks(
# to claim absolutely everything. So double check here
limit = calculate_limit(self._tasks_claim_limit, limit)

# CTE for finding the created_on from services which this record is a dependency of
# If a record is a dependency of a service, we use either the record's created_on or the service's created_on,
# whichever is earlier. That way a service doesn't have to wait for all other services to finish their tasks
# before it can finish.
br_task = aliased(BaseRecordORM) # BaseRecord for the task
br_svc = aliased(BaseRecordORM) # BaseRecord for services

least_date = func.least(br_task.created_on, func.min(br_svc.created_on)).label("created_on")
svcdate_cte = select(br_task.id.label("record_id"), least_date)
svcdate_cte = svcdate_cte.join(ServiceDependencyORM, ServiceDependencyORM.record_id == br_task.id)
svcdate_cte = svcdate_cte.join(ServiceQueueORM, ServiceQueueORM.id == ServiceDependencyORM.service_id)
svcdate_cte = svcdate_cte.join(br_svc, br_svc.id == ServiceQueueORM.record_id)
svcdate_cte = svcdate_cte.where(br_task.status == RecordStatusEnum.waiting)
svcdate_cte = svcdate_cte.group_by(br_task.id)
svcdate_cte = svcdate_cte.order_by(least_date.asc())
svcdate_cte = svcdate_cte.cte()

with self.root_socket.optional_session(session) as session:
stmt = select(ComputeManagerORM).where(ComputeManagerORM.name == manager_name)
stmt = stmt.with_for_update(skip_locked=False)
Expand Down Expand Up @@ -295,16 +313,23 @@ def claim_tasks(
BaseRecordORM.status, BaseRecordORM.manager_name, BaseRecordORM.modified_on
)
)

stmt = stmt.join(svcdate_cte, svcdate_cte.c.record_id == BaseRecordORM.id, isouter=True)
stmt = stmt.filter(BaseRecordORM.status == RecordStatusEnum.waiting)
stmt = stmt.filter(manager_programs.contains(TaskQueueORM.required_programs))
stmt = stmt.order_by(TaskQueueORM.priority.desc(), TaskQueueORM.created_on)

# Order by priority, then created_on (earliest first)
# Where the created_on may be the created_on of the parent service (see CTE above)
stmt = stmt.order_by(
TaskQueueORM.priority.desc(), func.least(BaseRecordORM.created_on, svcdate_cte.c.created_on).asc()
)

# If tag is "*", then the manager will pull anything
if tag != "*":
stmt = stmt.filter(TaskQueueORM.tag == tag)

# Skip locked rows - They may be in the process of being claimed by someone else
stmt = stmt.limit(new_limit).with_for_update(skip_locked=True)
stmt = stmt.limit(new_limit).with_for_update(of=[BaseRecordORM, TaskQueueORM], skip_locked=True)

new_items = session.execute(stmt).all()

Expand Down

0 comments on commit d5aa7bb

Please sign in to comment.