Skip to content

Commit

Permalink
Add OAR scheduler support
Browse files Browse the repository at this point in the history
  • Loading branch information
petitalb committed Aug 17, 2023
1 parent 29d3d1f commit b8c39fd
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 0 deletions.
80 changes: 80 additions & 0 deletions pydra/engine/tests/test_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .utils import (
need_sge,
need_slurm,
need_oar,
gen_basic_wf,
gen_basic_wf_with_threadcount,
gen_basic_wf_with_threadcount_concurrent,
Expand Down Expand Up @@ -573,6 +574,85 @@ def test_sge_no_limit_maxthreads(tmpdir):
assert job_1_endtime > job_2_starttime


@need_oar
def test_oar_wf(tmpdir):
wf = gen_basic_wf()
wf.cache_dir = tmpdir
# submit workflow and every task as oar job
with Submitter("oar") as sub:
sub(wf)

res = wf.result()
assert res.output.out == 9
script_dir = tmpdir / "OarWorker_scripts"
assert script_dir.exists()
# ensure each task was executed with oar
assert len([sd for sd in script_dir.listdir() if sd.isdir()]) == 2


@need_oar
def test_oar_wf_cf(tmpdir):
# submit entire workflow as single job executing with cf worker
wf = gen_basic_wf()
wf.cache_dir = tmpdir
wf.plugin = "cf"
with Submitter("oar") as sub:
sub(wf)
res = wf.result()
assert res.output.out == 9
script_dir = tmpdir / "OarWorker_scripts"
assert script_dir.exists()
# ensure only workflow was executed with oar
sdirs = [sd for sd in script_dir.listdir() if sd.isdir()]
assert len(sdirs) == 1
# oar scripts should be in the dirs that are using uid in the name
assert sdirs[0].basename == wf.uid


@need_oar
def test_oar_wf_state(tmpdir):
wf = gen_basic_wf()
wf.split("x", x=[5, 6])
wf.cache_dir = tmpdir
with Submitter("oar") as sub:
sub(wf)
res = wf.result()
assert res[0].output.out == 9
assert res[1].output.out == 10
script_dir = tmpdir / "OarWorker_scripts"
assert script_dir.exists()
sdirs = [sd for sd in script_dir.listdir() if sd.isdir()]
assert len(sdirs) == 2 * len(wf.inputs.x)


@need_oar
def test_oar_args_1(tmpdir):
"""testing sbatch_args provided to the submitter"""
task = sleep_add_one(x=1)
task.cache_dir = tmpdir
# submit workflow and every task as oar job
with Submitter("oar", oarsub_args="-l nodes=2") as sub:
sub(task)

res = task.result()
assert res.output.out == 2
script_dir = tmpdir / "OarWorker_scripts"
assert script_dir.exists()


@need_oar
def test_oar_args_2(tmpdir):
"""testing oarsub_args provided to the submitter
exception should be raised for invalid options
"""
task = sleep_add_one(x=1)
task.cache_dir = tmpdir
# submit workflow and every task as oar job
with pytest.raises(RuntimeError, match="Error returned from oarsub:"):
with Submitter("oar", oarsub_args="-l nodes=2 --invalid") as sub:
sub(task)


# @pytest.mark.xfail(reason="Not sure")
def test_wf_with_blocked_tasks(tmpdir):
wf = Workflow(name="wf_with_blocked_tasks", input_spec=["x"])
Expand Down
4 changes: 4 additions & 0 deletions pydra/engine/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
not (bool(shutil.which("qsub")) and bool(shutil.which("qacct"))),
reason="sge not available",
)
need_oar = pytest.mark.skipif(
not (bool(shutil.which("oarsub")) and bool(shutil.which("oarstat"))),
reason="oar not available",
)


def result_no_submitter(shell_task, plugin=None):
Expand Down
169 changes: 169 additions & 0 deletions pydra/engine/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import sys
import json
import os
import re
from tempfile import gettempdir
from pathlib import Path
Expand Down Expand Up @@ -186,6 +187,173 @@ def close(self):
self.pool.shutdown()


class OarWorker(DistributedWorker):
"""A worker to execute tasks on OAR systems."""

_cmd = "oarsub"

def __init__(self, loop=None, max_jobs=None, poll_delay=1, oarsub_args=None):
"""
Initialize OAR Worker.
Parameters
----------
poll_delay : seconds
Delay between polls to oar
oarsub_args : str
Additional oarsub arguments
max_jobs : int
Maximum number of submitted jobs
"""
super().__init__(loop=loop, max_jobs=max_jobs)
if not poll_delay or poll_delay < 0:
poll_delay = 0
self.poll_delay = poll_delay
self.oarsub_args = oarsub_args or ""
self.error = {}

def run_el(self, runnable, rerun=False):
"""Worker submission API."""
script_dir, batch_script = self._prepare_runscripts(runnable, rerun=rerun)
if (script_dir / script_dir.parts[1]) == gettempdir():
logger.warning("Temporary directories may not be shared across computers")
if isinstance(runnable, TaskBase):
cache_dir = runnable.cache_dir
name = runnable.name
uid = runnable.uid
else: # runnable is a tuple (ind, pkl file, task)
cache_dir = runnable[-1].cache_dir
name = runnable[-1].name
uid = f"{runnable[-1].uid}_{runnable[0]}"

