From 61fc7a6f7956428e61b41b5a353a01780763faf3 Mon Sep 17 00:00:00 2001 From: Gavin Date: Mon, 30 Dec 2024 13:09:46 -0800 Subject: [PATCH] Trigger upload files job state In future will need to poll JAWS waiting for job to be complete since JAWS doesn't support callbacks --- cdmtaskservice/app_state.py | 12 +++++++++- cdmtaskservice/config.py | 3 +++ cdmtaskservice/jobflows/nersc_jaws.py | 34 ++++++++++++++++++++++++++- cdmtaskservice/routes.py | 5 ++-- cdmtaskservice_config.toml.jinja | 3 +++ 5 files changed, 53 insertions(+), 4 deletions(-) diff --git a/cdmtaskservice/app_state.py b/cdmtaskservice/app_state.py index 23030a9..c65c7d8 100644 --- a/cdmtaskservice/app_state.py +++ b/cdmtaskservice/app_state.py @@ -17,6 +17,7 @@ from cdmtaskservice.coroutine_manager import CoroutineWrangler from cdmtaskservice.image_remote_lookup import DockerImageInfo from cdmtaskservice.images import Images +from cdmtaskservice.jaws.client import JAWSClient from cdmtaskservice.jobflows.nersc_jaws import NERSCJAWSRunner from cdmtaskservice.job_state import JobState from cdmtaskservice.job_submit import JobSubmit @@ -35,6 +36,7 @@ class AppState(NamedTuple): """ Holds application state. """ auth: KBaseAuth sfapi_client: NERSCSFAPIClientProvider + jaws_client: JAWSClient s3_client: S3Client job_submit: JobSubmit job_state: JobState @@ -93,11 +95,16 @@ async def build_app( logr.info("Initializing MongoDB client...") mongocli = await get_mongo_client(cfg) logr.info("Done") + jaws_client = None try: + logr.info("Initializing JAWS Central client... ") + jaws_client = await JAWSClient.create(cfg.jaws_url, cfg.jaws_token) + logr.info("Done") mongodao = await MongoDAO.create(mongocli[cfg.mongo_db]) job_state = JobState(mongodao) nerscjawsflow = NERSCJAWSRunner( # this has a lot of required args, yech nerscman, + jaws_client, job_state, # TODO CODE if this isn't necessary, remove and recombine with job_submit mongodao, s3, @@ -113,10 +120,12 @@ async def build_app( app.state._mongo = mongocli app.state._coroman = coman app.state._cdmstate = AppState( - auth, sfapi_client, s3, job_submit, job_state, images, runners + auth, sfapi_client, jaws_client, s3, job_submit, job_state, images, runners ) except: mongocli.close() + if jaws_client: + await jaws_client.close() raise @@ -135,6 +144,7 @@ async def destroy_app_state(app: FastAPI): app.state._mongo.close() app.state._coroman.destroy() await appstate.sfapi_client.destroy() + await appstate.jaws_client.close() # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown await asyncio.sleep(0.250) diff --git a/cdmtaskservice/config.py b/cdmtaskservice/config.py index 3e45745..645a938 100644 --- a/cdmtaskservice/config.py +++ b/cdmtaskservice/config.py @@ -32,6 +32,7 @@ class CDMTaskServiceConfig: as the remaining lines. sfapi_user: str - the user name of the user accociated with the credentials. nersc_remote_code_dir: str - the location at NERSC to upload remote code. + jaws_url: str - the URL of the JAWS Central service. jaws_token: str - the JAWS token used to run jobs. jaws_group: str - the JAWS group used to run jobs. s3_url: str - the URL of the S3 instance to use for data storage. @@ -78,6 +79,7 @@ def __init__(self, config_file: BinaryIO): self.sfapi_cred_path = _get_string_required(config, _SEC_NERSC, "sfapi_cred_path") self.sfapi_user = _get_string_required(config, _SEC_NERSC, "sfapi_user") self.nersc_remote_code_dir = _get_string_required(config, _SEC_NERSC, "remote_code_dir") + self.jaws_url = _get_string_required(config, _SEC_JAWS, "url") self.jaws_token = _get_string_required(config, _SEC_JAWS, "token") self.jaws_group = _get_string_required(config, _SEC_JAWS, "group") self.s3_url = _get_string_required(config, _SEC_S3, "url") @@ -116,6 +118,7 @@ def print_config(self, output: TextIO): f"NERSC client credential path: {self.sfapi_cred_path}", f"NERSC client user: {self.sfapi_user}", f"NERSC remote code dir: {self.nersc_remote_code_dir}", + f"JAWS Central URL: {self.jaws_url}", "JAWS token: REDACTED FOR THE NATIONAL SECURITY OF GONDWANALAND", f"JAWS group: {self.jaws_group}", f"S3 URL: {self.s3_url}", diff --git a/cdmtaskservice/jobflows/nersc_jaws.py b/cdmtaskservice/jobflows/nersc_jaws.py index fdb631f..82a8c3e 100644 --- a/cdmtaskservice/jobflows/nersc_jaws.py +++ b/cdmtaskservice/jobflows/nersc_jaws.py @@ -3,6 +3,7 @@ """ import logging +from typing import Any from cdmtaskservice import models from cdmtaskservice import timestamp @@ -10,6 +11,7 @@ from cdmtaskservice.callback_url_paths import get_download_complete_callback from cdmtaskservice.coroutine_manager import CoroutineWrangler from cdmtaskservice.exceptions import InvalidJobStateError +from cdmtaskservice.jaws.client import JAWSClient from cdmtaskservice.job_state import JobState from cdmtaskservice.mongo import MongoDAO from cdmtaskservice.nersc.manager import NERSCManager @@ -28,6 +30,7 @@ class NERSCJAWSRunner: def __init__( self, nersc_manager: NERSCManager, + jaws_client: JAWSClient, job_state: JobState, mongodao: MongoDAO, s3_client: S3Client, @@ -40,6 +43,7 @@ def __init__( Create the runner. nersc_manager - the NERSC manager. + jaws_client - a JAWS Central client. job_state - the job state manager. mongodao - the Mongo DAO object. s3_client - an S3 client pointed to the data stores. @@ -52,6 +56,7 @@ def __init__( leaving the service open to MITM attacks. """ self._nman = _not_falsy(nersc_manager, "nersc_manager") + self._jaws = _not_falsy(jaws_client, "jaws_client") self._jstate = _not_falsy(job_state, "job_state") self._mongo = _not_falsy(mongodao, "mongodao") self._s3 = _not_falsy(s3_client, "s3_client") @@ -89,7 +94,6 @@ async def start_job(self, job: models.Job, objmeta: list[S3ObjectMeta]): # will be deleted automatically by JAWS, or need own file deletion # TODO DISKSPACE will need to clean up job downloads @ NERSC # TODO LOGGING make the remote code log summary of results and upload and store - # TODO NOW how get remote paths at next step? task_id = await self._nman.download_s3_files( job.id, objmeta, presigned, callback_url, insecure_ssl=self._s3insecure ) @@ -152,3 +156,31 @@ async def _submit_jaws_job(self, job: models.AdminJobDetails): logr.exception(f"Error starting JAWS job for job {job.id}") # TODO IMPORTANT ERRORHANDLING update job state to ERROR w/ message and don't raise raise e + + async def job_complete(self, job: models.AdminJobDetails): + """ + Continue a job after the remote job run is complete. The job is expected to be in the + job submitted state. + """ + if _not_falsy(job, "job").state != models.JobState.JOB_SUBMITTED: + raise InvalidJobStateError("Job must be in the job submitted state") + # We assume this is a jaws job if it was mapped to this runner + # TODO RETRIES this line might need changes + jaws_info = await self._jaws.status(job.jaws_details.run_id[-1]) + if jaws_info["status"] != "done": + raise InvalidJobStateError("JAWS run is incomplete") + # TODO ERRHANDLING IMPORTANT if in an error state, pull the erros.json file from the + # JAWS job dir and add stderr / out to job record (what do to about huge + # logs?) and set job to error + await self._mongo.update_job_state( + job.id, + models.JobState.JOB_SUBMITTED, + models.JobState.UPLOAD_SUBMITTING, + timestamp.utcdatetime() + ) + await self._coman.run_coroutine(self._upload_files(job, jaws_info)) + + async def _upload_files(self, job: models.AdminJobDetails, jaws_info: dict[str, Any]): + logr = logging.getLogger(__name__) + # TODO REMOVE after implementing file upload + logr.info(f"Starting file upload for job {job.id} JAWS run {jaws_info['id']}") diff --git a/cdmtaskservice/routes.py b/cdmtaskservice/routes.py index 1fbfd8a..28c4c87 100644 --- a/cdmtaskservice/routes.py +++ b/cdmtaskservice/routes.py @@ -260,8 +260,9 @@ async def job_complete( job_id: _ANN_JOB_ID ): logging.getLogger(__name__).info(f"Remote job reported as complete for job {job_id}") - # TODO JOBS implement when job is complete - raise NotImplementedError() + appstate = app_state.get_app_state(r) + job = await appstate.job_state.get_job(job_id, _SERVICE_USER, as_admin=True) + await appstate.runners[job.job_input.cluster].job_complete(job) class ClientLifeTimeError(Exception): diff --git a/cdmtaskservice_config.toml.jinja b/cdmtaskservice_config.toml.jinja index 7c3f3f9..fc80252 100644 --- a/cdmtaskservice_config.toml.jinja +++ b/cdmtaskservice_config.toml.jinja @@ -39,6 +39,9 @@ remote_code_dir = "{{ KBCTS_NERSC_REMOTE_CODE_DIR or "/global/cfs/cdirs/kbase/cd [JAWS] +# The JAWS Central server URL. +url = "{{ KBCTS_JAWS_URL or "https://jaws-api.jgi.doe.gov/api/v2" }}" + # The JGI JAWS token to use to run jobs. This can be obtained from a JAWS representative. token = "{{ KBCTS_JAWS_TOKEN or "" }}"