Skip to content

Commit

Permalink
atspec changes round1
Browse files Browse the repository at this point in the history
  • Loading branch information
loriab committed Nov 26, 2024
1 parent 5eb04e3 commit 60cedf1
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 40 deletions.
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ New Features

Enhancements
++++++++++++
- (536c) ``v2.AtomicInput`` gained a ``specification`` field where driver, model, keywords, and protocols now live. ``v2.AtomicSpecification`` and ``v1.QCInputSpecification`` (used by opt and td) learned a ``convert_v`` to interconvert.
- (536b) ``v1.AtomicResult.convert_v`` learned a ``external_input_data`` option to inject that field (if known) rather than using incomplete reconstruction from the v1 Result. may not be the final sol'n.
- (536b) ``v2.FailedOperation`` gained schema_name and schema_version=2.
- (536b) ``v2.AtomicResult`` no longer inherits from ``v2.AtomicInput``. It gained a ``input_data`` field for the corresponding ``AtomicInput`` and independent ``id`` and ``molecule`` fields (the latter being equivalvent to ``v1.AtomicResult.molecule`` with the frame of the results; ``v2.AtomicResult.input_data.molecule`` is new, preserving the input frame). Gained independent ``extras``
Expand Down
21 changes: 21 additions & 0 deletions qcelemental/models/v1/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,24 @@ class QCInputSpecification(ProtoModel):
def _version_stamp(cls, v):
return 1

def convert_v(
self, version: int
) -> Union["qcelemental.models.v1.QCInputSpecification", "qcelemental.models.v2.AtomicSpecification"]:
"""Convert to instance of particular QCSchema version."""
import qcelemental as qcel

if check_convertible_version(version, error="QCInputSpecification") == "self":
return self

dself = self.dict()
if version == 2:
dself.pop("schema_name")
dself.pop("schema_version")

self_vN = qcel.models.v2.AtomicSpecification(**dself)

return self_vN