return self._submit_job(batch_script, name=name, uid=uid, cache_dir=cache_dir)

def _prepare_runscripts(self, task, interpreter="/bin/sh", rerun=False):
if isinstance(task, TaskBase):
cache_dir = task.cache_dir
ind = None
uid = task.uid
else:
ind = task[0]
cache_dir = task[-1].cache_dir
uid = f"{task[-1].uid}_{ind}"

script_dir = cache_dir / f"{self.__class__.__name__}_scripts" / uid
script_dir.mkdir(parents=True, exist_ok=True)
if ind is None:
if not (script_dir / "_task.pkl").exists():
save(script_dir, task=task)
else:
copyfile(task[1], script_dir / "_task.pklz")

task_pkl = script_dir / "_task.pklz"
if not task_pkl.exists() or not task_pkl.stat().st_size:
raise Exception("Missing or empty task!")

batchscript = script_dir / f"batchscript_{uid}.sh"
python_string = (
f"""'from pydra.engine.helpers import load_and_run; """
f"""load_and_run(task_pkl="{task_pkl}", ind={ind}, rerun={rerun}) '"""
)
bcmd = "\n".join(
(
f"#!{interpreter}",
f"{sys.executable} -c " + python_string,
)
)
with batchscript.open("wt") as fp:
fp.writelines(bcmd)
os.chmod(batchscript, 0o544)
return script_dir, batchscript

async def _submit_job(self, batchscript, name, uid, cache_dir):
"""Coroutine that submits task runscript and polls job until completion or error."""
script_dir = cache_dir / f"{self.__class__.__name__}_scripts" / uid
sargs = self.oarsub_args.split()
jobname = re.search(r"(?<=-n )\S+|(?<=--name=)\S+", self.oarsub_args)
if not jobname:
jobname = ".".join((name, uid))
sargs.append(f"--name={jobname}")
output = re.search(r"(?<=-O )\S+|(?<=--stdout=)\S+", self.oarsub_args)
if not output:
output_file = str(script_dir / "oar-%jobid%.out")
sargs.append(f"--stdout={output_file}")
error = re.search(r"(?<=-E )\S+|(?<=--stderr=)\S+", self.oarsub_args)
if not error:
error_file = str(script_dir / "oar-%jobid%.err")
sargs.append(f"--stderr={error_file}")
else:
error_file = None
sargs.append(str(batchscript))
# TO CONSIDER: add random sleep to avoid overloading calls
logger.debug(f"Submitting job {' '.join(sargs)}")
rc, stdout, stderr = await read_and_display_async(
self._cmd, *sargs, hide_display=True
)
jobid = re.search(r"OAR_JOB_ID=(\d+)", stdout)
if rc:
raise RuntimeError(f"Error returned from oarsub: {stderr}")
elif not jobid:
raise RuntimeError("Could not extract job ID")
jobid = jobid.group(1)
if error_file:
error_file = error_file.replace("%jobid%", jobid)
self.error[jobid] = error_file.replace("%jobid%", jobid)
# intermittent polling
while True:
# 4 possibilities
# False: job is still pending/working
# Terminated: job is complete
# Error + idempotent: job has been stopped and resubmited with another jobid
# Error: Job failure
done = await self._poll_job(jobid)
if not done:
await asyncio.sleep(self.poll_delay)
elif done == "Terminated":
return True
elif done == "Error" and "idempotent" in self.oarsub_args:
logger.debug(
f"Job {jobid} has been stopped. Looking for its resubmission..."
)
# loading info about task with a specific uid
info_file = cache_dir / f"{uid}_info.json"
if info_file.exists():
checksum = json.loads(info_file.read_text())["checksum"]
if (cache_dir / f"{checksum}.lock").exists():
# for pyt3.8 we could you missing_ok=True
(cache_dir / f"{checksum}.lock").unlink()
cmd_re = ("oarstat", "-J", "--sql", f"resubmit_job_id='{jobid}'")
_, stdout, _ = await read_and_display_async(*cmd_re, hide_display=True)
if not stdout:
raise RuntimeError(
"Job information about resubmission of job {jobid} not found"
)
jobid = next(iter(json.loads(stdout).keys()), None)
else:
error_file = self.error[jobid]
error_line = Path(error_file).read_text().split("\n")[-2]
if "Exception" in error_line:
error_message = error_line.replace("Exception: ", "")
elif "Error" in error_line:
error_message = error_line.replace("Error: ", "")
else:
error_message = "Job failed (unknown reason - TODO)"
raise Exception(error_message)
return True

async def _poll_job(self, jobid):
cmd = ("oarstat", "-J", "-s", "-j", jobid)
logger.debug(f"Polling job {jobid}")
_, stdout, _ = await read_and_display_async(*cmd, hide_display=True)
if not stdout:
raise RuntimeError("Job information not found")
status = json.loads(stdout)[jobid]
if status in ["Waiting", "Launching", "Running", "Finishing"]:
return False
return status


class SlurmWorker(DistributedWorker):
"""A worker to execute tasks on SLURM systems."""

Expand Down Expand Up @@ -894,4 +1062,5 @@ def close(self):
"slurm": SlurmWorker,
"dask": DaskWorker,
"sge": SGEWorker,
"oar": OarWorker,
}

0 comments on commit b8c39fd

Please sign in to comment.