Skip to content

Commit

Permalink
docs: extend and correct docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Ayush5120 committed Mar 7, 2023
1 parent cc83844 commit 477baf8
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 99 deletions.
20 changes: 10 additions & 10 deletions pro_tes/ga4gh/tes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,8 @@ class TesEndpoint(CustomBaseModel):
Args:
host: Host at which the TES API is served that is processing this
request; note that this should include the path information but
*not* the base path path defined in the TES API specification;
e.g., specify https://my.tes.com/api if the actual API is hosted at
*not* the base path defined in the TES API specification; e.g.,
specify https://my.tes.com/api if the actual API is hosted a
https://my.tes.com/api/ga4gh/tes/v1.
base_path: Override the default path suffix defined in the TES API
specification, i.e., `/ga4gh/tes/v1`.
Expand All @@ -676,8 +676,8 @@ class TesEndpoint(CustomBaseModel):
Attributes:
host: Host at which the TES API is served that is processing this
request; note that this should include the path information but
*not* the base path path defined in the TES API specification;
e.g., specify https://my.tes.com/api if the actual API is hosted at
*not* the base path defined in the TES API specification; e.g.,
specify https://my.tes.com/api if the actual API is hosted at
https://my.tes.com/api/ga4gh/tes/v1.
base_path: Override the default path suffix defined in the TES API
specification, i.e., `/ga4gh/tes/v1`.
Expand All @@ -692,24 +692,24 @@ class DbDocument(CustomBaseModel):
"""Create model instance for task request database document.
Args:
task_incoming: Information about incoming task.
task_outgoing: Information about outgoing task.
task: Information about task.
task_original: Information about original task.
user_id: Identifier of resource owner.
worker_id: Identifier of worker task.
basic_auth: Basic authentication credentials.
tes_endpoint: External TES endpoint.
Attributes:
task_incoming: Information about incoming task.
task_outgoing: Information about outgoing task.
task: Information about task.
task_original: Information about original task.
user_id: Identifier of resource owner.
worker_id: Identifier of worker task.
basic_auth: Basic authentication credentials.
tes_endpoint: External TES endpoint.
"""

