Skip to content

Commit

Permalink
Merge pull request #140 from kbase/dev-service
Browse files Browse the repository at this point in the history
Load manifest, wdl, and input.json files to NERSC
  • Loading branch information
MrCreosote authored Jan 6, 2025
2 parents 289a6be + e87693e commit 3e959f7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
13 changes: 10 additions & 3 deletions cdmtaskservice/jobflows/nersc_jaws.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,14 @@ async def download_complete(self, job: models.AdminJobDetails):
await self._coman.run_coroutine(self._download_complete(job))

async def _download_complete(self, job: models.AdminJobDetails):
jaws_job_id = await self._nman.run_JAWS(job)
# TODO JAWS record job ID in DB
logr = logging.getLogger(__name__)
logr.info(f"JAWS job id: {jaws_job_id}")
try:
# TODO PERF configure file download concurrency
jaws_job_id = await self._nman.run_JAWS(job)
# TODO JAWS record job ID in DB
logr.info(f"JAWS job id: {jaws_job_id}")
except Exception as e:
# TODO LOGGING figure out how logging it going to work etc.
logr.exception(f"Error starting JAWS job for job {job.id}")
# TODO IMPORTANT ERRORHANDLING update job state to ERROR w/ message and don't raise
raise e
40 changes: 33 additions & 7 deletions cdmtaskservice/nersc/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,22 +384,48 @@ def _base_manifest(self, op: str, concurrency: int, insecure_ssl: bool):
"sec-per-GB": _SEC_PER_GB,
}

async def run_JAWS(self, job: models.Job) -> str:
async def run_JAWS(self, job: models.Job, file_download_concurrency: int = 10) -> str:
"""
Run a JAWS job at NERSC and return the job ID.
job - the job to process
file_download_concurrency - the number of files at one time to download to NERSC.
"""
_check_int(file_download_concurrency, "file_download_concurrency")
if not _not_falsy(job, "job").job_input.inputs_are_S3File():
raise ValueError("Job files must be S3File objects")
manifest_files = generate_manifest_files(job)
manifest_file_paths = self._get_manifest_file_paths(job.id, len(manifest_files))
fmap = {m: self._localize_s3_path(job.id, m.file) for m in job.job_input.input_files}
wdljson = wdl.generate_wdl(job, fmap, manifest_file_paths)
# TODO REMOVE these lines
logr = logging.getLogger(__name__)
for m in manifest_files:
logr.info("***")
logr.info(m)
logr.info(f"*** wdl:\n{wdljson.wdl}\njson:\n{json.dumps(wdljson.input_json, indent=4)}")
uploads = {fp: f for fp, f in zip(manifest_file_paths, manifest_files)}
pre = self._dtn_scratch / _CTS_SCRATCH_ROOT_DIR / job.id
wdlpath = pre / "input.wdl"
jsonpath = pre / "input.json"
uploads[wdlpath] = wdljson.wdl
uploads[jsonpath] = json.dumps(wdljson.input_json, indent=4)
cli = self._client_provider()
dt = await cli.compute(_DT_TARGET)
semaphore = asyncio.Semaphore(file_download_concurrency)
async def sem_coro(coro):
async with semaphore:
return await coro
coros = []
try:
async with asyncio.TaskGroup() as tg:
for path, file in uploads.items():
coros.append(self._upload_file_to_nersc(
dt, path, bio=io.BytesIO(file.encode())
))
tg.create_task(sem_coro(coros[-1]))
except ExceptionGroup as eg:
e = eg.exceptions[0] # just pick one, essentially at random
raise e from eg
finally:
# otherwise you can get coroutine never awaited warnings if a failure occurs
for c in coros:
c.close()
# TODO NEXT run jaws job
return "fake_job_id"

def _get_manifest_file_paths(self, job_id: str, count: int) -> list[Path]:
Expand Down

0 comments on commit 3e959f7

Please sign in to comment.