From e72f0512787bde527cfb4ec61182f793545071a4 Mon Sep 17 00:00:00 2001 From: JosePizarro3 Date: Thu, 19 Sep 2024 10:40:15 +0200 Subject: [PATCH] Added check_n_tasks decorator --- .../workflow/base_workflows.py | 30 ++++++++++++++++--- .../schema_packages/workflow/dft_plus_tb.py | 6 +++- tests/workflow/test_base_workflows.py | 2 +- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/nomad_simulations/schema_packages/workflow/base_workflows.py b/src/nomad_simulations/schema_packages/workflow/base_workflows.py index b518fa54..4e76fb7b 100644 --- a/src/nomad_simulations/schema_packages/workflow/base_workflows.py +++ b/src/nomad_simulations/schema_packages/workflow/base_workflows.py @@ -17,6 +17,7 @@ # limitations under the License. # +from functools import wraps from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: @@ -32,6 +33,30 @@ from nomad_simulations.schema_packages.outputs import Outputs +def check_n_tasks(n_tasks: Optional[int] = None): + """ + Check if the `tasks` of a workflow exist. If the `n_tasks` input specified, it checks whether `tasks` + is of the same length as `n_tasks`. + + Args: + n_tasks (Optional[int], optional): The length of the `tasks` needs to be checked if set to an integer. Defaults to None. + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self.tasks: + return None + if n_tasks is not None and len(self.tasks) != n_tasks: + return None + + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + class SimulationWorkflow(Workflow): """ A base section used to define the workflows of a simulation with references to specific `tasks`, `inputs`, and `outputs`. The @@ -144,6 +169,7 @@ def resolve_beyonddft_method_ref( class BeyondDFT(SimulationWorkflow): method = SubSection(sub_section=BeyondDFTMethod.m_def) + @check_n_tasks() def resolve_all_outputs(self) -> list[Outputs]: """ Resolves all the `Outputs` sections from the `tasks` in the workflow. This is useful when @@ -153,10 +179,6 @@ def resolve_all_outputs(self) -> list[Outputs]: Returns: list[Outputs]: A list of all the `Outputs` sections from the `tasks`. """ - # Initial check - if not self.tasks: - return [] - # Populate the list of outputs from the last element in `tasks` all_outputs = [] for task in self.tasks: diff --git a/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py b/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py index e75d06bf..b6b0770a 100644 --- a/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py +++ b/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py @@ -32,6 +32,7 @@ BeyondDFT, BeyondDFTMethod, ) +from nomad_simulations.schema_packages.workflow.base_workflows import check_n_tasks class DFTPlusTBMethod(BeyondDFTMethod): @@ -69,6 +70,7 @@ class DFTPlusTB(BeyondDFT): - `method`: references to the `ModelMethod` sections in the DFT and TB entries. """ + @check_n_tasks(n_tasks=2) def resolve_method(self) -> DFTPlusTBMethod: """ Resolves the `DFT` and `TB` `ModelMethod` references for the `tasks` in the workflow by using the @@ -91,6 +93,7 @@ def resolve_method(self) -> DFTPlusTBMethod: return method + @check_n_tasks(n_tasks=2) def link_tasks(self) -> None: """ Links the `outputs` of the DFT task with the `inputs` of the TB task. @@ -123,6 +126,7 @@ def link_tasks(self) -> None: ) ] + @check_n_tasks(n_tasks=2) def overwrite_fermi_level(self) -> None: """ Overwrites the Fermi level in the TB calculation with the Fermi level from the DFT calculation. @@ -139,7 +143,7 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) # Initial check for the number of tasks - if len(self.tasks) != 2: + if not self.tasks or len(self.tasks) != 2: logger.error('A `DFTPlusTB` workflow must have two tasks.') return diff --git a/tests/workflow/test_base_workflows.py b/tests/workflow/test_base_workflows.py index 80011c6e..da6797fb 100644 --- a/tests/workflow/test_base_workflows.py +++ b/tests/workflow/test_base_workflows.py @@ -260,7 +260,7 @@ class TestBeyondDFT: 'tasks, result', [ # no task - (None, []), + (None, None), # empty task ([Task()], []), # task only contains inputs