Skip to content

Commit

Permalink
Added testing for link_tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Sep 19, 2024
1 parent e72f051 commit 1170649
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def resolve_method(self) -> DFTPlusTBMethod:
"""
method = DFTPlusTBMethod()

# Check if TaskReference exists for both tasks
for task in self.tasks:
if not task.task:
return None

# DFT method reference
dft_method = method.resolve_beyonddft_method_ref(task=self.tasks[0].task)
if dft_method is not None:
Expand All @@ -98,31 +103,42 @@ def link_tasks(self) -> None:
"""
Links the `outputs` of the DFT task with the `inputs` of the TB task.
"""
# Initial checks on the `inputs` and `tasks[*].outputs`
if not self.inputs:
return None
for task in self.tasks:
if not task.m_xpath('task.outputs'):
return None

# Assign dft task `inputs` to the `self.inputs[0]`
dft_task = self.tasks[0]
dft_task.inputs = [
Link(
name='Input Model System',
section=self.inputs[0],
)
]
# and rewrite dft task `outputs` and its name
dft_task.outputs = [
Link(
name='Output DFT Data',
section=dft_task.outputs[-1],
section=dft_task.task.outputs[-1],
)
]

# Assign tb task `inputs` to the `dft_task.outputs[-1]`
tb_task = self.tasks[1]
tb_task.inputs = [
Link(
name='Output DFT Data',
section=dft_task.outputs[-1],
section=dft_task.task.outputs[-1],
),
]
# and rewrite tb task `outputs` and its name
tb_task.outputs = [
Link(
name='Output TB Data',
section=tb_task.outputs[-1],
section=tb_task.task.outputs[-1],
)
]

Expand Down
197 changes: 197 additions & 0 deletions tests/workflow/test_dft_plus_tb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#
# Copyright The NOMAD Authors.
#
# This file is part of NOMAD. See https://nomad-lab.eu for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Optional

import pytest
from nomad.datamodel import EntryArchive
from nomad.datamodel.metainfo.workflow import Link, Task, TaskReference, Workflow

from nomad_simulations.schema_packages.model_method import (
DFT,
TB,
BaseModelMethod,
ModelMethod,
)
from nomad_simulations.schema_packages.model_system import ModelSystem
from nomad_simulations.schema_packages.outputs import Outputs
from nomad_simulations.schema_packages.workflow import (
DFTPlusTB,
DFTPlusTBMethod,
)

from ..conftest import generate_simulation
from . import logger


class TestDFTPlusTB:
@pytest.mark.parametrize(
'tasks, result',
[
(None, None),
([TaskReference(name='dft')], None),
(
[
TaskReference(name='dft'),
TaskReference(name='tb 1'),
TaskReference(name='tb 2'),
],
None,
),
([TaskReference(name='dft'), TaskReference(name='tb')], None),
(
[
TaskReference(name='dft', task=Task(name='dft task')),
TaskReference(name='tb'),
],
None,
),
(
[
TaskReference(
name='dft',
task=Task(
name='dft task',
inputs=[
Link(name='model system', section=ModelSystem()),
Link(name='model method dft', section=DFT()),
],
),
),
TaskReference(
name='tb',
task=Task(name='tb task'),
),
],
[DFT, None],
),
(
[
TaskReference(
name='dft',
task=Task(
name='dft task',
inputs=[
Link(name='model system', section=ModelSystem()),
Link(name='model method dft', section=DFT()),
],
),
),
TaskReference(
name='tb',
task=Task(
name='tb task',
inputs=[
Link(name='model system', section=ModelSystem()),
Link(name='model method tb', section=TB()),
],
),
),
],
[DFT, TB],
),
],
)
def test_resolve_method(
self,
tasks: list[Task],
result: DFTPlusTBMethod,
):
"""
Test the `resolve_method` method of the `DFTPlusTB` section.
"""
archive = EntryArchive()
workflow = DFTPlusTB()
archive.workflow2 = workflow
workflow.tasks = tasks
workflow_method = workflow.resolve_method()
if workflow_method is None:
assert workflow_method == result
else:
if result[0] is not None:
assert isinstance(workflow_method.dft_method_ref, result[0])
else:
assert workflow_method.dft_method_ref == result[0]
if result[1] is not None:
assert isinstance(workflow_method.tb_method_ref, result[1])
else:
assert workflow_method.tb_method_ref == result[1]

def test_link_tasks(self):
"""
Test the `resolve_n_scf_steps` method of the `DFTPlusTB` section.
"""
archive = EntryArchive()
workflow = DFTPlusTB()
archive.workflow2 = workflow
workflow.tasks = [
TaskReference(
name='dft',
task=Task(
name='dft task',
inputs=[
Link(name='model system', section=ModelSystem()),
Link(name='model method dft', section=DFT()),
],
outputs=[
Link(name='output dft', section=Outputs()),
],
),
),
TaskReference(
name='tb',
task=Task(
name='tb task',
inputs=[
Link(name='model system', section=ModelSystem()),
Link(name='model method tb', section=TB()),
],
outputs=[
Link(name='output tb', section=Outputs()),
],
),
),
]
workflow.inputs = [Link(name='model system', section=ModelSystem())]
workflow.outputs = [Link(name='output tb', section=Outputs())]

# Linking and overwritting inputs and outputs
workflow.link_tasks()

dft_task = workflow.tasks[0]
assert len(dft_task.inputs) == 1
assert dft_task.inputs[0].name == 'Input Model System'
assert len(dft_task.outputs) == 1
assert dft_task.outputs[0].name == 'Output DFT Data'
tb_task = workflow.tasks[1]
assert len(tb_task.inputs) == 1
assert tb_task.inputs[0].name == 'Output DFT Data'
assert len(tb_task.outputs) == 1
assert tb_task.outputs[0].name == 'Output TB Data'

def test_overwrite_fermi_level(self):
"""
Test the `overwrite_fermi_level` method of the `DFTPlusTB` section.
"""
assert True

def test_normalize(self):
"""
Test the `normalize` method of the `DFTPlusTB` section.
"""
assert True

0 comments on commit 1170649

Please sign in to comment.