Skip to content

Commit

Permalink
Add job_id_position Parameter to launch_slurm_job Method (#282)
Browse files Browse the repository at this point in the history
* Added changes for handling custom Slurm Sbatch outputs for jobs requiring dependecies

* Applied ruff formatting to changes
  • Loading branch information
StephenRebel authored Jan 9, 2025
1 parent 338b3ad commit 2fc7660
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/datatrove/executor/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class SlurmPipelineExecutor(PipelineExecutor):
depends: another SlurmPipelineExecutor that should run
before this one
depends_job_id: alternatively to the above, you can pass the job id of a dependency
job_id_position: position of job ID in custom Sbatch outputs.
default: -1
logging_dir: where to save logs, stats, etc. Should be parsable into a datatrove.io.DataFolder
skip_completed: whether to skip tasks that were completed in
previous runs. default: True
Expand Down Expand Up @@ -99,6 +101,7 @@ def __init__(
max_array_size: int = 1001,
depends: SlurmPipelineExecutor | None = None,
depends_job_id: str | None = None,
job_id_position: int = -1,
logging_dir: DataFolderLike = None,
skip_completed: bool = True,
slurm_logs_folder: str = None,
Expand Down Expand Up @@ -128,6 +131,7 @@ def __init__(
self.venv_path = venv_path
self.depends = depends
self.depends_job_id = depends_job_id
self.job_id_position = job_id_position
self._sbatch_args = sbatch_args if sbatch_args else {}
self.max_array_size = max_array_size
self.max_array_launch_parallel = max_array_launch_parallel
Expand Down Expand Up @@ -198,7 +202,8 @@ def launch_merge_stats(self):
},
f'merge_stats {self.logging_dir.resolve_paths("stats")} '
f'-o {self.logging_dir.resolve_paths("stats.json")}',
)
),
self.job_id_position,
)

@property
Expand Down Expand Up @@ -277,7 +282,7 @@ def launch_job(self):
args = [f"--export=ALL,RUN_OFFSET={launched_jobs}"]
if self.dependency:
args.append(f"--dependency={self.dependency}")
self.job_id = launch_slurm_job(launch_file_contents, *args)
self.job_id = launch_slurm_job(launch_file_contents, self.job_id_position, *args)
launched_jobs += 1
logger.info(f"Slurm job launched successfully with (last) id={self.job_id}.")
self.launch_merge_stats()
Expand Down Expand Up @@ -355,11 +360,12 @@ def world_size(self) -> int:
return self.tasks


def launch_slurm_job(launch_file_contents, *args):
def launch_slurm_job(launch_file_contents, job_id_position, *args):
"""
Small helper function to save a sbatch script and call it.
Args:
launch_file_contents: Contents of the sbatch script
job_id_position: Index of dependecy job ID.
*args: any other arguments to pass to the sbatch command
Returns: the id of the launched slurm job
Expand All @@ -368,4 +374,4 @@ def launch_slurm_job(launch_file_contents, *args):
with tempfile.NamedTemporaryFile("w") as f:
f.write(launch_file_contents)
f.flush()
return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[-1]
return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[job_id_position]

0 comments on commit 2fc7660

Please sign in to comment.