diff --git a/docs/tutorial/creating_workflows.md b/docs/tutorial/creating_workflows.md index 2aa94b55..89cecca6 100644 --- a/docs/tutorial/creating_workflows.md +++ b/docs/tutorial/creating_workflows.md @@ -89,7 +89,8 @@ Currently, the following `Operator`s are maintained: ### `JIDSlurmOperator` arguments - `task_id`: This is nominally the name of the task on the Airflow side. However, for simplicity this is used 1-1 to match the name of a **managed** Task defined in LUTE's `managed_tasks.py` module. I.e., it should the name of an `Executor("Task")` object which will run the specific Task of interest. This **must** match the name of a defined managed Task. -- `max_cores`: Used to cap the maximum number of cores which should be requested of SLURM. By default all jobs will run with the same number of cores, which should be specified when running the `launch_airflow.py` script (either from the ARP, or by hand). This behaviour was chosen because in general we want to increase or decrease the core-count for all Tasks uniformly, and we don't want to have to specify core number arguments for each job individually. Nonetheless, on occassion it may be necessary to cap the number of cores a specific job will use. E.g. if the default value specified when launching the Airflow DAG is multiple cores, and one job is single threaded, the core count can be capped for that single job to 1, while the rest run with multiple cores. +- `max_cores`: Used to cap the maximum number of cores which should be requested of SLURM. By default all jobs will run with the same number of cores, which should be specified when running the `launch_airflow.py` script (either from the ARP, or by hand). This behaviour was chosen because in general we want to increase or decrease the core-count for all `Task`s uniformly, and we don't want to have to specify core number arguments for each job individually. Nonetheless, on occassion it may be necessary to cap the number of cores a specific job will use. E.g. if the default value specified when launching the Airflow DAG is multiple cores, and one job is single threaded, the core count can be capped for that single job to 1, while the rest run with multiple cores. +- `max_nodes`: Similar to the above. This will make sure the `Task` is distributed across no more than a maximum number of nodes. This feature is useful for, e.g., multi-threaded software which does not make use of tools like `MPI`. So, the `Task` can run on multiple cores, but only within a single node. # Creating a new workflow diff --git a/workflows/airflow/find_peaks_index.py b/workflows/airflow/find_peaks_index.py index 2335c090..3c512474 100644 --- a/workflows/airflow/find_peaks_index.py +++ b/workflows/airflow/find_peaks_index.py @@ -31,7 +31,7 @@ peak_finder: JIDSlurmOperator = JIDSlurmOperator(task_id="PeakFinderPsocake", dag=dag) indexer: JIDSlurmOperator = JIDSlurmOperator( - max_cores=120, task_id="CrystFELIndexer", dag=dag + max_cores=120, max_nodes=1, task_id="CrystFELIndexer", dag=dag ) peak_finder >> indexer diff --git a/workflows/airflow/operators/jidoperators.py b/workflows/airflow/operators/jidoperators.py index 6367d698..01e05c3b 100644 --- a/workflows/airflow/operators/jidoperators.py +++ b/workflows/airflow/operators/jidoperators.py @@ -121,6 +121,7 @@ def __init__( user: str = getpass.getuser(), poke_interval: float = 30.0, max_cores: Optional[int] = None, + max_nodes: Optional[int] = None, *args, **kwargs, ) -> None: @@ -129,6 +130,53 @@ def __init__( self.user: str = user self.poke_interval: float = poke_interval self.max_cores: Optional[int] = max_cores + self.max_nodes: Optional[int] = max_nodes + + def _sub_overridable_arguments(self, slurm_param_str: str) -> str: + """Overrides certain SLURM arguments given instance options. + + Since the same SLURM arguments are used by default for the entire DAG, + individual Operator instances can override some important ones if they + are passed at instantiation. + + ASSUMES `=` is used with SLURM arguments! E.g. --ntasks=12, --nodes=0-4 + + Args: + slurm_param_str (str): Constructed string of DAG SLURM arguments + without modification + Returns: + slurm_param_str (str): Modified SLURM argument string. + """ + # Cap max cores used by a managed Task if that is requested + # Only search for part after `=` since this will usually be passed + if self.max_cores is not None: + pattern: str = r"(?<=\bntasks=)\d+" + ntasks: int + try: + ntasks = int(re.findall(pattern, slurm_param_str)[0]) + if ntasks > self.max_cores: + slurm_param_str = re.sub( + pattern, f"{self.max_cores}", slurm_param_str + ) + except IndexError: # If `ntasks` not passed - 1 is default + ntasks = 1 + slurm_param_str = f"{slurm_param_str} --ntasks={ntasks}" + + # Cap max nodes. Unlike above search for everything, if not present, add it. + if self.max_nodes is not None: + pattern = r"nodes=\S+" + nnodes_str: str + try: + nnodes_str = re.findall(pattern, slurm_param_str)[0] + # Check if present with above. Below does nothing but does not + # throw error if pattern not present. + slurm_param_str = re.sub( + pattern, f"nodes=0-{self.max_nodes}", slurm_param_str + ) + except IndexError: # `--nodes` not present + slurm_param_str = f"{slurm_param_str} --nodes=0-{self.max_nodes}" + + return slurm_param_str def create_control_doc( self, context: Dict[str, Any] diff --git a/workflows/airflow/psocake_sfx_phasing.py b/workflows/airflow/psocake_sfx_phasing.py index 433778a9..7bc62bf4 100644 --- a/workflows/airflow/psocake_sfx_phasing.py +++ b/workflows/airflow/psocake_sfx_phasing.py @@ -35,7 +35,7 @@ peak_finder: JIDSlurmOperator = JIDSlurmOperator(task_id="PeakFinderPsocake", dag=dag) indexer: JIDSlurmOperator = JIDSlurmOperator( - max_cores=120, task_id="CrystFELIndexer", dag=dag + max_cores=120, max_nodes=1, task_id="CrystFELIndexer", dag=dag ) # Concatenate stream files from all previous runs with same tag @@ -45,17 +45,17 @@ # Merge merger: JIDSlurmOperator = JIDSlurmOperator( - max_cores=120, task_id="PartialatorMerger", dag=dag + max_cores=120, max_nodes=1, task_id="PartialatorMerger", dag=dag ) # Figures of merit hkl_comparer: JIDSlurmOperator = JIDSlurmOperator( - max_cores=8, task_id="HKLComparer", dag=dag + max_cores=8, max_nodes=1, task_id="HKLComparer", dag=dag ) # HKL conversions hkl_manipulator: JIDSlurmOperator = JIDSlurmOperator( - max_cores=8, task_id="HKLManipulator", dag=dag + max_cores=8, max_nodes=1, task_id="HKLManipulator", dag=dag ) # SHELX Tasks