Skip to content

Commit

Permalink
Update GMSO to work with pydantic 2.0 (#745)
Browse files Browse the repository at this point in the history
* first pass using bump-pydantic per https://docs.pydantic.dev/2.0/migration/

* Further update syntax in abc and core

* checkpoint

* checkpoint

* rip out json_encoder since it iss being deprecated, will need to fix serialization related functions later

* fix up atom related files and tests

* fix bond.py

* fix up dihedral and improper

* fix various import in tests, removed json handler import

* re-add __hash__ method for atom type and parametric potential

* replace parse_obj with model_validate

* add serializer for several fields, fix everything but test_serialization and test_xml_handling

* fix some of the serialization test

* reimplement json_dict, parse_raw (since model_validate_json raised unimplemeted error), all serialization tests fixed

* found workaround to avoid using parse_raw

* fix remainder of xml handling test by change atom_type._etree_attrib to be consistent previous version

* update potential templates to use expected_parameters_dimensions_ for consistency

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* fix typo

* add the missing alias key

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* udpate env yml

* remove pydantic from matrx in CI

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
daico007 and pre-commit-ci[bot] authored Dec 13, 2023
1 parent c014a44 commit 096da71
Show file tree
Hide file tree
Showing 43 changed files with 559 additions and 554 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:
matrix:
os: [macOS-latest, ubuntu-latest]
python-version: ["3.9", "3.10", "3.11"]
pydantic-version: ["2"]

defaults:
run:
Expand All @@ -36,7 +35,6 @@ jobs:
environment-file: environment-dev.yml
create-args: >-
python=${{ matrix.python-version }}
pydantic=${{ matrix.pydantic-version }}
- name: Install Package
run: python -m pip install -e .
Expand Down
2 changes: 1 addition & 1 deletion environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- unyt>=2.9.5
- boltons
- lxml
- pydantic
- pydantic>=2
- networkx
- pytest
- mbuild>=0.11.0
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- unyt>=2.9.5
- boltons
- lxml
- pydantic
- pydantic>=2
- networkx
- ele>=0.2.0
- foyer>=0.11.3
Expand Down
2 changes: 1 addition & 1 deletion gmso/abc/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from gmso.abc.serialization_utils import GMSOJSONHandler, unyt_to_dict
from gmso.abc.serialization_utils import unyt_to_dict
41 changes: 21 additions & 20 deletions gmso/abc/abstract_connection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from typing import Optional, Sequence

from pydantic import ConfigDict, Field, model_validator

from gmso.abc.abstract_site import Site
from gmso.abc.gmso_base import GMSOBase
from gmso.exceptions import GMSOError

try:
from pydantic.v1 import Field, root_validator
except ImportError:
from pydantic import Field, root_validator


class Connection(GMSOBase):
__base_doc__ = """An abstract class that stores data about connections between sites.
Expand All @@ -18,12 +15,21 @@ class Connection(GMSOBase):
"""

name_: str = Field(
default="", description="Name of the connection. Defaults to class name"
default="",
description="Name of the connection. Defaults to class name.",
alias="name",
)

connection_members_: Optional[Sequence[Site]] = Field(
default=None,
description="A list of constituents in this connection, in order.",
alias="connection_members",
)
model_config = ConfigDict(
alias_to_fields={
"name": "name_",
"connection_members": "connection_members_",
}
)

@property
Expand All @@ -37,12 +43,12 @@ def name(self):
@property
def member_types(self):
"""Return the atomtype of the connection members as a list of string."""
return self._get_members_types_or_classes("member_types_")
return self._get_members_types_or_classes("member_types")

@property
def member_classes(self):
"""Return the class of the connection members as a list of string."""
return self._get_members_types_or_classes("member_classes_")
return self._get_members_types_or_classes("member_classes")

def _has_typed_members(self):
"""Check if all the members of this connection are typed."""
Expand All @@ -53,7 +59,7 @@ def _has_typed_members(self):

def _get_members_types_or_classes(self, to_return):
"""Return types or classes for connection members if they exist."""
assert to_return in {"member_types_", "member_classes_"}
assert to_return in {"member_types", "member_classes"}
ctype = getattr(self, "connection_type")
ctype_attr = getattr(ctype, to_return) if ctype else None

Expand All @@ -62,15 +68,18 @@ def _get_members_types_or_classes(self, to_return):
elif self._has_typed_members():
tc = [
member.atom_type.name
if to_return == "member_types_"
if to_return == "member_types"
else member.atom_type.atomclass
for member in self.__dict__.get("connection_members_")
]
return tc if all(tc) else None

@root_validator(pre=True)
@model_validator(mode="before")
def validate_fields(cls, values):
connection_members = values.get("connection_members")
if "connection_members" in values:
connection_members = values.get("connection_members")
else:
connection_members = values.get("connection_members_")

if all(isinstance(member, dict) for member in connection_members):
connection_members = [
Expand Down Expand Up @@ -103,11 +112,3 @@ def __repr__(self):

def __str__(self):
return f"<{self.__class__.__name__} {self.name}, id: {id(self)}> "

class Config:
fields = {"name_": "name", "connection_members_": "connection_members"}

alias_to_fields = {
"name": "name_",
"connection_members": "connection_members_",
}
79 changes: 52 additions & 27 deletions gmso/abc/abstract_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from abc import abstractmethod
from typing import Any, Dict, Iterator, List

import unyt as u
from pydantic import ConfigDict, Field, field_serializer, field_validator

from gmso.abc.gmso_base import GMSOBase
from gmso.abc.serialization_utils import unyt_to_dict
from gmso.utils.expression import PotentialExpression

try:
from pydantic.v1 import Field, validator
except ImportError:
from pydantic import Field, validator


class AbstractPotential(GMSOBase):
__base_doc__ = """An abstract potential class.
Expand All @@ -24,16 +23,28 @@ class AbstractPotential(GMSOBase):
"""

name_: str = Field(
"", description="The name of the potential. Defaults to class name"
"",
description="The name of the potential. Defaults to class name",
alias="name",
)

potential_expression_: PotentialExpression = Field(
PotentialExpression(expression="a*x+b", independent_variables={"x"}),
description="The mathematical expression for the potential",
alias="potential_expression",
)

tags_: Dict[str, Any] = Field(
{}, description="Tags associated with the potential"
{},
description="Tags associated with the potential",
alias="tags",
)
model_config = ConfigDict(
alias_to_fields={
"name": "name_",
"potential_expression": "potential_expression_",
"tags": "tags_",
}
)

def __init__(
Expand Down Expand Up @@ -72,12 +83,12 @@ def name(self):
@property
def independent_variables(self):
"""Optional[Union[set, str]]\n\tThe independent variables in the `Potential`'s expression."""
return self.potential_expression_.independent_variables
return self.potential_expression.independent_variables

@property
def expression(self):
"""Optional[Union[str, sympy.Expr]]\n\tThe mathematical expression of the functional form of the potential."""
return self.potential_expression_.expression
return self.potential_expression.expression

@property
def potential_expression(self):
Expand All @@ -96,6 +107,34 @@ def tag_names(self) -> List[str]:
def tag_names_iter(self) -> Iterator[str]:
return iter(self.__dict__.get("tags_"))

@field_serializer("potential_expression_")
def serialize_expression(self, potential_expression_: PotentialExpression):
expr = str(potential_expression_.expression)
ind = sorted(
list(
str(ind) for ind in potential_expression_.independent_variables
)
)
params = {
param: unyt_to_dict(val)
for param, val in potential_expression_.parameters.items()
}
return {
"expression": expr,
"independent_variables": ind,
"parameters": params,
}

@field_serializer("tags_")
def serialize_tags(self, tags_):
return_dict = dict()
for key, val in tags_.items():
if isinstance(val, u.unyt_array):
return_dict[key] = unyt_to_dict(val)
else:
return_dict[key] = val
return return_dict

def add_tag(self, tag: str, value: Any, overwrite=True) -> None:
"""Add metadata for a particular tag"""
if self.tags.get(tag) and not overwrite:
Expand All @@ -118,7 +157,8 @@ def delete_tag(self, tag: str) -> None:
def pop_tag(self, tag: str) -> Any:
return self.tags.pop(tag, None)

@validator("potential_expression_", pre=True)
@field_validator("potential_expression_", mode="before")
@classmethod
def validate_potential_expression(cls, v):
if isinstance(v, dict):
v = PotentialExpression(**v)
Expand All @@ -132,9 +172,9 @@ def set_expression(self):
def __setattr__(self, key: Any, value: Any) -> None:
"""Set attributes of the potential."""
if key == "expression":
self.potential_expression_.expression = value
self.potential_expression.expression = value
elif key == "independent_variables":
self.potential_expression_.independent_variables = value
self.potential_expression.independent_variables = value
elif key == "set_ref_":
return
else:
Expand All @@ -156,18 +196,3 @@ def __str__(self):
f"expression: {self.expression}, "
f"id: {id(self)}>"
)

class Config:
"""Pydantic configuration for the potential objects."""

fields = {
"name_": "name",
"potential_expression_": "potential_expression",
"tags_": "tags",
}

alias_to_fields = {
"name": "name_",
"potential_expression": "potential_expression_",
"tags": "tags_",
}
Loading

0 comments on commit 096da71

Please sign in to comment.