diff --git a/cdmtaskservice/jobflows/nersc_jaws.py b/cdmtaskservice/jobflows/nersc_jaws.py index b5dc0f6..8d4be16 100644 --- a/cdmtaskservice/jobflows/nersc_jaws.py +++ b/cdmtaskservice/jobflows/nersc_jaws.py @@ -133,7 +133,14 @@ async def download_complete(self, job: models.AdminJobDetails): await self._coman.run_coroutine(self._download_complete(job)) async def _download_complete(self, job: models.AdminJobDetails): - jaws_job_id = await self._nman.run_JAWS(job) - # TODO JAWS record job ID in DB logr = logging.getLogger(__name__) - logr.info(f"JAWS job id: {jaws_job_id}") + try: + # TODO PERF configure file download concurrency + jaws_job_id = await self._nman.run_JAWS(job) + # TODO JAWS record job ID in DB + logr.info(f"JAWS job id: {jaws_job_id}") + except Exception as e: + # TODO LOGGING figure out how logging it going to work etc. + logr.exception(f"Error starting JAWS job for job {job.id}") + # TODO IMPORTANT ERRORHANDLING update job state to ERROR w/ message and don't raise + raise e diff --git a/cdmtaskservice/nersc/manager.py b/cdmtaskservice/nersc/manager.py index ea30394..c05d832 100644 --- a/cdmtaskservice/nersc/manager.py +++ b/cdmtaskservice/nersc/manager.py @@ -384,22 +384,48 @@ def _base_manifest(self, op: str, concurrency: int, insecure_ssl: bool): "sec-per-GB": _SEC_PER_GB, } - async def run_JAWS(self, job: models.Job) -> str: + async def run_JAWS(self, job: models.Job, file_download_concurrency: int = 10) -> str: """ Run a JAWS job at NERSC and return the job ID. + + job - the job to process + file_download_concurrency - the number of files at one time to download to NERSC. """ + _check_int(file_download_concurrency, "file_download_concurrency") if not _not_falsy(job, "job").job_input.inputs_are_S3File(): raise ValueError("Job files must be S3File objects") manifest_files = generate_manifest_files(job) manifest_file_paths = self._get_manifest_file_paths(job.id, len(manifest_files)) fmap = {m: self._localize_s3_path(job.id, m.file) for m in job.job_input.input_files} wdljson = wdl.generate_wdl(job, fmap, manifest_file_paths) - # TODO REMOVE these lines - logr = logging.getLogger(__name__) - for m in manifest_files: - logr.info("***") - logr.info(m) - logr.info(f"*** wdl:\n{wdljson.wdl}\njson:\n{json.dumps(wdljson.input_json, indent=4)}") + uploads = {fp: f for fp, f in zip(manifest_file_paths, manifest_files)} + pre = self._dtn_scratch / _CTS_SCRATCH_ROOT_DIR / job.id + wdlpath = pre / "input.wdl" + jsonpath = pre / "input.json" + uploads[wdlpath] = wdljson.wdl + uploads[jsonpath] = json.dumps(wdljson.input_json, indent=4) + cli = self._client_provider() + dt = await cli.compute(_DT_TARGET) + semaphore = asyncio.Semaphore(file_download_concurrency) + async def sem_coro(coro): + async with semaphore: + return await coro + coros = [] + try: + async with asyncio.TaskGroup() as tg: + for path, file in uploads.items(): + coros.append(self._upload_file_to_nersc( + dt, path, bio=io.BytesIO(file.encode()) + )) + tg.create_task(sem_coro(coros[-1])) + except ExceptionGroup as eg: + e = eg.exceptions[0] # just pick one, essentially at random + raise e from eg + finally: + # otherwise you can get coroutine never awaited warnings if a failure occurs + for c in coros: + c.close() + # TODO NEXT run jaws job return "fake_job_id" def _get_manifest_file_paths(self, job_id: str, count: int) -> list[Path]: