Skip to content

Commit

Permalink
Clean up task & service queue
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed Aug 16, 2023
1 parent 25f1ea2 commit 2bf3f9a
Show file tree
Hide file tree
Showing 33 changed files with 71 additions and 97 deletions.
6 changes: 3 additions & 3 deletions qcarchivetesting/qcarchivetesting/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

if TYPE_CHECKING:
from typing import List, Dict, Any
from qcportal.tasks import TaskInformation
from qcportal.record_models import RecordTask
from qcfractal.config import FractalConfig
from qcfractalcompute.apps.models import AppTaskResult

Expand Down Expand Up @@ -79,7 +79,7 @@ def postprocess_results(self, results: Dict[int, AppTaskResult]):
clean_conda_env(r_dict)
self._result_queue.put((self._task_map[task_id], r_dict))

def preprocess_new_tasks(self, new_tasks: List[TaskInformation]):
def preprocess_new_tasks(self, new_tasks: List[RecordTask]):
for task in new_tasks:
# Store the full task by task id
self._task_map[task.id] = task
Expand Down Expand Up @@ -112,7 +112,7 @@ def _stop(cls, compute, compute_thread):
def stop(self) -> None:
self._finalizer()

def get_data(self) -> List[tuple[TaskInformation, Dict[str, Any]]]:
def get_data(self) -> List[tuple[RecordTask, Dict[str, Any]]]:
# Returns list of iterable (task, result)
data = []

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Remove created_on from task/service queue
Revision ID: d1ee87a66b71
Revises: e6f5053c7600
Create Date: 2023-08-15 17:58:28.061638
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "d1ee87a66b71"
down_revision = "e6f5053c7600"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_service_queue_waiting_sort", table_name="service_queue")
op.drop_column("service_queue", "created_on")
op.drop_index("ix_task_queue_waiting_sort", table_name="task_queue")
op.drop_column("task_queue", "created_on")
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("task_queue", sa.Column("created_on", postgresql.TIMESTAMP(), autoincrement=False, nullable=False))
op.create_index("ix_task_queue_waiting_sort", "task_queue", [sa.text("priority DESC"), "created_on"], unique=False)
op.add_column("service_queue", sa.Column("created_on", postgresql.TIMESTAMP(), autoincrement=False, nullable=False))
op.create_index(
"ix_service_queue_waiting_sort", "service_queue", [sa.text("priority DESC"), "created_on"], unique=False
)
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_gridoptimization_client_add_get(

assert time_0 < r.created_on < time_1
assert time_0 < r.modified_on < time_1
assert time_0 < r.service.created_on < time_1

assert recs[0].initial_molecule.identifiers.molecule_hash == hooh.get_hash()
assert recs[1].initial_molecule.identifiers.molecule_hash == h3ns.get_hash()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def test_gridoptimization_socket_add_get(

assert time_0 < r.created_on < time_1
assert time_0 < r.modified_on < time_1
assert time_0 < r.service.created_on < time_1

assert recs[0].initial_molecule.identifiers["molecule_hash"] == hooh.get_hash()
assert recs[1].initial_molecule.identifiers["molecule_hash"] == h3ns.get_hash()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
from qcfractal.testing_helpers import run_service
from qcportal.gridoptimization import GridoptimizationSpecification, GridoptimizationKeywords
from qcportal.optimization import OptimizationSpecification
from qcportal.record_models import PriorityEnum, RecordStatusEnum
from qcportal.record_models import PriorityEnum, RecordStatusEnum, RecordTask
from qcportal.singlepoint import SinglepointProtocols, QCSpecification
from qcportal.utils import recursive_normalizer

if TYPE_CHECKING:
from qcfractal.db_socket import SQLAlchemySocket
from qcportal.managers import ManagerName
from qcportal.tasks import TaskInformation


def compare_gridoptimization_specs(
Expand Down Expand Up @@ -83,7 +82,7 @@ def compare_gridoptimization_specs(
]


def generate_task_key(task: TaskInformation):
def generate_task_key(task: RecordTask):
# task is an optimization
inp_data = task.function_kwargs["input_data"]
assert inp_data["schema_name"] in "qcschema_optimization_input"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def test_manybody_client_add_get(

assert time_0 < r.created_on < time_1
assert time_0 < r.modified_on < time_1
assert time_0 < r.service.created_on < time_1

assert recs[0].initial_molecule.get_hash() == water2.get_hash()
assert recs[1].initial_molecule.get_hash() == water4.get_hash()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test_manybody_socket_add_get(storage_socket: SQLAlchemySocket, session: Sess

assert time_0 < r.created_on < time_1
assert time_0 < r.modified_on < time_1
assert time_0 < r.service.created_on < time_1

assert recs[0].initial_molecule.identifiers["molecule_hash"] == water2.get_hash()
assert recs[1].initial_molecule.identifiers["molecule_hash"] == water4.get_hash()
Expand Down
5 changes: 2 additions & 3 deletions qcfractal/qcfractal/components/manybody/testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
from qcfractal.components.manybody.record_db_models import ManybodyRecordORM
from qcfractal.testing_helpers import run_service
from qcportal.manybody import ManybodySpecification, ManybodyKeywords
from qcportal.record_models import PriorityEnum, RecordStatusEnum
from qcportal.record_models import PriorityEnum, RecordStatusEnum, RecordTask
from qcportal.singlepoint import SinglepointProtocols, QCSpecification

if TYPE_CHECKING:
from qcfractal.db_socket import SQLAlchemySocket
from qcportal.managers import ManagerName
from qcportal.tasks import TaskInformation

test_specs = [
ManybodySpecification(
Expand Down Expand Up @@ -67,7 +66,7 @@ def compare_manybody_specs(
return input_spec == output_spec


def generate_task_key(task: TaskInformation):
def generate_task_key(task: RecordTask):
# task is a singlepoint
inp_data = task.function_kwargs["input_data"]
assert inp_data["schema_name"] in "qcschema_input"
Expand Down
1 change: 0 additions & 1 deletion qcfractal/qcfractal/components/neb/test_record_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def test_neb_client_add_get(submitter_client: PortalClient, spec: NEBSpecificati

assert time_0 < r.created_on < time_1
assert time_0 < r.modified_on < time_1
assert time_0 < r.service.created_on < time_1

assert len(recs[0].initial_chain) == 11 # default image number
assert len(recs[1].initial_chain) == 11
Expand Down
1 change: 0 additions & 1 deletion qcfractal/qcfractal/components/neb/test_record_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def test_neb_socket_add_get(storage_socket: SQLAlchemySocket, session: Session,

assert time_0 < r.created_on < time_1
assert time_0 < r.modified_on < time_1
assert time_0 < r.service.created_on < time_1

assert len(recs[0].initial_chain) == spec.keywords.images
assert len(recs[1].initial_chain) == spec.keywords.images
Expand Down
5 changes: 2 additions & 3 deletions qcfractal/qcfractal/components/neb/testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
from qcfractal.testing_helpers import run_service
from qcportal.generic_result import GenericTaskResult
from qcportal.neb import NEBSpecification, NEBKeywords
from qcportal.record_models import PriorityEnum, RecordStatusEnum
from qcportal.record_models import PriorityEnum, RecordStatusEnum, RecordTask
from qcportal.singlepoint import SinglepointProtocols, QCSpecification
from qcportal.utils import recursive_normalizer, hash_dict

if TYPE_CHECKING:
from qcfractal.db_socket import SQLAlchemySocket
from qcportal.managers import ManagerName
from qcportal.tasks import TaskInformation

test_specs = [
NEBSpecification(
Expand Down Expand Up @@ -65,7 +64,7 @@ def compare_neb_specs(
return input_spec == output_spec


def generate_task_key(task: TaskInformation):
def generate_task_key(task: RecordTask):
if task.function in ("qcengine.compute", "qcengine.compute_procedure"):
inp_data = task.function_kwargs["input_data"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def test_optimization_client_add_get(

assert time_0 < r.created_on < time_1
assert time_0 < r.modified_on < time_1
assert time_0 < r.task.created_on < time_1

mol1 = submitter_client.get_molecules([recs[0].initial_molecule_id])[0]
mol2 = submitter_client.get_molecules([recs[1].initial_molecule_id])[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
from qcportal.optimization import (
OptimizationSpecification,
)
from qcportal.record_models import RecordStatusEnum, PriorityEnum
from qcportal.record_models import RecordStatusEnum, PriorityEnum, RecordTask
from qcportal.singlepoint import (
QCSpecification,
SinglepointDriver,
SinglepointProtocols,
)
from qcportal.tasks import TaskInformation

if TYPE_CHECKING:
from qcfractal.db_socket import SQLAlchemySocket
Expand All @@ -47,7 +46,7 @@ def test_optimization_socket_task_spec(
assert meta.success

tasks = storage_socket.tasks.claim_tasks(activated_manager_name.fullname, activated_manager_programs, ["*"])
tasks = [TaskInformation(**t) for t in tasks]
tasks = [RecordTask(**t) for t in tasks]

assert len(tasks) == 3
for t in tasks:
Expand All @@ -73,8 +72,6 @@ def test_optimization_socket_task_spec(
assert t.tag == "tag1"
assert t.priority == PriorityEnum.low

assert time_0 < t.created_on < time_1

rec_id_mol_map = {
id[0]: all_mols[0],
id[1]: all_mols[1],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def test_reaction_client_add_get(

assert time_0 < r.created_on < time_1
assert time_0 < r.modified_on < time_1
assert time_0 < r.service.created_on < time_1

mol_hash_0 = set(x.molecule.identifiers.molecule_hash for x in recs[0].components)
mol_hash_1 = set(x.molecule.identifiers.molecule_hash for x in recs[1].components)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def test_reaction_socket_add_get(storage_socket: SQLAlchemySocket, session: Sess

assert time_0 < r.created_on < time_1
assert time_0 < r.modified_on < time_1
assert time_0 < r.service.created_on < time_1

mol_hash_0 = set(x.molecule.identifiers["molecule_hash"] for x in recs[0].components)
mol_hash_1 = set(x.molecule.identifiers["molecule_hash"] for x in recs[1].components)
Expand Down
5 changes: 2 additions & 3 deletions qcfractal/qcfractal/components/reaction/testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
from qcfractal.components.reaction.record_db_models import ReactionRecordORM
from qcfractal.testing_helpers import run_service
from qcportal.reaction import ReactionSpecification, ReactionKeywords
from qcportal.record_models import PriorityEnum, RecordStatusEnum
from qcportal.record_models import PriorityEnum, RecordStatusEnum, RecordTask
from qcportal.singlepoint import SinglepointProtocols, QCSpecification

if TYPE_CHECKING:
from qcfractal.db_socket import SQLAlchemySocket
from qcportal.managers import ManagerName
from qcportal.tasks import TaskInformation

test_specs = [
ReactionSpecification(
Expand Down Expand Up @@ -52,7 +51,7 @@ def compare_reaction_specs(
return input_spec == output_spec


def generate_task_key(task: TaskInformation):
def generate_task_key(task: RecordTask):
inp_data = task.function_kwargs["input_data"]

if inp_data["schema_name"] == "qcschema_optimization_input":
Expand Down
4 changes: 0 additions & 4 deletions qcfractal/qcfractal/components/services/db_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import datetime
from typing import Optional, Iterable, Dict, Any

from sqlalchemy import (
Expand All @@ -9,7 +8,6 @@
JSON,
ForeignKey,
String,
DateTime,
Boolean,
Index,
UniqueConstraint,
Expand Down Expand Up @@ -64,7 +62,6 @@ class ServiceQueueORM(BaseORM):

tag = Column(String, nullable=False)
priority = Column(Integer, nullable=False)
created_on = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
find_existing = Column(Boolean, nullable=False)

service_state = Column(PlainMsgpackExt)
Expand All @@ -76,7 +73,6 @@ class ServiceQueueORM(BaseORM):
__table_args__ = (
UniqueConstraint("record_id", name="ux_service_queue_record_id"),
Index("ix_service_queue_tag", "tag"),
Index("ix_service_queue_waiting_sort", priority.desc(), created_on),
CheckConstraint("tag = LOWER(tag)", name="ck_service_queue_tag_lower"),
)

Expand Down
2 changes: 1 addition & 1 deletion qcfractal/qcfractal/components/services/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def iterate_services(self, session: Session, job_progress: JobProgress) -> int:
.join(ServiceQueueORM.record)
.options(contains_eager(ServiceQueueORM.record))
.filter(BaseRecordORM.status == RecordStatusEnum.waiting)
.order_by(ServiceQueueORM.priority.desc(), ServiceQueueORM.created_on)
.order_by(ServiceQueueORM.priority.desc(), BaseRecordORM.created_on)
.limit(new_service_count)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def test_singlepoint_client_add_get(submitter_client: PortalClient, spec: QCSpec
assert r.owner_group == owner_group
assert time_0 < r.created_on < time_1
assert time_0 < r.modified_on < time_1
assert time_0 < r.task.created_on < time_1

assert recs[0].molecule == water
assert recs[1].molecule == hooh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from qcportal.compression import decompress
from qcportal.managers import ManagerName
from qcportal.molecules import Molecule
from qcportal.record_models import RecordStatusEnum, PriorityEnum
from qcportal.record_models import RecordStatusEnum, PriorityEnum, RecordTask
from qcportal.singlepoint import QCSpecification, SinglepointDriver, SinglepointProtocols
from qcportal.tasks import TaskInformation
from .record_db_models import SinglepointRecordORM
from .testing_helpers import test_specs, load_test_data, run_test_data

Expand Down Expand Up @@ -45,7 +44,7 @@ def test_singlepoint_socket_task_spec(
assert meta.success

tasks = storage_socket.tasks.claim_tasks(activated_manager_name.fullname, activated_manager_programs, ["*"])
tasks = [TaskInformation(**t) for t in tasks]
tasks = [RecordTask(**t) for t in tasks]

assert len(tasks) == 3
for t in tasks:
Expand All @@ -56,7 +55,6 @@ def test_singlepoint_socket_task_spec(
assert function_kwargs["program"] == spec.program
assert t.tag == "tag1"
assert t.priority == PriorityEnum.low
assert time_0 < t.created_on < time_1

rec_id_mol_map = {
id[0]: all_mols[0],
Expand Down
6 changes: 0 additions & 6 deletions qcfractal/qcfractal/components/tasks/db_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import datetime

from sqlalchemy import (
Column,
Integer,
String,
DateTime,
ForeignKey,
Index,
LargeBinary,
Expand Down Expand Up @@ -37,8 +34,6 @@ class TaskQueueORM(BaseORM):
tag = Column(String, nullable=False)
priority = Column(Integer, nullable=False)

created_on = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)

record_id = Column(Integer, ForeignKey(BaseRecordORM.id, ondelete="cascade"), nullable=False)
record = relationship(BaseRecordORM, back_populates="task", uselist=False)

Expand All @@ -49,7 +44,6 @@ class TaskQueueORM(BaseORM):
__table_args__ = (
Index("ix_task_queue_tag", "tag"),
Index("ix_task_queue_required_programs", "required_programs"),
Index("ix_task_queue_waiting_sort", priority.desc(), created_on),
UniqueConstraint("record_id", name="ux_task_queue_record_id"),
# WARNING - these are not autodetected by alembic
CheckConstraint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def test_torsiondrive_client_add_get(

assert time_0 < r.created_on < time_1
assert time_0 < r.modified_on < time_1
assert time_0 < r.service.created_on < time_1

assert len(recs[0].initial_molecules) == 1
assert len(recs[1].initial_molecules) == 2
Expand Down
Loading

0 comments on commit 2bf3f9a

Please sign in to comment.