Skip to content

Commit

Permalink
Merge pull request #148 from kbase/dev-service
Browse files Browse the repository at this point in the history
Submit job to JAWS
  • Loading branch information
MrCreosote authored Jan 6, 2025
2 parents 3e959f7 + 4eeb567 commit 69a2a0b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 21 deletions.
4 changes: 2 additions & 2 deletions cdmtaskservice/jobflows/nersc_jaws.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ async def download_complete(self, job: models.AdminJobDetails):
models.JobState.JOB_SUBMITTING,
timestamp.utcdatetime()
)
await self._coman.run_coroutine(self._download_complete(job))
await self._coman.run_coroutine(self._submit_jaws_job(job))

async def _download_complete(self, job: models.AdminJobDetails):
async def _submit_jaws_job(self, job: models.AdminJobDetails):
logr = logging.getLogger(__name__)
try:
# TODO PERF configure file download concurrency
Expand Down
82 changes: 63 additions & 19 deletions cdmtaskservice/nersc/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
from pathlib import Path
from sfapi_client import AsyncClient
from sfapi_client.exceptions import SfApiError
from sfapi_client.paths import AsyncRemotePath
from sfapi_client.compute import Machine, AsyncCompute
import sys
Expand Down Expand Up @@ -46,17 +47,26 @@
_SEC_PER_GB = 2 * 60 # may want to make this configurable

_CTS_SCRATCH_ROOT_DIR = Path("cdm_task_service")
_JOB_FILES = "files"
_MANIFESTS = "manifests"
_JOB_FILES = Path("files")
_JOB_MANIFESTS = Path("manifests")
_MANIFEST_FILE_PREFIX = "manifest-"


_JAWS_CONF_FILENAME = "jaws.conf"
_JAWS_CONF_FILENAME = "jaws_cts.conf"
_JAWS_CONF_TEMPLATE = """
[USER]
token = {token}
default_team = {group}
"""
_JAWS_COMMAND_TEMPLATE = f"""
module use /global/cfs/projectdirs/kbase/jaws/modulefiles
module load jaws
export JAWS_USER_CONFIG=~/{_JAWS_CONF_FILENAME}
jaws submit --quiet --tag {{job_id}} {{wdlpath}} {{inputjsonpath}} {{site}}
"""
_JAWS_SITE_PERLMUTTER = "kbase" # add lawrencium later, maybe
_JAWS_INPUT_WDL = "input.wdl"
_JAWS_INPUT_JSON = "input.json"


