From 889238f45d2aaecfcfa1cb2a54870c3c2707b61a Mon Sep 17 00:00:00 2001 From: JosePizarro3 Date: Wed, 18 Sep 2024 12:19:08 +0200 Subject: [PATCH] Fix resolve_inputs_outputs method --- .../workflow/base_workflows.py | 23 +++++++-- tests/workflow/test_base_workflows.py | 51 +++++++++++++++++-- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/src/nomad_simulations/schema_packages/workflow/base_workflows.py b/src/nomad_simulations/schema_packages/workflow/base_workflows.py index b52f2efa..b518fa54 100644 --- a/src/nomad_simulations/schema_packages/workflow/base_workflows.py +++ b/src/nomad_simulations/schema_packages/workflow/base_workflows.py @@ -49,17 +49,20 @@ class SimulationWorkflow(Workflow): description="""Methodological parameters used during the workflow.""", ) - def resolve_inputs_outputs( + def _resolve_inputs_outputs_from_archive( self, archive: 'EntryArchive', logger: 'BoundLogger' ) -> None: """ - Resolves the `inputs` and `outputs` sections from the archive sections under `data` and stores + Resolves the `ModelSystem`, `ModelMethod`, and `Outputs` sections from the archive and stores them in private attributes. Args: archive (EntryArchive): The archive to resolve the sections from. logger (BoundLogger): The logger to log messages. """ + self._input_systems = [] + self._input_methods = [] + self._outputs = [] if ( not archive.data.model_system or not archive.data.model_method @@ -73,14 +76,26 @@ def resolve_inputs_outputs( self._input_methods = archive.data.model_method self._outputs = archive.data.outputs + def resolve_inputs_outputs( + self, archive: 'EntryArchive', logger: 'BoundLogger' + ) -> None: + """ + Resolves the `inputs` and `outputs` of the `SimulationWorkflow`. + + Args: + archive (EntryArchive): The archive to resolve the sections from. + logger (BoundLogger): The logger to log messages. + """ + self._resolve_inputs_outputs_from_archive(archive=archive, logger=logger) + # Resolve `inputs` - if not self.inputs: + if not self.inputs and self._input_systems: self.m_add_sub_section( Workflow.inputs, Link(name='Input Model System', section=self._input_systems[0]), ) # Resolve `outputs` - if not self.outputs: + if not self.outputs and self._outputs: self.m_add_sub_section( Workflow.outputs, Link(name='Output Data', section=self._outputs[-1]), diff --git a/tests/workflow/test_base_workflows.py b/tests/workflow/test_base_workflows.py index df7494e2..80011c6e 100644 --- a/tests/workflow/test_base_workflows.py +++ b/tests/workflow/test_base_workflows.py @@ -36,6 +36,54 @@ class TestSimulationWorkflow: + @pytest.mark.parametrize( + 'model_system, model_method, outputs', + [ + # empty sections in archive.data + (None, None, None), + # only one section in archive.data + (ModelSystem(), None, None), + # another section in archive.data + (None, ModelMethod(), None), + # only two sections in archive.data + (ModelSystem(), ModelMethod(), None), + # all sections in archive.data + (ModelSystem(), ModelMethod(), Outputs()), + ], + ) + def test_resolve_inputs_outputs_from_archive( + self, + model_system: Optional[ModelSystem], + model_method: Optional[ModelMethod], + outputs: Optional[Outputs], + ): + """ + Test the `_resolve_inputs_outputs_from_archive` method of the `SimulationWorkflow` section. + """ + archive = EntryArchive() + simulation = generate_simulation( + model_system=model_system, model_method=model_method, outputs=outputs + ) + archive.data = simulation + workflow = SimulationWorkflow() + archive.workflow2 = workflow + workflow._resolve_inputs_outputs_from_archive(archive=archive, logger=logger) + if ( + model_system is not None + and model_method is not None + and outputs is not None + ): + for input_system in workflow._input_systems: + assert isinstance(input_system, ModelSystem) + for input_method in workflow._input_methods: + assert isinstance(input_method, ModelMethod) + for output in workflow._outputs: + assert isinstance(output, Outputs) + else: + assert not workflow._input_systems + assert not workflow._input_methods + assert not workflow._outputs + @pytest.mark.parametrize( 'model_system, model_method, outputs, workflow_inputs, workflow_outputs', [ @@ -84,8 +132,6 @@ def test_resolve_inputs_outputs( assert workflow.inputs[0].name == workflow_inputs[0].name # ! direct comparison of section does not work (probably an issue with references) # assert workflow.inputs[0].section == workflow_inputs[0].section - assert workflow._input_systems[0] == model_system - assert workflow._input_methods[0] == model_method if not workflow_outputs: assert workflow.outputs == workflow_outputs else: @@ -93,7 +139,6 @@ def test_resolve_inputs_outputs( assert workflow.outputs[0].name == workflow_outputs[0].name # ! direct comparison of section does not work (probably an issue with references) # assert workflow.outputs[0].section == workflow_outputs[0].section - assert workflow._outputs[0] == outputs @pytest.mark.parametrize( 'model_system, model_method, outputs, workflow_inputs, workflow_outputs',