Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load manifest, wdl, and input.json files to NERSC #140

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions cdmtaskservice/jobflows/nersc_jaws.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,14 @@
await self._coman.run_coroutine(self._download_complete(job))

async def _download_complete(self, job: models.AdminJobDetails):
jaws_job_id = await self._nman.run_JAWS(job)
# TODO JAWS record job ID in DB
logr = logging.getLogger(__name__)
logr.info(f"JAWS job id: {jaws_job_id}")
try:

Check warning on line 137 in cdmtaskservice/jobflows/nersc_jaws.py

View check run for this annotation

Codecov / codecov/patch

cdmtaskservice/jobflows/nersc_jaws.py#L137

Added line #L137 was not covered by tests
# TODO PERF configure file download concurrency
jaws_job_id = await self._nman.run_JAWS(job)

Check warning on line 139 in cdmtaskservice/jobflows/nersc_jaws.py

View check run for this annotation

Codecov / codecov/patch

cdmtaskservice/jobflows/nersc_jaws.py#L139

Added line #L139 was not covered by tests
# TODO JAWS record job ID in DB
logr.info(f"JAWS job id: {jaws_job_id}")
except Exception as e:

Check warning on line 142 in cdmtaskservice/jobflows/nersc_jaws.py

View check run for this annotation

Codecov / codecov/patch

cdmtaskservice/jobflows/nersc_jaws.py#L141-L142

Added lines #L141 - L142 were not covered by tests
# TODO LOGGING figure out how logging it going to work etc.
logr.exception(f"Error starting JAWS job for job {job.id}")

Check warning on line 144 in cdmtaskservice/jobflows/nersc_jaws.py

View check run for this annotation

Codecov / codecov/patch

cdmtaskservice/jobflows/nersc_jaws.py#L144

Added line #L144 was not covered by tests
# TODO IMPORTANT ERRORHANDLING update job state to ERROR w/ message and don't raise
raise e

Check warning on line 146 in cdmtaskservice/jobflows/nersc_jaws.py

View check run for this annotation

Codecov / codecov/patch

cdmtaskservice/jobflows/nersc_jaws.py#L146

Added line #L146 was not covered by tests
40 changes: 33 additions & 7 deletions cdmtaskservice/nersc/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,22 +384,48 @@
"sec-per-GB": _SEC_PER_GB,
}

async def run_JAWS(self, job: models.Job) -> str:
async def run_JAWS(self, job: models.Job, file_download_concurrency: int = 10) -> str:
"""
Run a JAWS job at NERSC and return the job ID.

job - the job to process
file_download_concurrency - the number of files at one time to download to NERSC.
"""
_check_int(file_download_concurrency, "file_download_concurrency")

Check warning on line 394 in cdmtaskservice/nersc/manager.py

View check run for this annotation

Codecov / codecov/patch

cdmtaskservice/nersc/manager.py#L394

Added line #L394 was not covered by tests
if not _not_falsy(job, "job").job_input.inputs_are_S3File():
raise ValueError("Job files must be S3File objects")
manifest_files = generate_manifest_files(job)
manifest_file_paths = self._get_manifest_file_paths(job.id, len(manifest_files))
fmap = {m: self._localize_s3_path(job.id, m.file) for m in job.job_input.input_files}
wdljson = wdl.generate_wdl(job, fmap, manifest_file_paths)
# TODO REMOVE these lines
logr = logging.getLogger(__name__)
for m in manifest_files:
logr.info("***")
logr.info(m)
logr.info(f"*** wdl:\n{wdljson.wdl}\njson:\n{json.dumps(wdljson.input_json, indent=4)}")
uploads = {fp: f for fp, f in zip(manifest_file_paths, manifest_files)}
pre = self._dtn_scratch / _CTS_SCRATCH_ROOT_DIR / job.id
wdlpath = pre / "input.wdl"
jsonpath = pre / "input.json"
uploads[wdlpath] = wdljson.wdl
uploads[jsonpath] = json.dumps(wdljson.input_json, indent=4)
cli = self._client_provider()
dt = await cli.compute(_DT_TARGET)
semaphore = asyncio.Semaphore(file_download_concurrency)
async def sem_coro(coro):
async with semaphore:
return await coro
coros = []
try:
async with asyncio.TaskGroup() as tg:
for path, file in uploads.items():
coros.append(self._upload_file_to_nersc(

Check warning on line 417 in cdmtaskservice/nersc/manager.py

View check run for this annotation

Codecov / codecov/patch

cdmtaskservice/nersc/manager.py#L401-L417

Added lines #L401 - L417 were not covered by tests
dt, path, bio=io.BytesIO(file.encode())
))
tg.create_task(sem_coro(coros[-1]))
except ExceptionGroup as eg:
e = eg.exceptions[0] # just pick one, essentially at random
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
e = eg.exceptions[0] # just pick one, essentially at random
for ex in eg.exceptions:
logging.getLogger(__name__).error("upload error", exc_info=ex)
e = eg.exceptions[0] # just pick one, essentially at random

wondering if we should log all exceptions here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me this seems like overkill for a couple reasons

  • If multiple coroutines fail, 99.99% of the time it's going to be for the same reason, and so this will just log the same exception over and over
  • In the rare case where coroutines fail for different reasons, it should be easy enough to fix the first reason, retry, fix the 2nd reason, etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
e = eg.exceptions[0] # just pick one, essentially at random
for ex in eg.exceptions:
logging.getLogger(__name__).error("upload error", exc_info=ex)
e = eg.exceptions[0] # just pick one, essentially at random

wondering if we should log all exceptions here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above

raise e from eg

Check warning on line 423 in cdmtaskservice/nersc/manager.py

View check run for this annotation

Codecov / codecov/patch

cdmtaskservice/nersc/manager.py#L420-L423

Added lines #L420 - L423 were not covered by tests
finally:
# otherwise you can get coroutine never awaited warnings if a failure occurs
for c in coros:
c.close()

Check warning on line 427 in cdmtaskservice/nersc/manager.py

View check run for this annotation

Codecov / codecov/patch

cdmtaskservice/nersc/manager.py#L426-L427

Added lines #L426 - L427 were not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure why we need to manually call close() on a coroutine obj. It's already wrapped by async with asyncio.TaskGroup() as tg: right?

Copy link
Member Author

@MrCreosote MrCreosote Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this setup from the remote code so it's been a while since I messed with it, but from what I recall I got warnings in the logs about unclosed or unawaited coroutines when I didn't close them manually. IIRC when a coroutine fails any remaining coroutines that haven't started running aren't run

# TODO NEXT run jaws job
return "fake_job_id"

def _get_manifest_file_paths(self, job_id: str, count: int) -> list[Path]:
Expand Down
Loading