class OptimizationInput(ProtoModel):
id: Optional[str] = None
Expand Down Expand Up @@ -297,6 +315,7 @@ def convert_v(
return self

dself = self.dict()
# dself = self.model_dump(exclude_unset=True, exclude_none=True)
if version == 2:
dself["input_specification"].pop("schema_version", None)
dself["optimization_spec"].pop("schema_version", None)
Expand Down Expand Up @@ -363,6 +382,8 @@ def convert_v(
k: [opthist_class(**res).convert_v(version) for res in lst]
for k, lst in dself["optimization_history"].items()
}
# if dself["optimization_spec"].pop("extras", None):
# pass

self_vN = qcel.models.v2.TorsionDriveResult(**dself)

Expand Down
22 changes: 17 additions & 5 deletions qcelemental/models/v1/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,15 @@ def convert_v(
return self

dself = self.dict()
spec = {}
if version == 2:
dself.pop("schema_name") # changes in v2

spec["driver"] = dself.pop("driver")
spec["model"] = dself.pop("model")
spec["keywords"] = dself.pop("keywords", None)
spec["protocols"] = dself.pop("protocols", None)
dself["specification"] = spec
self_vN = qcel.models.v2.AtomicInput(**dself)

return self_vN
Expand Down Expand Up @@ -831,13 +839,17 @@ def convert_v(
dself.pop("error", None)

input_data = {
k: dself.pop(k) for k in list(dself.keys()) if k in ["driver", "keywords", "model", "protocols"]
"specification": {
k: dself.pop(k) for k in list(dself.keys()) if k in ["driver", "keywords", "model", "protocols"]
},
"molecule": dself["molecule"], # duplicate since input mol has been overwritten
"extras": {
k: dself["extras"].pop(k) for k in list(dself["extras"].keys()) if k in []
}, # sep any merged extras
}
input_data["molecule"] = dself["molecule"] # duplicate since input mol has been overwritten
# any input provenance has been overwritten
input_data["extras"] = {
k: dself["extras"].pop(k) for k in list(dself["extras"].keys()) if k in []
} # sep any merged extras
# if dself["id"]:
# input_data["id"] = dself["id"] # in/out should likely match
if external_input_data:
# Note: overwriting with external, not updating. reconsider?
dself["input_data"] = external_input_data
Expand Down
9 changes: 8 additions & 1 deletion qcelemental/models/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@
TorsionDriveInput,
TorsionDriveResult,
)
from .results import AtomicInput, AtomicResult, AtomicResultProperties, AtomicResultProtocols, WavefunctionProperties
from .results import (
AtomicInput,
AtomicResult,
AtomicResultProperties,
AtomicResultProtocols,
AtomicSpecification,
WavefunctionProperties,
)


def qcschema_models():
Expand Down
85 changes: 71 additions & 14 deletions qcelemental/models/v2/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,52 @@ class AtomicResultProtocols(ProtoModel):
model_config = ExtendedConfigDict(force_skip_defaults=True)


class AtomicSpecification(ProtoModel):
"""Specification for a single point QC calculation"""

# schema_name: Literal["qcschema_atomicspecification"] = "qcschema_atomicspecification"
# schema_version: Literal[2] = Field(
# 2,
# description="The version number of ``schema_name`` to which this model conforms.",
# )
keywords: Dict[str, Any] = Field({}, description="The program specific keywords to be used.")
program: str = Field(
"", description="The program for which the Specification is intended."
) # TODO interaction with cmdline
driver: DriverEnum = Field(..., description=DriverEnum.__doc__)
model: Model = Field(..., description=Model.__doc__)
protocols: AtomicResultProtocols = Field(
AtomicResultProtocols(),
description=AtomicResultProtocols.__doc__,
)
extras: Dict[str, Any] = Field(
{},
description="Additional information to bundle with the computation. Use for schema development and scratch space.",
)

def convert_v(
self, version: int
) -> Union["qcelemental.models.v1.QCInputSpecification", "qcelemental.models.v2.AtomicSpecification"]:
"""Convert to instance of particular QCSchema version."""
import qcelemental as qcel

if check_convertible_version(version, error="AtomicSpecification") == "self":
return self

loss_store = {}
dself = self.model_dump()
if version == 1:
loss_store["protocols"] = dself.pop("protocols")
loss_store["program"] = dself.pop("program")

if loss_store:
dself["extras"]["_qcsk_conversion_loss"] = loss_store

self_vN = qcel.models.v1.QCInputSpecification(**dself)

return self_vN


### Primary models


Expand All @@ -664,22 +710,22 @@ class AtomicInput(ProtoModel):
r"""The MolSSI Quantum Chemistry Schema"""

id: Optional[str] = Field(None, description="The optional ID for the computation.")
schema_name: constr(strip_whitespace=True, pattern=r"^(qc\_?schema_input)$") = Field( # type: ignore
schema_name: Literal["qcschema_input"] = Field(
qcschema_input_default,
description=(
f"The QCSchema specification this model conforms to. Explicitly fixed as {qcschema_input_default}."
),
)
schema_version: Literal[2] = Field(
2,
description="The version number of :attr:`~qcelemental.models.AtomicInput.schema_name` to which this model conforms.",
description="The version number of ``schema_name`` to which this model conforms.",
)

molecule: Molecule = Field(..., description="The molecule to use in the computation.")
driver: DriverEnum = Field(..., description=str(DriverEnum.__doc__))
model: Model = Field(..., description=str(Model.__doc__))
keywords: Dict[str, Any] = Field({}, description="The program-specific keywords to be used.")
protocols: AtomicResultProtocols = Field(AtomicResultProtocols(), description=str(AtomicResultProtocols.__doc__))

specification: AtomicSpecification = Field(
..., description="Additional fields specifying how to run the single-point computation."
)

extras: Dict[str, Any] = Field(
{},
Expand All @@ -696,8 +742,8 @@ class AtomicInput(ProtoModel):

def __repr_args__(self) -> "ReprArgs":
return [
("driver", self.driver.value),
("model", self.model.model_dump()),
("driver", self.specification.driver.value),
("model", self.specification.model.model_dump()),
("molecule_hash", self.molecule.get_hash()[:7]),
]

Expand All @@ -718,6 +764,15 @@ def convert_v(

dself = self.model_dump()
if version == 1:
dself["driver"] = dself["specification"].pop("driver")
dself["model"] = dself["specification"].pop("model")
dself["keywords"] = dself["specification"].pop("keywords", None)
dself["protocols"] = dself["specification"].pop("protocols", None)
dself["extras"] = {**dself["specification"].pop("extras", {}), **dself["extras"]}
dself["specification"].pop("program", None) # TODO store?
assert not dself["specification"], dself["specification"]
dself.pop("specification") # now empty

self_vN = qcel.models.v1.AtomicInput(**dself)

return self_vN
Expand Down Expand Up @@ -784,7 +839,7 @@ def _validate_return_result(cls, v, info):
# Do not propagate validation errors
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")
driver = info.data["input_data"].driver
driver = info.data["input_data"].specification.driver
if driver == "energy":
if isinstance(v, np.ndarray) and v.size == 1:
v = v.item(0)
Expand Down Expand Up @@ -825,7 +880,7 @@ def _wavefunction_protocol(cls, value, info):
wfn.pop(k)

# Handle protocols
wfnp = info.data["input_data"].protocols.wavefunction
wfnp = info.data["input_data"].specification.protocols.wavefunction
return_keep = None
if wfnp == "all":
pass
Expand Down Expand Up @@ -875,7 +930,7 @@ def _stdout_protocol(cls, value, info):
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")

outp = info.data["input_data"].protocols.stdout
outp = info.data["input_data"].specification.protocols.stdout
if outp is True:
return value
elif outp is False:
Expand All @@ -890,7 +945,7 @@ def _native_file_protocol(cls, value, info):
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")

ancp = info.data["input_data"].protocols.native_files
ancp = info.data["input_data"].specification.protocols.native_files
if ancp == "all":
return value
elif ancp == "none":
Expand Down Expand Up @@ -920,12 +975,14 @@ def convert_v(

dself = self.model_dump()
if version == 1:
# input_data = self.input_data.convert_v(1) # TODO probably later
input_data = dself.pop("input_data")
# for input_data, work from model, not dict, to use convert_v
dself.pop("input_data")
input_data = self.input_data.convert_v(1).model_dump() # exclude_unset=True, exclude_none=True
input_data.pop("molecule", None) # discard
input_data.pop("provenance", None) # discard
dself["extras"] = {**input_data.pop("extras", {}), **dself.pop("extras", {})} # merge
dself = {**input_data, **dself}

self_vN = qcel.models.v1.AtomicResult(**dself)

return self_vN
14 changes: 11 additions & 3 deletions qcelemental/tests/test_model_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,17 @@ def test_repr_failed_op(schema_versions):
def test_repr_result(request, schema_versions):
AtomicInput = schema_versions.AtomicInput

result = AtomicInput(
**{"driver": "gradient", "model": {"method": "UFF"}, "molecule": {"symbols": ["He"], "geometry": [0, 0, 0]}}
)
if "v2" in request.node.name:
result = AtomicInput(
**{
"specification": {"driver": "gradient", "model": {"method": "UFF"}},
"molecule": {"symbols": ["He"], "geometry": [0, 0, 0]},
}
)
else:
result = AtomicInput(
**{"driver": "gradient", "model": {"method": "UFF"}, "molecule": {"symbols": ["He"], "geometry": [0, 0, 0]}}
)
drop_qcsk(result, request.node.name)
assert "molecule_hash" in str(result)
assert "molecule_hash" in repr(result)
Expand Down
Loading

0 comments on commit 60cedf1

Please sign in to comment.