Skip to content

Commit

Permalink
Added check_n_tasks decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Sep 19, 2024
1 parent 98a5fb4 commit e72f051
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
30 changes: 26 additions & 4 deletions src/nomad_simulations/schema_packages/workflow/base_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# limitations under the License.
#

from functools import wraps
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
Expand All @@ -32,6 +33,30 @@
from nomad_simulations.schema_packages.outputs import Outputs


def check_n_tasks(n_tasks: Optional[int] = None):
"""
Check if the `tasks` of a workflow exist. If the `n_tasks` input specified, it checks whether `tasks`
is of the same length as `n_tasks`.
Args:
n_tasks (Optional[int], optional): The length of the `tasks` needs to be checked if set to an integer. Defaults to None.
"""

def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.tasks:
return None
if n_tasks is not None and len(self.tasks) != n_tasks:
return None

return func(self, *args, **kwargs)

return wrapper

return decorator


class SimulationWorkflow(Workflow):
"""
A base section used to define the workflows of a simulation with references to specific `tasks`, `inputs`, and `outputs`. The
Expand Down Expand Up @@ -144,6 +169,7 @@ def resolve_beyonddft_method_ref(
class BeyondDFT(SimulationWorkflow):
method = SubSection(sub_section=BeyondDFTMethod.m_def)

@check_n_tasks()
def resolve_all_outputs(self) -> list[Outputs]:
"""
Resolves all the `Outputs` sections from the `tasks` in the workflow. This is useful when
Expand All @@ -153,10 +179,6 @@ def resolve_all_outputs(self) -> list[Outputs]:
Returns:
list[Outputs]: A list of all the `Outputs` sections from the `tasks`.
"""
# Initial check
if not self.tasks:
return []

# Populate the list of outputs from the last element in `tasks`
all_outputs = []
for task in self.tasks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
BeyondDFT,
BeyondDFTMethod,
)
from nomad_simulations.schema_packages.workflow.base_workflows import check_n_tasks


class DFTPlusTBMethod(BeyondDFTMethod):
Expand Down Expand Up @@ -69,6 +70,7 @@ class DFTPlusTB(BeyondDFT):
- `method`: references to the `ModelMethod` sections in the DFT and TB entries.
"""

@check_n_tasks(n_tasks=2)
def resolve_method(self) -> DFTPlusTBMethod:
"""
Resolves the `DFT` and `TB` `ModelMethod` references for the `tasks` in the workflow by using the
Expand All @@ -91,6 +93,7 @@ def resolve_method(self) -> DFTPlusTBMethod:

return method

@check_n_tasks(n_tasks=2)
def link_tasks(self) -> None:
"""
Links the `outputs` of the DFT task with the `inputs` of the TB task.
Expand Down Expand Up @@ -123,6 +126,7 @@ def link_tasks(self) -> None:
)
]

@check_n_tasks(n_tasks=2)
def overwrite_fermi_level(self) -> None:
"""
Overwrites the Fermi level in the TB calculation with the Fermi level from the DFT calculation.
Expand All @@ -139,7 +143,7 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)

# Initial check for the number of tasks
if len(self.tasks) != 2:
if not self.tasks or len(self.tasks) != 2:
logger.error('A `DFTPlusTB` workflow must have two tasks.')
return

Expand Down
2 changes: 1 addition & 1 deletion tests/workflow/test_base_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ class TestBeyondDFT:
'tasks, result',
[
# no task
(None, []),
(None, None),
# empty task
([Task()], []),
# task only contains inputs
Expand Down

0 comments on commit e72f051

Please sign in to comment.