task_incoming: TesTask = TesTask()
task_outgoing: TesTask = TesTask(executors=[])
task: TesTask = TesTask()
task_original: TesTask = TesTask(executors=[])
user_id: Optional[str] = None
worker_id: str = ""
basic_auth: BasicAuth = BasicAuth()
Expand Down
131 changes: 84 additions & 47 deletions pro_tes/ga4gh/tes/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,24 @@ def create_task( # pylint: disable=too-many-statements,too-many-branches

db_document.basic_auth = self.parse_basic_auth(request.authorization)

db_document.task_outgoing = TesTask(**payload)
db_document.task_original = TesTask(**payload)

# middleware is called after the task is created in the database
payload = self.task_distributor.modify_request(request=request).json

tes_uri_list = deepcopy(payload["tes_uri"])
del payload["tes_uri"]

db_document.task_incoming = TesTask(**payload)
db_document = self._update_task_incoming(
db_document.task = TesTask(**payload)
db_document = self._update_task(
payload=payload,
db_document=db_document,
start_time=start_time,
**kwargs,
)
logger.info(
"Trying to forward incoming task with task identifier "
f"'{db_document.task_incoming.id}' and worker job identifier "
"Trying to forward task with task identifier "
f"'{db_document.task.id}' and worker job identifier "
f"'{db_document.worker_id}'"
)
db_connector = DbDocumentConnector(
Expand All @@ -108,7 +108,7 @@ def create_task( # pylint: disable=too-many-statements,too-many-branches
except TypeError as exc:
db_connector.update_task_state(state=TesState.SYSTEM_ERROR.value)
raise BadRequest(
f"Task '{db_document.task_incoming.id}' could not be "
f"Task '{db_document.task.id}' could not be "
f"validate. Original error message: '{type(exc).__name__}: "
f"{exc}'"
) from exc
Expand All @@ -131,7 +131,7 @@ def create_task( # pylint: disable=too-many-statements,too-many-branches
state=TesState.SYSTEM_ERROR.value
)
logger.info(
f"Task '{db_document.task_incoming.id}' could not "
f"Task '{db_document.task.id}' could not "
f"be sentto TES endpoint hosted at: {url}. Invalid TES"
" endpoint URL. Original error message: "
f"'{type(exc).__name__}: {exc}'"
Expand Down Expand Up @@ -161,7 +161,7 @@ def create_task( # pylint: disable=too-many-statements,too-many-branches
state=TesState.SYSTEM_ERROR.value
)
logger.info(
f"Task '{db_document.task_incoming.id}' "
f"Task '{db_document.task.id}' "
"could not be sent to TES endpoint hosted "
f"at: {url}. Task could not be created. Original "
f"error message: '{type(exc).__name__}: "
Expand All @@ -173,28 +173,28 @@ def create_task( # pylint: disable=too-many-statements,too-many-branches
f"Task '{remote_task_id}' "
"forwarded to TES endpoint "
f"hosted at: {url}. proTES task identifier: "
f"{db_document.task_incoming.id}."
f"{db_document.task.id}."
)
try:
task: Task = cli.get_task(remote_task_id)
task_model_converter = TaskModelConverter(task=task)
task_converted: TesTask = task_model_converter.convert_task()
db_document.task_incoming.state = task_converted.state
db_document.task.state = task_converted.state
except requests.HTTPError as exc:
logger.error(
f"Task '{db_document.task_incoming.id}' info could "
f"Task '{db_document.task.id}' info could "
"not be retrieved from TES endpoint hosted at: "
f"{url}. Original error message: "
f"'{type(exc).__name__}: {exc}'"
)
except PyMongoError as exc:
logger.error(
"Database could not be updated with task info "
f"retrieved for task '{db_document.task_incoming.id}'"
f"retrieved for task '{db_document.task.id}'"
f"sent to TES endpoint hosted at: {url}. "
f"Original error message:'{type(exc).__name__}: {exc}'"
)
# update task_logs, tes_endpoint and task_incoming in db
# update task_logs, tes_endpoint and task in db
db_document = self._update_doc_in_db(
db_connector=db_connector,
tes_uri=tes_uri,
Expand All @@ -211,7 +211,7 @@ def create_task( # pylint: disable=too-many-statements,too-many-branches
"password": db_document.basic_auth.password,
},
)
return {"id": db_document.task_incoming.id}
return {"id": db_document.task.id}

def list_tasks(self, **kwargs) -> Dict:
"""Return list of tasks.
Expand Down Expand Up @@ -253,13 +253,13 @@ def list_tasks(self, **kwargs) -> Dict:
for task in tasks_list:
del task["_id"]
if view == "MINIMAL":
task["id"] = task["task_incoming"]["id"]
task["state"] = task["task_incoming"]["state"]
task["id"] = task["task"]["id"]
task["state"] = task["task"]["state"]
tasks_lists.append({"id": task["id"], "state": task["state"]})
if view == "BASIC":
tasks_lists.append(task["task_incoming"])
tasks_lists.append(task["task"])
if view == "FULL":
tasks_lists.append(task["task_incoming"])
tasks_lists.append(task["task"])

return {"next_page_token": next_page_token, "tasks": tasks_lists}

Expand All @@ -280,12 +280,12 @@ def get_task(self, id=str, **kwargs) -> Dict:
"""
projection = self._set_projection(view=kwargs.get("view", "BASIC"))
document = self.db_client.find_one(
filter={"task_incoming.id": id}, projection=projection
filter={"task.id": id}, projection=projection
)
if document is None:
logger.error(f"Task '{id}' not found.")
raise TaskNotFound
return document["task_incoming"]
return document["task"]

def cancel_task(self, id: str, **kwargs) -> Dict:
"""Cancel task.
Expand All @@ -304,15 +304,15 @@ def cancel_task(self, id: str, **kwargs) -> Dict:
available.
"""
document = self.db_client.find_one(
filter={"task_incoming.id": id},
filter={"task.id": id},
projection={"_id": False},
)
if document is None:
logger.error(f"task '{id}' not found.")
raise TaskNotFound
db_document = DbDocument(**document)

if db_document.task_incoming.state in States.CANCELABLE:
if db_document.task.state in States.CANCELABLE:
db_connector = DbDocumentConnector(
collection=self.db_client,
worker_id=db_document.worker_id,
Expand All @@ -322,11 +322,11 @@ def cancel_task(self, id: str, **kwargs) -> Dict:
f"{db_document.tes_endpoint.base_path.strip('/')}"
)
if self.store_logs:
task_id = db_document.task_incoming.logs[
task_id = db_document.task.logs[
0
].metadata.forwarded_to.id
else:
task_id = db_document.task_incoming.logs[0].metadata[
task_id = db_document.task.logs[0].metadata[
"remote_task_id"
]
logger.info(
Expand Down Expand Up @@ -370,7 +370,7 @@ def _write_doc_to_db(

# try inserting until unused task id found
for _ in range(controller_config["db"]["insert_attempts"]):
document.task_incoming.id = generate_id(
document.task.id = generate_id(
charset=charset,
length=length,
)
Expand All @@ -380,14 +380,14 @@ def _write_doc_to_db(
except DuplicateKeyError:
continue
assert document is not None
return document.task_incoming.id, document.worker_id
return document.task.id, document.worker_id
raise DuplicateKeyError("Could not insert document into database.")

def _sanitize_request(self, payload: dict) -> Dict:
"""Sanitize request for use with py-tes.
Args:
payloads: Request payload.
payload: Request payload.
Returns:
Sanitized request payload.
Expand Down Expand Up @@ -417,7 +417,7 @@ def _sanitize_request(self, payload: dict) -> Dict:
return payload

def _set_projection(self, view: str) -> Dict:
"""Set database projectoin for selected view.
"""Set database projection for selected view.
Args:
view: View path parameter.
Expand All @@ -430,15 +430,15 @@ def _set_projection(self, view: str) -> Dict:
"""
if view == "MINIMAL":
projection = {
"task_incoming.id": True,
"task_incoming.state": True,
"task.id": True,
"task.state": True,
}
elif view == "BASIC":
projection = {
"task_incoming.inputs.content": False,
"task_incoming.system_logs": False,
"task_incoming.logs.stdout": False,
"task_incoming.logs.stderr": False,
"task.inputs.content": False,
"task.system_logs": False,
"task.logs.stdout": False,
"task.logs.stderr": False,
"tes_endpoint": False,
}
elif view == "FULL":
Expand All @@ -450,24 +450,41 @@ def _set_projection(self, view: str) -> Dict:
raise BadRequest
return projection

def _update_task_incoming(
def _update_task(
self, payload: dict, db_document: DbDocument, start_time: str, **kwargs
) -> DbDocument:
"""Update the task incoming object."""
"""Update the incoming task document.
Args:
payload: A dictionary containing the payload for the update.
db_document: The document in the database to be updated.
start_time: The starting time of the incoming TES request.
Returns:
DbDocument: The updated database document.
"""
logs = self._set_logs(
payloads=deepcopy(payload), start_time=start_time
)
db_document.task_incoming.logs = [TesTaskLog(**logs) for logs in logs]
db_document.task_incoming.state = TesState.UNKNOWN
db_document.task.logs = [TesTaskLog(**logs) for logs in logs]
db_document.task.state = TesState.UNKNOWN
db_document.user_id = kwargs.get("user_id", None)

(task_id, worker_id) = self._write_doc_to_db(document=db_document)
db_document.task_incoming.id = task_id
db_document.task.id = task_id
db_document.worker_id = worker_id
return db_document

def _set_logs(self, payloads: dict, start_time: str) -> Dict:
"""Set up the logs for the incoming request."""
"""Create or update TesTask.logs and set start time.
Args:
payload: A dictionary containing the payload for the update.
start_time: The starting time of the incoming TES request.
Returns:
Task logs with start time set.
"""
if "logs" not in payloads.keys():
logs = [
{
Expand All @@ -491,7 +508,16 @@ def _update_doc_in_db(
tes_uri: str,
remote_task_id: str,
) -> DbDocument:
"""Update the document in the database."""
"""Set end time, task metadata in TesTask.logs, and update document.
Args:
db_connector: The database connector.
tes_uri: The TES URI where the task if forwarded.
remote_task_id: Task identifier at the remote TES instance.
Returns:
The updated database document.
"""
time_now = datetime.now().strftime("%m-%d-%Y %H:%M:%S")
tes_endpoint_dict = {"host": tes_uri, "base_path": ""}
db_document = db_connector.upsert_fields_in_root_object(
Expand All @@ -503,7 +529,7 @@ def _update_doc_in_db(
"finally to database "
)
# updating the end time in TesTask logs
for logs in db_document.task_incoming.logs:
for logs in db_document.task.logs:
logs.end_time = time_now

# updating the metadata in TesTask logs
Expand All @@ -514,22 +540,33 @@ def _update_doc_in_db(
remote_task_id=remote_task_id,
)
else:
for logs in db_document.task_incoming.logs:
for logs in db_document.task.logs:
logs.metadata = {"remote_task_id": remote_task_id}

db_document = db_connector.upsert_fields_in_root_object(
root="task_incoming",
**db_document.dict()["task_incoming"],
root="task",
**db_document.dict()["task"],
)
logger.info(
f"Task '{db_document.task_incoming}' inserted to database "
f"Task '{db_document.task}' inserted to database "
)
return db_document

def _update_task_metadata(
self, db_document: DbDocument, tes_uri: str, remote_task_id: str
) -> DbDocument:
"""Update the task metadata."""
"""Update the task metadata.
Set TES endpoint and remote task identifier in `TesTask.logs.metadata`.
Args:
db_document: The document in the database to be updated.
tes_uri: The TES URI where the task if forwarded.
remote_task_id: Task identifier at the remote TES instance.
Returns:
The updated database document.
"""
for logs in db_document.task_incoming.logs:
tesNextTes_obj = TesNextTes(id=remote_task_id, url=tes_uri)
if logs.metadata.forwarded_to is None:
Expand Down
Loading

0 comments on commit 477baf8

Please sign in to comment.