diff --git a/src/datatrove/executor/slurm.py b/src/datatrove/executor/slurm.py index 611818a6..777a2cfb 100644 --- a/src/datatrove/executor/slurm.py +++ b/src/datatrove/executor/slurm.py @@ -61,6 +61,8 @@ class SlurmPipelineExecutor(PipelineExecutor): depends: another SlurmPipelineExecutor that should run before this one depends_job_id: alternatively to the above, you can pass the job id of a dependency + job_id_position: position of job ID in custom Sbatch outputs. + default: -1 logging_dir: where to save logs, stats, etc. Should be parsable into a datatrove.io.DataFolder skip_completed: whether to skip tasks that were completed in previous runs. default: True @@ -99,6 +101,7 @@ def __init__( max_array_size: int = 1001, depends: SlurmPipelineExecutor | None = None, depends_job_id: str | None = None, + job_id_position: int = -1, logging_dir: DataFolderLike = None, skip_completed: bool = True, slurm_logs_folder: str = None, @@ -128,6 +131,7 @@ def __init__( self.venv_path = venv_path self.depends = depends self.depends_job_id = depends_job_id + self.job_id_position = job_id_position self._sbatch_args = sbatch_args if sbatch_args else {} self.max_array_size = max_array_size self.max_array_launch_parallel = max_array_launch_parallel @@ -198,7 +202,8 @@ def launch_merge_stats(self): }, f'merge_stats {self.logging_dir.resolve_paths("stats")} ' f'-o {self.logging_dir.resolve_paths("stats.json")}', - ) + ), + self.job_id_position, ) @property @@ -277,7 +282,7 @@ def launch_job(self): args = [f"--export=ALL,RUN_OFFSET={launched_jobs}"] if self.dependency: args.append(f"--dependency={self.dependency}") - self.job_id = launch_slurm_job(launch_file_contents, *args) + self.job_id = launch_slurm_job(launch_file_contents, self.job_id_position, *args) launched_jobs += 1 logger.info(f"Slurm job launched successfully with (last) id={self.job_id}.") self.launch_merge_stats() @@ -355,11 +360,12 @@ def world_size(self) -> int: return self.tasks -def launch_slurm_job(launch_file_contents, *args): +def launch_slurm_job(launch_file_contents, job_id_position, *args): """ Small helper function to save a sbatch script and call it. Args: launch_file_contents: Contents of the sbatch script + job_id_position: Index of dependecy job ID. *args: any other arguments to pass to the sbatch command Returns: the id of the launched slurm job @@ -368,4 +374,4 @@ def launch_slurm_job(launch_file_contents, *args): with tempfile.NamedTemporaryFile("w") as f: f.write(launch_file_contents) f.flush() - return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[-1] + return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[job_id_position]