Skip to content

Commit

Permalink
Fix resolve_inputs_outputs method
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Sep 18, 2024
1 parent e14d3e2 commit 889238f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
23 changes: 19 additions & 4 deletions src/nomad_simulations/schema_packages/workflow/base_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,20 @@ class SimulationWorkflow(Workflow):
description="""Methodological parameters used during the workflow.""",
)

def resolve_inputs_outputs(
def _resolve_inputs_outputs_from_archive(
self, archive: 'EntryArchive', logger: 'BoundLogger'
) -> None:
"""
Resolves the `inputs` and `outputs` sections from the archive sections under `data` and stores
Resolves the `ModelSystem`, `ModelMethod`, and `Outputs` sections from the archive and stores
them in private attributes.
Args:
archive (EntryArchive): The archive to resolve the sections from.
logger (BoundLogger): The logger to log messages.
"""
self._input_systems = []
self._input_methods = []
self._outputs = []
if (
not archive.data.model_system
or not archive.data.model_method
Expand All @@ -73,14 +76,26 @@ def resolve_inputs_outputs(
self._input_methods = archive.data.model_method
self._outputs = archive.data.outputs

def resolve_inputs_outputs(
self, archive: 'EntryArchive', logger: 'BoundLogger'
) -> None:
"""
Resolves the `inputs` and `outputs` of the `SimulationWorkflow`.
Args:
archive (EntryArchive): The archive to resolve the sections from.
logger (BoundLogger): The logger to log messages.
"""
self._resolve_inputs_outputs_from_archive(archive=archive, logger=logger)

# Resolve `inputs`
if not self.inputs:
if not self.inputs and self._input_systems:
self.m_add_sub_section(
Workflow.inputs,
Link(name='Input Model System', section=self._input_systems[0]),
)
# Resolve `outputs`
if not self.outputs:
if not self.outputs and self._outputs:
self.m_add_sub_section(
Workflow.outputs,
Link(name='Output Data', section=self._outputs[-1]),
Expand Down
51 changes: 48 additions & 3 deletions tests/workflow/test_base_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,54 @@


class TestSimulationWorkflow:
@pytest.mark.parametrize(
'model_system, model_method, outputs',
[
# empty sections in archive.data
(None, None, None),
# only one section in archive.data
(ModelSystem(), None, None),
# another section in archive.data
(None, ModelMethod(), None),
# only two sections in archive.data
(ModelSystem(), ModelMethod(), None),
# all sections in archive.data
(ModelSystem(), ModelMethod(), Outputs()),
],
)
def test_resolve_inputs_outputs_from_archive(
self,
model_system: Optional[ModelSystem],
model_method: Optional[ModelMethod],
outputs: Optional[Outputs],
):
"""
Test the `_resolve_inputs_outputs_from_archive` method of the `SimulationWorkflow` section.
"""
archive = EntryArchive()
simulation = generate_simulation(
model_system=model_system, model_method=model_method, outputs=outputs
)
archive.data = simulation
workflow = SimulationWorkflow()
archive.workflow2 = workflow
workflow._resolve_inputs_outputs_from_archive(archive=archive, logger=logger)
if (
model_system is not None
and model_method is not None
and outputs is not None
):
for input_system in workflow._input_systems:
assert isinstance(input_system, ModelSystem)
for input_method in workflow._input_methods:
assert isinstance(input_method, ModelMethod)
for output in workflow._outputs:
assert isinstance(output, Outputs)
else:
assert not workflow._input_systems
assert not workflow._input_methods
assert not workflow._outputs

@pytest.mark.parametrize(
'model_system, model_method, outputs, workflow_inputs, workflow_outputs',
[
Expand Down Expand Up @@ -84,16 +132,13 @@ def test_resolve_inputs_outputs(
assert workflow.inputs[0].name == workflow_inputs[0].name
# ! direct comparison of section does not work (probably an issue with references)
# assert workflow.inputs[0].section == workflow_inputs[0].section
assert workflow._input_systems[0] == model_system
assert workflow._input_methods[0] == model_method
if not workflow_outputs:
assert workflow.outputs == workflow_outputs
else:
assert len(workflow.outputs) == 1
assert workflow.outputs[0].name == workflow_outputs[0].name
# ! direct comparison of section does not work (probably an issue with references)
# assert workflow.outputs[0].section == workflow_outputs[0].section
assert workflow._outputs[0] == outputs

@pytest.mark.parametrize(
'model_system, model_method, outputs, workflow_inputs, workflow_outputs',
Expand Down

0 comments on commit 889238f

Please sign in to comment.