diff --git a/cdmtaskservice/app_state.py b/cdmtaskservice/app_state.py index 0807d1e..2b9d6ed 100644 --- a/cdmtaskservice/app_state.py +++ b/cdmtaskservice/app_state.py @@ -9,17 +9,21 @@ import logging from motor.motor_asyncio import AsyncIOMotorClient from pathlib import Path -from typing import NamedTuple +from typing import NamedTuple, Any from fastapi import FastAPI, Request +from cdmtaskservice import models from cdmtaskservice.config import CDMTaskServiceConfig from cdmtaskservice.image_remote_lookup import DockerImageInfo from cdmtaskservice.images import Images +from cdmtaskservice.jobflows.nersc_jaws import NERSCJAWSRunner from cdmtaskservice.job_state import JobState from cdmtaskservice.kb_auth import KBaseAuth from cdmtaskservice.mongo import MongoDAO from cdmtaskservice.nersc.client import NERSCSFAPIClientProvider +from cdmtaskservice.nersc.manager import NERSCManager from cdmtaskservice.s3.client import S3Client +from cdmtaskservice.version import VERSION # The main point of this module is to handle all the application state in one place # to keep it consistent and allow for refactoring without breaking other code @@ -32,6 +36,10 @@ class AppState(NamedTuple): s3_client: S3Client job_state: JobState images: Images + # TODO CODE make an abstract jobflow class or something. For now just duck type them + # may not even need this, but not hard to do and shows how different flows might work + # in the future + runners: dict[models.Cluster, Any] async def build_app( @@ -44,6 +52,7 @@ async def build_app( app - the FastAPI app. cfg - the CDM task service config. """ + # This method is getting pretty long but it's stupid simple so... # May want to parallelize some of this for faster startups. would need to rework prints logr = logging.getLogger(__name__) logr.info("Connecting to KBase auth service... ") @@ -54,12 +63,23 @@ async def build_app( ) logr.info("Done") logr.info("Initializing NERSC SFAPI client... ") - nersc = await NERSCSFAPIClientProvider.create(Path(cfg.sfapi_cred_path), cfg.sfapi_user) + sfapi_client = await NERSCSFAPIClientProvider.create(Path(cfg.sfapi_cred_path), cfg.sfapi_user) + logr.info("Done") + logr.info("Setting up NERSC manager and installing code at NERSC...") + remote_code_loc = Path(cfg.nersc_remote_code_dir) / VERSION + nerscman = await NERSCManager.create(sfapi_client.get_client, remote_code_loc) logr.info("Done") logr.info("Initializing S3 client... ") s3 = await S3Client.create( cfg.s3_url, cfg.s3_access_key, cfg.s3_access_secret, insecure_ssl=cfg.s3_allow_insecure ) + s3_external = await S3Client.create( + cfg.s3_external_url, + cfg.s3_access_key, + cfg.s3_access_secret, + insecure_ssl=cfg.s3_allow_insecure, + skip_connection_check=not cfg.s3_verify_external_url + ) logr.info("Done") logr.info("Initializing MongoDB client...") mongocli = await get_mongo_client(cfg) @@ -67,10 +87,23 @@ async def build_app( try: mongodao = await MongoDAO.create(mongocli[cfg.mongo_db]) job_state = JobState(mongodao, s3) + nerscjawsflow = NERSCJAWSRunner( + nerscman, + job_state, + s3, + s3_external, + cfg.jaws_token, + cfg.jaws_group, + cfg.service_root_url, + s3_insecure_ssl=cfg.s3_allow_insecure, + ) + runners = {models.Cluster.PERLMUTTER_JAWS: nerscjawsflow} imginfo = await DockerImageInfo.create(Path(cfg.crane_path).expanduser().absolute()) images = Images(mongodao, imginfo) app.state._mongo = mongocli - app.state._cdmstate = AppState(auth, nersc, s3, job_state, images) + app.state._cdmstate = AppState( + auth, sfapi_client, s3, job_state, images, runners + ) except: mongocli.close() raise diff --git a/cdmtaskservice/jobflows/nersc_jaws.py b/cdmtaskservice/jobflows/nersc_jaws.py new file mode 100644 index 0000000..975a9b8 --- /dev/null +++ b/cdmtaskservice/jobflows/nersc_jaws.py @@ -0,0 +1,52 @@ +""" +Manages running jobs at NERSC using the JAWS system. +""" + +from cdmtaskservice.arg_checkers import not_falsy as _not_falsy, require_string as _require_string +from cdmtaskservice.job_state import JobState +from cdmtaskservice.nersc.manager import NERSCManager +from cdmtaskservice.s3.client import S3Client + +# Not sure how other flows would work and how much code they might share. For now just make +# this work and pull it apart / refactor later. + + +class NERSCJAWSRunner: + """ + Runs jobs at NERSC using JAWS. + """ + + def __init__( + self, + nersc_manager: NERSCManager, + job_state: JobState, + s3_client: S3Client, + s3_external_client: S3Client, + jaws_token: str, + jaws_group: str, + service_root_url: str, + s3_insecure_ssl: bool = False, + ): + """ + Create the runner. + + nersc_manager - the NERSC manager. + job_state - the job state manager. + s3_client - an S3 client pointed to the data stores. + s3_external_client - an S3 client pointing to an external URL for the S3 data stores + that may not be accessible from the current process, but is accessible to remote + processes at NERSC. + jaws_token - a token for the JGI JAWS system. + jaws_group - the group to use for running JAWS jobs. + service_root_url - the URL of the service root, used for constructing service callbacks. + s3_insecure_url - whether to skip checking the SSL certificate for the S3 instance, + leaving the service open to MITM attacks. + """ + self._nman = _not_falsy(nersc_manager, "nersc_manager") + self._jstate = _not_falsy(job_state, "job_state") + self._s3 = _not_falsy(s3_client, "s3_client") + self._s3ext = _not_falsy(s3_external_client, "s3_external_client") + self._s3insecure = s3_insecure_ssl + self._jtoken = _require_string(jaws_token, "jaws_token") + self._jgroup = _require_string(jaws_group, "jaws_group") + self._callback_root = _require_string(service_root_url, "service_root_url") diff --git a/test/jobflows/nersc_jaws_test.py b/test/jobflows/nersc_jaws_test.py new file mode 100644 index 0000000..ac28999 --- /dev/null +++ b/test/jobflows/nersc_jaws_test.py @@ -0,0 +1,7 @@ +# TODO TEST add tests + +from cdmtaskservice.jobflows import nersc_jaws # @UnusedImport + + +def test_noop(): + pass