diff --git a/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py b/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py index b6b0770a..2b191a7d 100644 --- a/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py +++ b/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py @@ -81,6 +81,11 @@ def resolve_method(self) -> DFTPlusTBMethod: """ method = DFTPlusTBMethod() + # Check if TaskReference exists for both tasks + for task in self.tasks: + if not task.task: + return None + # DFT method reference dft_method = method.resolve_beyonddft_method_ref(task=self.tasks[0].task) if dft_method is not None: @@ -98,6 +103,14 @@ def link_tasks(self) -> None: """ Links the `outputs` of the DFT task with the `inputs` of the TB task. """ + # Initial checks on the `inputs` and `tasks[*].outputs` + if not self.inputs: + return None + for task in self.tasks: + if not task.m_xpath('task.outputs'): + return None + + # Assign dft task `inputs` to the `self.inputs[0]` dft_task = self.tasks[0] dft_task.inputs = [ Link( @@ -105,24 +118,27 @@ def link_tasks(self) -> None: section=self.inputs[0], ) ] + # and rewrite dft task `outputs` and its name dft_task.outputs = [ Link( name='Output DFT Data', - section=dft_task.outputs[-1], + section=dft_task.task.outputs[-1], ) ] + # Assign tb task `inputs` to the `dft_task.outputs[-1]` tb_task = self.tasks[1] tb_task.inputs = [ Link( name='Output DFT Data', - section=dft_task.outputs[-1], + section=dft_task.task.outputs[-1], ), ] + # and rewrite tb task `outputs` and its name tb_task.outputs = [ Link( name='Output TB Data', - section=tb_task.outputs[-1], + section=tb_task.task.outputs[-1], ) ] diff --git a/tests/workflow/test_dft_plus_tb.py b/tests/workflow/test_dft_plus_tb.py new file mode 100644 index 00000000..40d54644 --- /dev/null +++ b/tests/workflow/test_dft_plus_tb.py @@ -0,0 +1,197 @@ +# +# Copyright The NOMAD Authors. +# +# This file is part of NOMAD. See https://nomad-lab.eu for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +import pytest +from nomad.datamodel import EntryArchive +from nomad.datamodel.metainfo.workflow import Link, Task, TaskReference, Workflow + +from nomad_simulations.schema_packages.model_method import ( + DFT, + TB, + BaseModelMethod, + ModelMethod, +) +from nomad_simulations.schema_packages.model_system import ModelSystem +from nomad_simulations.schema_packages.outputs import Outputs +from nomad_simulations.schema_packages.workflow import ( + DFTPlusTB, + DFTPlusTBMethod, +) + +from ..conftest import generate_simulation +from . import logger + + +class TestDFTPlusTB: + @pytest.mark.parametrize( + 'tasks, result', + [ + (None, None), + ([TaskReference(name='dft')], None), + ( + [ + TaskReference(name='dft'), + TaskReference(name='tb 1'), + TaskReference(name='tb 2'), + ], + None, + ), + ([TaskReference(name='dft'), TaskReference(name='tb')], None), + ( + [ + TaskReference(name='dft', task=Task(name='dft task')), + TaskReference(name='tb'), + ], + None, + ), + ( + [ + TaskReference( + name='dft', + task=Task( + name='dft task', + inputs=[ + Link(name='model system', section=ModelSystem()), + Link(name='model method dft', section=DFT()), + ], + ), + ), + TaskReference( + name='tb', + task=Task(name='tb task'), + ), + ], + [DFT, None], + ), + ( + [ + TaskReference( + name='dft', + task=Task( + name='dft task', + inputs=[ + Link(name='model system', section=ModelSystem()), + Link(name='model method dft', section=DFT()), + ], + ), + ), + TaskReference( + name='tb', + task=Task( + name='tb task', + inputs=[ + Link(name='model system', section=ModelSystem()), + Link(name='model method tb', section=TB()), + ], + ), + ), + ], + [DFT, TB], + ), + ], + ) + def test_resolve_method( + self, + tasks: list[Task], + result: DFTPlusTBMethod, + ): + """ + Test the `resolve_method` method of the `DFTPlusTB` section. + """ + archive = EntryArchive() + workflow = DFTPlusTB() + archive.workflow2 = workflow + workflow.tasks = tasks + workflow_method = workflow.resolve_method() + if workflow_method is None: + assert workflow_method == result + else: + if result[0] is not None: + assert isinstance(workflow_method.dft_method_ref, result[0]) + else: + assert workflow_method.dft_method_ref == result[0] + if result[1] is not None: + assert isinstance(workflow_method.tb_method_ref, result[1]) + else: + assert workflow_method.tb_method_ref == result[1] + + def test_link_tasks(self): + """ + Test the `resolve_n_scf_steps` method of the `DFTPlusTB` section. + """ + archive = EntryArchive() + workflow = DFTPlusTB() + archive.workflow2 = workflow + workflow.tasks = [ + TaskReference( + name='dft', + task=Task( + name='dft task', + inputs=[ + Link(name='model system', section=ModelSystem()), + Link(name='model method dft', section=DFT()), + ], + outputs=[ + Link(name='output dft', section=Outputs()), + ], + ), + ), + TaskReference( + name='tb', + task=Task( + name='tb task', + inputs=[ + Link(name='model system', section=ModelSystem()), + Link(name='model method tb', section=TB()), + ], + outputs=[ + Link(name='output tb', section=Outputs()), + ], + ), + ), + ] + workflow.inputs = [Link(name='model system', section=ModelSystem())] + workflow.outputs = [Link(name='output tb', section=Outputs())] + + # Linking and overwritting inputs and outputs + workflow.link_tasks() + + dft_task = workflow.tasks[0] + assert len(dft_task.inputs) == 1 + assert dft_task.inputs[0].name == 'Input Model System' + assert len(dft_task.outputs) == 1 + assert dft_task.outputs[0].name == 'Output DFT Data' + tb_task = workflow.tasks[1] + assert len(tb_task.inputs) == 1 + assert tb_task.inputs[0].name == 'Output DFT Data' + assert len(tb_task.outputs) == 1 + assert tb_task.outputs[0].name == 'Output TB Data' + + def test_overwrite_fermi_level(self): + """ + Test the `overwrite_fermi_level` method of the `DFTPlusTB` section. + """ + assert True + + def test_normalize(self): + """ + Test the `normalize` method of the `DFTPlusTB` section. + """ + assert True