Skip to content

Commit

Permalink
xtb
Browse files Browse the repository at this point in the history
  • Loading branch information
loriab committed Nov 14, 2024
1 parent b716788 commit 5b8a813
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 16 deletions.
14 changes: 7 additions & 7 deletions qcengine/programs/dftd_ng.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,16 @@ def compute(self, input_model: AtomicInput, config: TaskConfig) -> AtomicResult:
input_data["keywords"]["level_hint"] = level_hint

# dftd4 speaks qcsk.v1
input_model = qcelemental.models.v1.AtomicInput(**input_data)
input_model_v1 = qcelemental.models.v1.AtomicInput(**input_data)

# Run the Harness
output = run_qcschema(input_model)
output_v1 = run_qcschema(input_model_v1)

# d4 qcschema interface stores error in Result model
if not output.success:
return FailedOperation(input_data=input_data, error=output.error.model_dump())
if not output_v1.success:
return FailedOperation(input_data=input_data, error=output_v1.error.model_dump())

output = output.convert_v(2)
output = output_v1.convert_v(2, external_input_data=input_model)

if "info" in output.extras:
qcvkey = output.extras["info"]["fctldash"].upper()
Expand All @@ -126,14 +126,14 @@ def compute(self, input_model: AtomicInput, config: TaskConfig) -> AtomicResult:
if qcvkey:
calcinfo[f"{qcvkey} DISPERSION CORRECTION ENERGY"] = energy

if output.driver == "gradient":
if output.input_data.driver == "gradient":
gradient = output.return_result
calcinfo["CURRENT GRADIENT"] = gradient
calcinfo["DISPERSION CORRECTION GRADIENT"] = gradient
if qcvkey:
calcinfo[f"{qcvkey} DISPERSION CORRECTION GRADIENT"] = gradient

if output.keywords.get("pair_resolved", False):
if output.input_data.keywords.get("pair_resolved", False):
pw2 = output.extras["dftd4"]["additive pairwise energy"]
pw3 = output.extras["dftd4"]["non-additive pairwise energy"]
assert abs(pw2.sum() + pw3.sum() - energy) < 1.0e-8, f"{pw2.sum()} + {pw3.sum()} != {energy}"
Expand Down
6 changes: 6 additions & 0 deletions qcengine/programs/tests/test_dftd4.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_dftd4_task_tpss_m02(schema_versions, request):
},
},
driver="gradient",
extras={"mymsg": "will I pass through the calc?"},
)

atomic_input = checkver_and_convert(atomic_input, request.node.name, "pre")
Expand All @@ -82,6 +83,11 @@ def test_dftd4_task_tpss_m02(schema_versions, request):

assert atomic_result.success
assert pytest.approx(atomic_result.return_result, abs=thr) == return_result
if "v2" in request.node.name:
assert "will I pass" in atomic_result.input_data.extras.get("mymsg", "no key!"), "input extras roundtrip fail"
assert "mymsg" not in atomic_result.extras.get("mymsg", "no key!"), "input extras wrongly present in result"
else:
assert "will I pass" in atomic_result.extras.get("mymsg", "no key!"), "input extras roundtrip fail"


@using("dftd4")
Expand Down
13 changes: 5 additions & 8 deletions qcengine/programs/xtb.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,12 @@ def compute(self, input_data: AtomicInput, config: TaskConfig) -> AtomicResult:
from xtb.qcschema.harness import run_qcschema

# Run the Harness
input_data = input_data.convert_v(1)
output = run_qcschema(input_data)
input_data_v1 = input_data.convert_v(1)
output_v1 = run_qcschema(input_data_v1)

# xtb qcschema interface stores error in Result model
if not output.success:
return FailedOperation(input_data=input_data, error=output.error.model_dump())
if not output_v1.success:
return FailedOperation(input_data=input_data, error=output_v1.error.model_dump())

output = output.convert_v(2)

# Make sure all keys from the initial input spec are sent along
output.extras.update(input_data.extras)
output = output_v1.convert_v(2, external_input_data=input_data)
return output
6 changes: 5 additions & 1 deletion qcengine/tests/test_harness_canonical.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ def test_compute_gradient(program, model, keywords, schema_versions, request):
assert isinstance(ret.return_result, np.ndarray)
assert len(ret.return_result.shape) == 2
assert ret.return_result.shape[1] == 3
assert "mytag" in ret.extras, ret.extras
if "v2" in request.node.name:
assert "mytag" in ret.input_data.extras, ret.input_data.extras
assert "mytag" not in ret.extras, "input extras wrongly present in result"
else:
assert "mytag" in ret.extras, ret.extras


@pytest.mark.parametrize("program, model, keywords", _canonical_methods_qcsk_basis)
Expand Down

0 comments on commit 5b8a813

Please sign in to comment.