Skip to content

Commit

Permalink
Add testing for BeyondDFT workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Sep 18, 2024
1 parent 91dbb86 commit e14d3e2
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,15 @@ def resolve_all_outputs(self) -> list[Outputs]:
Returns:
list[Outputs]: A list of all the `Outputs` sections from the `tasks`.
"""
# Initial check
if not self.tasks:
return []

# Populate the list of outputs from the last element in `tasks`
all_outputs = []
for task in self.tasks:
if not task.outputs:
continue
all_outputs.append(task.outputs[-1])
return all_outputs

Expand Down
87 changes: 83 additions & 4 deletions tests/workflow/test_base_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@

import pytest
from nomad.datamodel import EntryArchive
from nomad.datamodel.metainfo.workflow import Link, Task, Workflow
from nomad.datamodel.metainfo.workflow import Link, Task

from nomad_simulations.schema_packages.model_method import BaseModelMethod, ModelMethod
from nomad_simulations.schema_packages.model_system import ModelSystem
from nomad_simulations.schema_packages.outputs import Outputs
from nomad_simulations.schema_packages.workflow import (
BeyondDFT,
BeyondDFTMethod,
BeyondDFTWorkflow,
SimulationWorkflow,
)

Expand Down Expand Up @@ -211,5 +211,84 @@ def test_resolve_beyonddft_method_ref(


class TestBeyondDFT:
def test_resolve_all_outputs(self):
assert True
@pytest.mark.parametrize(
'tasks, result',
[
# no task
(None, []),
# empty task
([Task()], []),
# task only contains inputs
(
[Task(inputs=[Link(name='Input Model System', section=ModelSystem())])],
[],
),
# one task with one output
(
[Task(outputs=[Link(name='Output Data 1', section=Outputs())])],
[Link(name='Output Data 1', section=Outputs())],
),
# one task with multiple outputs (only last is resolved)
(
[
Task(
outputs=[
Link(name='Output Data 1', section=Outputs()),
Link(name='Output Data 2', section=Outputs()),
]
)
],
[Link(name='Output Data 2', section=Outputs())],
),
# multiple task with one output each
(
[
Task(
outputs=[Link(name='Task 1:Output Data 1', section=Outputs())]
),
Task(
outputs=[Link(name='Task 2:Output Data 1', section=Outputs())]
),
],
[
Link(name='Task 1:Output Data 1', section=Outputs()),
Link(name='Task 2:Output Data 1', section=Outputs()),
],
),
# multiple task with two outputs each (only last is resolved)
(
[
Task(
outputs=[
Link(name='Task 1:Output Data 1', section=Outputs()),
Link(name='Task 1:Output Data 2', section=Outputs()),
]
),
Task(
outputs=[
Link(name='Task 2:Output Data 1', section=Outputs()),
Link(name='Task 2:Output Data 2', section=Outputs()),
]
),
],
[
Link(name='Task 1:Output Data 2', section=Outputs()),
Link(name='Task 2:Output Data 2', section=Outputs()),
],
),
],
)
def test_resolve_all_outputs(
self, tasks: Optional[list[Task]], result: list[Outputs]
):
"""
Test the `resolve_all_outputs` method of the `BeyondDFT` section.
"""
workflow = BeyondDFT()
if tasks is not None:
workflow.tasks = tasks
if result is not None:
for i, output in enumerate(workflow.resolve_all_outputs()):
assert output.name == result[i].name
else:
assert workflow.resolve_all_outputs() == result

0 comments on commit e14d3e2

Please sign in to comment.