# TODO PROD add start and end time to task output and record
Expand Down Expand Up @@ -184,7 +194,8 @@ async def _setup_remote_code(self, file_group: str, jaws_token: str, jaws_group:
),
chmod = "600"
))
scratch = tg.create_task(self._set_up_dtn_scratch(cli, file_group))
pm_scratch = tg.create_task(perlmutter.run("echo $SCRATCH"))
dtn_scratch = tg.create_task(self._set_up_dtn_scratch(cli, file_group))
if _PIP_DEPENDENCIES:
deps = " ".join(
# may need to do something else if module doesn't have __version__
Expand All @@ -198,7 +209,11 @@ async def _setup_remote_code(self, file_group: str, jaws_token: str, jaws_group:
+ f"pip install {deps}" # adding notapackage causes a failure
)
tg.create_task(dt.run(command))
self._dtn_scratch = scratch.result()
self._dtn_scratch = dtn_scratch.result()
self._perlmutter_scratch = Path(pm_scratch.result().strip())
logging.getLogger(__name__).info(
f"NERSC perlmutter scratch path: {self._perlmutter_scratch}"
)

async def _set_up_dtn_scratch(self, client: AsyncClient, file_group: str) -> Path:
dt = await client.compute(_DT_TARGET)
Expand All @@ -208,7 +223,7 @@ async def _set_up_dtn_scratch(self, client: AsyncClient, file_group: str) -> Pat
raise ValueError("Unable to determine $SCRATCH variable for NERSC dtns")
logging.getLogger(__name__).info(f"NERSC DTN scratch path: {scratch}")
await dt.run(
f"{_DT_WORKAROUND}; set -e; chgrp {file_group} {scratch}; chmod g+rs {scratch}"
f"{_DT_WORKAROUND}; set -e; chgrp {file_group} {scratch}; chmod g+rsx {scratch}"
)
return Path(scratch)

Expand Down Expand Up @@ -394,26 +409,58 @@ async def run_JAWS(self, job: models.Job, file_download_concurrency: int = 10) -
_check_int(file_download_concurrency, "file_download_concurrency")
if not _not_falsy(job, "job").job_input.inputs_are_S3File():
raise ValueError("Job files must be S3File objects")
cli = self._client_provider()
await self._generate_and_load_job_files_to_nersc(cli, job, file_download_concurrency)
perl = await cli.compute(Machine.perlmutter)
pre = self._perlmutter_scratch / _CTS_SCRATCH_ROOT_DIR / job.id
try:
res = await perl.run(_JAWS_COMMAND_TEMPLATE.format(
job_id=job.id,
wdlpath=pre / _JAWS_INPUT_WDL,
inputjsonpath=pre / _JAWS_INPUT_JSON,
site=_JAWS_SITE_PERLMUTTER
))
except SfApiError as e:
# TODO ERRORHANDLING if jaws provides valid json parse it and return just the detail
#try:
# j = json.loads(f"{e}")
# if "detail" in j:
# raise ValueError(f"JAWS error: {j['detail']}") from e
raise ValueError(f"JAWS error: {e}") from e
#except json.JSONDecodeError as je:
# raise ValueError(f"JAWS returned invalid JSON ({je}) in error: {e}") from e
try:
j = json.loads(res)
if "run_id" not in j:
raise ValueError(f"JAWS returned no run_id in JSON {res}")
run_id = j["run_id"]
logging.getLogger(__name__).info(
f"Submitted JAWS job with run id {run_id} for job {job.id}"
)
return run_id
except json.JSONDecodeError as e:
raise ValueError(f"JAWS returned invalid JSON: {e}\n{res}") from e

async def _generate_and_load_job_files_to_nersc(
self, cli: AsyncClient, job: models.Job, concurrency: int
):
manifest_files = generate_manifest_files(job)
manifest_file_paths = self._get_manifest_file_paths(job.id, len(manifest_files))
fmap = {m: self._localize_s3_path(job.id, m.file) for m in job.job_input.input_files}
fmap = {m: _JOB_FILES / m.file for m in job.job_input.input_files}
wdljson = wdl.generate_wdl(job, fmap, manifest_file_paths)
uploads = {fp: f for fp, f in zip(manifest_file_paths, manifest_files)}
pre = self._dtn_scratch / _CTS_SCRATCH_ROOT_DIR / job.id
wdlpath = pre / "input.wdl"
jsonpath = pre / "input.json"
uploads[wdlpath] = wdljson.wdl
uploads[jsonpath] = json.dumps(wdljson.input_json, indent=4)
cli = self._client_provider()
downloads = {pre / fp: f for fp, f in zip(manifest_file_paths, manifest_files)}
downloads[pre / _JAWS_INPUT_WDL] = wdljson.wdl
downloads[pre / _JAWS_INPUT_JSON] = json.dumps(wdljson.input_json, indent=4)
dt = await cli.compute(_DT_TARGET)
semaphore = asyncio.Semaphore(file_download_concurrency)
semaphore = asyncio.Semaphore(concurrency)
async def sem_coro(coro):
async with semaphore:
return await coro
coros = []
try:
async with asyncio.TaskGroup() as tg:
for path, file in uploads.items():
for path, file in downloads.items():
coros.append(self._upload_file_to_nersc(
dt, path, bio=io.BytesIO(file.encode())
))
Expand All @@ -425,11 +472,8 @@ async def sem_coro(coro):
# otherwise you can get coroutine never awaited warnings if a failure occurs
for c in coros:
c.close()
# TODO NEXT run jaws job
return "fake_job_id"

def _get_manifest_file_paths(self, job_id: str, count: int) -> list[Path]:
if count == 0:
return []
pre = self._dtn_scratch / _CTS_SCRATCH_ROOT_DIR / job_id / _MANIFESTS
return [pre / f"{_MANIFEST_FILE_PREFIX}{c}" for c in range(1, count + 1)]
return [_JOB_MANIFESTS / f"{_MANIFEST_FILE_PREFIX}{c}" for c in range(1, count + 1)]

0 comments on commit 69a2a0b

Please sign in to comment.