From 096da716720ccae57a0b0e4aefd6a4f279a02929 Mon Sep 17 00:00:00 2001 From: Co Quach <43968221+daico007@users.noreply.github.com> Date: Wed, 13 Dec 2023 11:31:37 -0600 Subject: [PATCH] Update GMSO to work with pydantic 2.0 (#745) * first pass using bump-pydantic per https://docs.pydantic.dev/2.0/migration/ * Further update syntax in abc and core * checkpoint * checkpoint * rip out json_encoder since it iss being deprecated, will need to fix serialization related functions later * fix up atom related files and tests * fix bond.py * fix up dihedral and improper * fix various import in tests, removed json handler import * re-add __hash__ method for atom type and parametric potential * replace parse_obj with model_validate * add serializer for several fields, fix everything but test_serialization and test_xml_handling * fix some of the serialization test * reimplement json_dict, parse_raw (since model_validate_json raised unimplemeted error), all serialization tests fixed * found workaround to avoid using parse_raw * fix remainder of xml handling test by change atom_type._etree_attrib to be consistent previous version * update potential templates to use expected_parameters_dimensions_ for consistency * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * add the missing alias key * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * udpate env yml * remove pydantic from matrx in CI --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/CI.yaml | 2 - environment-dev.yml | 2 +- environment.yml | 2 +- gmso/abc/__init__.py | 2 +- gmso/abc/abstract_connection.py | 41 ++++---- gmso/abc/abstract_potential.py | 79 ++++++++++----- gmso/abc/abstract_site.py | 76 +++++++------- gmso/abc/auto_doc.py | 5 +- gmso/abc/gmso_base.py | 81 ++++++++------- gmso/abc/serialization_utils.py | 18 ---- gmso/core/angle.py | 41 ++++---- gmso/core/angle_type.py | 28 +++--- gmso/core/atom.py | 101 +++++++++++-------- gmso/core/atom_type.py | 112 ++++++++++++--------- gmso/core/bond.py | 43 ++++---- gmso/core/bond_type.py | 30 +++--- gmso/core/dihedral.py | 41 ++++---- gmso/core/dihedral_type.py | 28 +++--- gmso/core/element.py | 21 ++-- gmso/core/forcefield.py | 2 +- gmso/core/improper.py | 36 +++---- gmso/core/improper_type.py | 30 +++--- gmso/core/pairpotential_type.py | 20 ++-- gmso/core/parametric_potential.py | 34 ++++--- gmso/core/topology.py | 4 +- gmso/external/convert_parmed.py | 32 +++--- gmso/formats/gro.py | 3 +- gmso/formats/json.py | 16 +-- gmso/lib/potential_templates.py | 33 +++--- gmso/parameterization/parameterize.py | 2 +- gmso/tests/abc/test_serialization_utils.py | 8 +- gmso/tests/base_test.py | 6 +- gmso/tests/test_angle.py | 8 +- gmso/tests/test_atom.py | 8 +- gmso/tests/test_atom_type.py | 8 +- gmso/tests/test_bond.py | 8 +- gmso/tests/test_dihedral.py | 8 +- gmso/tests/test_improper.py | 8 +- gmso/tests/test_potential.py | 6 +- gmso/tests/test_serialization.py | 60 ++++++----- gmso/tests/test_views.py | 2 +- gmso/utils/decorators.py | 15 --- gmso/utils/expression.py | 3 - 43 files changed, 559 insertions(+), 554 deletions(-) diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index aed95a5c6..da7128976 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -20,7 +20,6 @@ jobs: matrix: os: [macOS-latest, ubuntu-latest] python-version: ["3.9", "3.10", "3.11"] - pydantic-version: ["2"] defaults: run: @@ -36,7 +35,6 @@ jobs: environment-file: environment-dev.yml create-args: >- python=${{ matrix.python-version }} - pydantic=${{ matrix.pydantic-version }} - name: Install Package run: python -m pip install -e . diff --git a/environment-dev.yml b/environment-dev.yml index 73a40d308..425d7d8cb 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -9,7 +9,7 @@ dependencies: - unyt>=2.9.5 - boltons - lxml - - pydantic + - pydantic>=2 - networkx - pytest - mbuild>=0.11.0 diff --git a/environment.yml b/environment.yml index 57144b91d..85a309123 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ dependencies: - unyt>=2.9.5 - boltons - lxml - - pydantic + - pydantic>=2 - networkx - ele>=0.2.0 - foyer>=0.11.3 diff --git a/gmso/abc/__init__.py b/gmso/abc/__init__.py index dece032fd..5049d4ce7 100644 --- a/gmso/abc/__init__.py +++ b/gmso/abc/__init__.py @@ -1 +1 @@ -from gmso.abc.serialization_utils import GMSOJSONHandler, unyt_to_dict +from gmso.abc.serialization_utils import unyt_to_dict diff --git a/gmso/abc/abstract_connection.py b/gmso/abc/abstract_connection.py index 69f96f983..e4a015f05 100644 --- a/gmso/abc/abstract_connection.py +++ b/gmso/abc/abstract_connection.py @@ -1,14 +1,11 @@ from typing import Optional, Sequence +from pydantic import ConfigDict, Field, model_validator + from gmso.abc.abstract_site import Site from gmso.abc.gmso_base import GMSOBase from gmso.exceptions import GMSOError -try: - from pydantic.v1 import Field, root_validator -except ImportError: - from pydantic import Field, root_validator - class Connection(GMSOBase): __base_doc__ = """An abstract class that stores data about connections between sites. @@ -18,12 +15,21 @@ class Connection(GMSOBase): """ name_: str = Field( - default="", description="Name of the connection. Defaults to class name" + default="", + description="Name of the connection. Defaults to class name.", + alias="name", ) connection_members_: Optional[Sequence[Site]] = Field( default=None, description="A list of constituents in this connection, in order.", + alias="connection_members", + ) + model_config = ConfigDict( + alias_to_fields={ + "name": "name_", + "connection_members": "connection_members_", + } ) @property @@ -37,12 +43,12 @@ def name(self): @property def member_types(self): """Return the atomtype of the connection members as a list of string.""" - return self._get_members_types_or_classes("member_types_") + return self._get_members_types_or_classes("member_types") @property def member_classes(self): """Return the class of the connection members as a list of string.""" - return self._get_members_types_or_classes("member_classes_") + return self._get_members_types_or_classes("member_classes") def _has_typed_members(self): """Check if all the members of this connection are typed.""" @@ -53,7 +59,7 @@ def _has_typed_members(self): def _get_members_types_or_classes(self, to_return): """Return types or classes for connection members if they exist.""" - assert to_return in {"member_types_", "member_classes_"} + assert to_return in {"member_types", "member_classes"} ctype = getattr(self, "connection_type") ctype_attr = getattr(ctype, to_return) if ctype else None @@ -62,15 +68,18 @@ def _get_members_types_or_classes(self, to_return): elif self._has_typed_members(): tc = [ member.atom_type.name - if to_return == "member_types_" + if to_return == "member_types" else member.atom_type.atomclass for member in self.__dict__.get("connection_members_") ] return tc if all(tc) else None - @root_validator(pre=True) + @model_validator(mode="before") def validate_fields(cls, values): - connection_members = values.get("connection_members") + if "connection_members" in values: + connection_members = values.get("connection_members") + else: + connection_members = values.get("connection_members_") if all(isinstance(member, dict) for member in connection_members): connection_members = [ @@ -103,11 +112,3 @@ def __repr__(self): def __str__(self): return f"<{self.__class__.__name__} {self.name}, id: {id(self)}> " - - class Config: - fields = {"name_": "name", "connection_members_": "connection_members"} - - alias_to_fields = { - "name": "name_", - "connection_members": "connection_members_", - } diff --git a/gmso/abc/abstract_potential.py b/gmso/abc/abstract_potential.py index 673519e0d..89abe9f0f 100644 --- a/gmso/abc/abstract_potential.py +++ b/gmso/abc/abstract_potential.py @@ -2,14 +2,13 @@ from abc import abstractmethod from typing import Any, Dict, Iterator, List +import unyt as u +from pydantic import ConfigDict, Field, field_serializer, field_validator + from gmso.abc.gmso_base import GMSOBase +from gmso.abc.serialization_utils import unyt_to_dict from gmso.utils.expression import PotentialExpression -try: - from pydantic.v1 import Field, validator -except ImportError: - from pydantic import Field, validator - class AbstractPotential(GMSOBase): __base_doc__ = """An abstract potential class. @@ -24,16 +23,28 @@ class AbstractPotential(GMSOBase): """ name_: str = Field( - "", description="The name of the potential. Defaults to class name" + "", + description="The name of the potential. Defaults to class name", + alias="name", ) potential_expression_: PotentialExpression = Field( PotentialExpression(expression="a*x+b", independent_variables={"x"}), description="The mathematical expression for the potential", + alias="potential_expression", ) tags_: Dict[str, Any] = Field( - {}, description="Tags associated with the potential" + {}, + description="Tags associated with the potential", + alias="tags", + ) + model_config = ConfigDict( + alias_to_fields={ + "name": "name_", + "potential_expression": "potential_expression_", + "tags": "tags_", + } ) def __init__( @@ -72,12 +83,12 @@ def name(self): @property def independent_variables(self): """Optional[Union[set, str]]\n\tThe independent variables in the `Potential`'s expression.""" - return self.potential_expression_.independent_variables + return self.potential_expression.independent_variables @property def expression(self): """Optional[Union[str, sympy.Expr]]\n\tThe mathematical expression of the functional form of the potential.""" - return self.potential_expression_.expression + return self.potential_expression.expression @property def potential_expression(self): @@ -96,6 +107,34 @@ def tag_names(self) -> List[str]: def tag_names_iter(self) -> Iterator[str]: return iter(self.__dict__.get("tags_")) + @field_serializer("potential_expression_") + def serialize_expression(self, potential_expression_: PotentialExpression): + expr = str(potential_expression_.expression) + ind = sorted( + list( + str(ind) for ind in potential_expression_.independent_variables + ) + ) + params = { + param: unyt_to_dict(val) + for param, val in potential_expression_.parameters.items() + } + return { + "expression": expr, + "independent_variables": ind, + "parameters": params, + } + + @field_serializer("tags_") + def serialize_tags(self, tags_): + return_dict = dict() + for key, val in tags_.items(): + if isinstance(val, u.unyt_array): + return_dict[key] = unyt_to_dict(val) + else: + return_dict[key] = val + return return_dict + def add_tag(self, tag: str, value: Any, overwrite=True) -> None: """Add metadata for a particular tag""" if self.tags.get(tag) and not overwrite: @@ -118,7 +157,8 @@ def delete_tag(self, tag: str) -> None: def pop_tag(self, tag: str) -> Any: return self.tags.pop(tag, None) - @validator("potential_expression_", pre=True) + @field_validator("potential_expression_", mode="before") + @classmethod def validate_potential_expression(cls, v): if isinstance(v, dict): v = PotentialExpression(**v) @@ -132,9 +172,9 @@ def set_expression(self): def __setattr__(self, key: Any, value: Any) -> None: """Set attributes of the potential.""" if key == "expression": - self.potential_expression_.expression = value + self.potential_expression.expression = value elif key == "independent_variables": - self.potential_expression_.independent_variables = value + self.potential_expression.independent_variables = value elif key == "set_ref_": return else: @@ -156,18 +196,3 @@ def __str__(self): f"expression: {self.expression}, " f"id: {id(self)}>" ) - - class Config: - """Pydantic configuration for the potential objects.""" - - fields = { - "name_": "name", - "potential_expression_": "potential_expression", - "tags_": "tags", - } - - alias_to_fields = { - "name": "name_", - "potential_expression": "potential_expression_", - "tags": "tags_", - } diff --git a/gmso/abc/abstract_site.py b/gmso/abc/abstract_site.py index 0234446ae..6fbed0f3c 100644 --- a/gmso/abc/abstract_site.py +++ b/gmso/abc/abstract_site.py @@ -4,16 +4,20 @@ import numpy as np import unyt as u +from pydantic import ( + ConfigDict, + Field, + StrictInt, + StrictStr, + field_serializer, + field_validator, +) from unyt.exceptions import InvalidUnitOperation from gmso.abc.gmso_base import GMSOBase +from gmso.abc.serialization_utils import unyt_to_dict from gmso.exceptions import GMSOError -try: - from pydantic.v1 import Field, StrictInt, StrictStr, validator -except ImportError: - from pydantic import Field, StrictInt, StrictStr, validator - PositionType = Union[Sequence[float], np.ndarray, u.unyt_array] MoleculeType = NamedTuple("Molecule", name=StrictStr, number=StrictInt) ResidueType = NamedTuple("Residue", name=StrictStr, number=StrictInt) @@ -30,10 +34,13 @@ def default_position(): class Site(GMSOBase): __iterable_attributes__: ClassVar[set] = { + "name", "label", "group", "molecule", "residue", + "position", + "model_config", } __base_doc__: ClassVar[ @@ -54,28 +61,47 @@ class Site(GMSOBase): name_: str = Field( "", + validate_default=True, description="Name of the site, defaults to class name", + alias="name", + ) + label_: str = Field( + "", description="Label to be assigned to the site", alias="label" ) - - label_: str = Field("", description="Label to be assigned to the site") group_: Optional[StrictStr] = Field( - None, description="Flexible alternative label relative to site" + None, + description="Flexible alternative label relative to site", + alias="group", ) molecule_: Optional[MoleculeType] = Field( None, description="Molecule label for the site, format of (molecule_name, molecule_number)", + alias="molecule", ) residue_: Optional[ResidueType] = Field( None, description="Residue label for the site, format of (residue_name, residue_number)", + alias="residue", ) position_: PositionType = Field( default_factory=default_position, description="The 3D Cartesian coordinates of the position of the site", + alias="position", + ) + + model_config = ConfigDict( + alias_to_fields={ + "name": "name_", + "label": "label_", + "group": "group_", + "molecule": "molecule_", + "residue": "residue_", + "position": "position_", + }, ) @property @@ -108,6 +134,10 @@ def residue(self): """Return the residue assigned to the site.""" return self.__dict__.get("residue_") + @field_serializer("position_") + def serialize_position(self, position_: PositionType): + return unyt_to_dict(position_) + def __repr__(self): """Return the formatted representation of the site.""" return ( @@ -124,7 +154,8 @@ def __str__(self): f"label: {self.label if self.label else None} id: {id(self)}>" ) - @validator("position_") + @field_validator("position_") + @classmethod def is_valid_position(cls, position): """Validate attribute position.""" if position is None: @@ -152,7 +183,7 @@ def is_valid_position(cls, position): return position - @validator("name_", pre=True, always=True) + @field_validator("name_") def inject_name(cls, value): if value == "" or value is None: return cls.__name__ @@ -165,28 +196,3 @@ def __new__(cls, *args: Any, **kwargs: Any) -> SiteT: raise TypeError("Cannot instantiate abstract class of type Site") else: return object.__new__(cls) - - class Config: - """Pydantic configuration for site objects.""" - - arbitrary_types_allowed = True - - fields = { - "name_": "name", - "position_": "position", - "label_": "label", - "group_": "group", - "molecule_": "molecule", - "residue_": "residue", - } - - alias_to_fields = { - "name": "name_", - "position": "position_", - "label": "label_", - "group": "group_", - "molecule": "molecule_", - "residue": "residue_", - } - - validate_assignment = True diff --git a/gmso/abc/auto_doc.py b/gmso/abc/auto_doc.py index 3a9634ed6..325f08cf5 100644 --- a/gmso/abc/auto_doc.py +++ b/gmso/abc/auto_doc.py @@ -4,10 +4,7 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple, Type, Union -try: - from pydantic.v1 import BaseModel -except ImportError: - from pydantic import BaseModel +from pydantic import BaseModel BASE_DOC_ATTR = "__base_doc__" FIELDS_IN_DOCSTRING = "__alias_to_fields__" diff --git a/gmso/abc/gmso_base.py b/gmso/abc/gmso_base.py index 3cdc30952..38f8b3e25 100644 --- a/gmso/abc/gmso_base.py +++ b/gmso/abc/gmso_base.py @@ -4,16 +4,12 @@ from abc import ABC from typing import Any, ClassVar, Type -from gmso.abc import GMSOJSONHandler +from pydantic import BaseModel, ConfigDict, validators + from gmso.abc.auto_doc import apply_docs from gmso.abc.serialization_utils import dict_to_unyt -try: - from pydantic.v1 import BaseModel - from pydantic.v1.validators import dict_validator -except ImportError: - from pydantic import BaseModel - from pydantic.validators import dict_validator +dict_validator = validators.getattr_migration("dict_validator") class GMSOBase(BaseModel, ABC): @@ -25,6 +21,13 @@ class GMSOBase(BaseModel, ABC): __docs_generated__: ClassVar[bool] = False + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=True, + extra="forbid", + populate_by_name=True, + ) + def __hash__(self): """Return the unique hash of the object.""" return id(self) @@ -35,9 +38,9 @@ def __eq__(self, other): def __setattr__(self, name: Any, value: Any) -> None: """Set the attributes of the object.""" - if name in self.__config__.alias_to_fields: - name = self.__config__.alias_to_fields[name] - elif name in self.__config__.alias_to_fields.values(): + if name in self.model_config.get("alias_to_fields"): + name = self.model_config.get("alias_to_fields")[name] + elif name in self.model_config.get("alias_to_fields").values(): warnings.warn( "Use of internal fields is discouraged. " "Please use external fields to set attributes." @@ -62,15 +65,44 @@ def __init_subclass__(cls, **kwargs): apply_docs(cls, map_names=True, silent=False) @classmethod - def parse_obj(cls: Type["Model"], obj: Any) -> "Model": + def model_validate(cls: Type["Model"], obj: Any) -> "Model": dict_to_unyt(obj) - return super(GMSOBase, cls).parse_obj(obj) + return super(GMSOBase, cls).model_validate(obj) - def dict(self, **kwargs) -> "DictStrAny": + def model_dump(self, **kwargs) -> "DictStrAny": kwargs["by_alias"] = True - super_dict = super(GMSOBase, self).dict(**kwargs) + + additional_excludes = set() + if "exclude" in kwargs: + for term in kwargs["exclude"]: + if term in self.model_config["alias_to_fields"]: + additional_excludes.add( + self.model_config["alias_to_fields"][term] + ) + kwargs["exclude"] = kwargs["exclude"].union(additional_excludes) + super_dict = super(GMSOBase, self).model_dump(**kwargs) return super_dict + def model_dump_json(self, **kwargs): + kwargs["by_alias"] = True + + additional_excludes = set() + if "exclude" in kwargs: + for term in kwargs["exclude"]: + if term in self.model_config["alias_to_fields"]: + additional_excludes.add( + self.model_config["alias_to_fields"][term] + ) + kwargs["exclude"] = kwargs["exclude"].union(additional_excludes) + super_dict = super(GMSOBase, self).model_dump_json(**kwargs) + + return super_dict + + def json_dict(self, **kwargs): + """Return a JSON serializable dictionary from the object""" + raw_json = self.model_dump_json(**kwargs) + return json.loads(raw_json) + def _iter(self, **kwargs) -> "TupleGenerator": exclude = kwargs.get("exclude") include = kwargs.get("include") @@ -95,18 +127,6 @@ def _iter(self, **kwargs) -> "TupleGenerator": yield from super()._iter(**kwargs) - def json(self, **kwargs): - kwargs["by_alias"] = True - # FIXME: Pydantic>1.8 doesn't recognize json_encoders without this update - self.__config__.json_encoders.update(GMSOJSONHandler.json_encoders) - - return super(GMSOBase, self).json(**kwargs) - - def json_dict(self, **kwargs): - """Return a JSON serializable dictionary from the object""" - raw_json = self.json(**kwargs) - return json.loads(raw_json) - @classmethod def validate(cls, value): """Ensure that the object is validated before use.""" @@ -119,12 +139,3 @@ def validate(cls, value): def __get_validators__(cls) -> "CallableGenerator": """Get the validators of the object.""" yield cls.validate - - class Config: - """Pydantic configuration for base object.""" - - arbitrary_types_allowed = True - alias_to_fields = dict() - extra = "forbid" - json_encoders = GMSOJSONHandler.json_encoders - allow_population_by_field_name = True diff --git a/gmso/abc/serialization_utils.py b/gmso/abc/serialization_utils.py index 683343192..89f31de1e 100644 --- a/gmso/abc/serialization_utils.py +++ b/gmso/abc/serialization_utils.py @@ -33,21 +33,3 @@ def dict_to_unyt(dict_obj) -> None: unyt_func = u.unyt_array dict_obj[key] = unyt_func(np_array, value["unit"]) - - -class JSONHandler: - def __init__(self): - self.json_encoders = {} - - def register(self, type_, callable_, override=False): - """Register a new JSON encoder for an object for serialization in GMSO""" - if type_ not in self.json_encoders: - self.json_encoders[type_] = callable_ - else: - if override: - warn(f"Overriding json serializer for {type_}") - self.json_encoders[type_] = callable_ - - -GMSOJSONHandler = JSONHandler() -GMSOJSONHandler.register(u.unyt_array, unyt_to_dict) diff --git a/gmso/core/angle.py b/gmso/core/angle.py index 69a9f32f1..882a36bc3 100644 --- a/gmso/core/angle.py +++ b/gmso/core/angle.py @@ -1,15 +1,12 @@ """Support for 3-partner connections between gmso.core.Atoms.""" from typing import Callable, ClassVar, Optional, Tuple +from pydantic import ConfigDict, Field + from gmso.abc.abstract_connection import Connection from gmso.core.angle_type import AngleType from gmso.core.atom import Atom -try: - from pydantic.v1 import Field -except ImportError: - from pydantic import Field - class Angle(Connection): __base_doc__ = """A 3-partner connection between Atoms. @@ -24,14 +21,18 @@ class Angle(Connection): __eq__, __repr__, _validate methods Additional _validate methods are presented """ - __members_creator__: ClassVar[Callable] = Atom.parse_obj + __members_creator__: ClassVar[Callable] = Atom.model_validate connection_members_: Tuple[Atom, Atom, Atom] = Field( - ..., description="The 3 atoms involved in the angle." + ..., + description="The 3 atoms involved in the angle.", + alias="connection_members", ) angle_type_: Optional[AngleType] = Field( - default=None, description="AngleType of this angle." + default=None, + description="AngleType of this angle.", + alias="angle_type", ) restraint_: Optional[dict] = Field( @@ -42,6 +43,16 @@ class Angle(Connection): Refer to https://manual.gromacs.org/current/reference-manual/topologies/topology-file-formats.html for more information. """, + alias="restraint", + ) + model_config = ConfigDict( + alias_to_fields=dict( + **Connection.model_config["alias_to_fields"], + **{ + "angle_type": "angle_type_", + "restraint": "restraint_", + } + ) ) @property @@ -85,17 +96,3 @@ def __setattr__(self, key, value): super(Angle, self).__setattr__("angle_type", value) else: super(Angle, self).__setattr__(key, value) - - class Config: - """Support pydantic configuration for attributes and behavior.""" - - fields = { - "connection_members_": "connection_members", - "angle_type_": "angle_type", - "restraint_": "restraint", - } - alias_to_fields = { - "connection_members": "connection_members_", - "angle_type": "angle_type_", - "restraint": "restraint_", - } diff --git a/gmso/core/angle_type.py b/gmso/core/angle_type.py index 9c75dfbff..1e62a0a27 100644 --- a/gmso/core/angle_type.py +++ b/gmso/core/angle_type.py @@ -1,15 +1,11 @@ from typing import Optional, Tuple import unyt as u +from pydantic import ConfigDict, Field from gmso.core.parametric_potential import ParametricPotential from gmso.utils.expression import PotentialExpression -try: - from pydantic.v1 import Field -except ImportError: - from pydantic import Field - class AngleType(ParametricPotential): __base_doc__ = """A descripton of the interaction between 3 bonded partners. @@ -32,12 +28,23 @@ class AngleType(ParametricPotential): None, description="List-like of gmso.AtomType.name " "defining the members of this angle type", + alias="member_types", ) member_classes_: Optional[Tuple[str, str, str]] = Field( None, description="List-like of gmso.AtomType.atomclass " "defining the members of this angle type", + alias="member_classes", + ) + model_config = ConfigDict( + alias_to_fields=dict( + **ParametricPotential.model_config["alias_to_fields"], + **{ + "member_types": "member_types_", + "member_classes": "member_classes_", + }, + ), ) def __init__( @@ -80,14 +87,3 @@ def member_types(self): @property def member_classes(self): return self.__dict__.get("member_classes_") - - class Config: - fields = { - "member_types_": "member_types", - "member_classes_": "member_classes", - } - - alias_to_fields = { - "member_types": "member_types_", - "member_classes": "member_classes_", - } diff --git a/gmso/core/atom.py b/gmso/core/atom.py index bd1a1c551..8ee9993f1 100644 --- a/gmso/core/atom.py +++ b/gmso/core/atom.py @@ -3,18 +3,15 @@ from typing import Optional, Union import unyt as u +from pydantic import ConfigDict, Field, field_serializer, field_validator from gmso.abc.abstract_site import Site +from gmso.abc.serialization_utils import unyt_to_dict from gmso.core.atom_type import AtomType from gmso.core.element import Element from gmso.utils._constants import UNIT_WARNING_STRING from gmso.utils.misc import ensure_valid_dimensions -try: - from pydantic.v1 import Field, validator -except ImportError: - from pydantic import Field, validator - class Atom(Site): __base_doc__ = """An atom represents a single element association in a topology. @@ -43,21 +40,25 @@ class Atom(Site): the gmso.abc.abstract site class """ charge_: Optional[Union[u.unyt_quantity, float]] = Field( - None, - description="Charge of the atom", + None, description="Charge of the atom", alias="charge" ) mass_: Optional[Union[u.unyt_quantity, float]] = Field( - None, description="Mass of the atom" + None, + description="Mass of the atom", + alias="mass", ) element_: Optional[Element] = Field( - None, description="Element associated with the atom" + None, + description="Element associated with the atom", + alias="element", ) atom_type_: Optional[AtomType] = Field( - None, description="AtomType associated with the atom" + None, description="AtomType associated with the atom", alias="atom_type" ) + restraint_: Optional[dict] = Field( default=None, description=""" @@ -68,6 +69,19 @@ class Atom(Site): """, ) + model_config = ConfigDict( + alias_to_fields=dict( + **Site.model_config["alias_to_fields"], + **{ + "charge": "charge_", + "mass": "mass_", + "element": "element_", + "atom_type": "atom_type_", + "restraint": "restraint_", + }, + ), + ) + @property def charge(self) -> Union[u.unyt_quantity, None]: """Return the charge of the atom.""" @@ -83,7 +97,7 @@ def charge(self) -> Union[u.unyt_quantity, None]: @property def mass(self) -> Union[u.unyt_quantity, None]: """Return the mass of the atom.""" - mass = self.__dict__.get("mass_", None) + mass = self.__dict__.get("mass_", property) atom_type = self.__dict__.get("atom_type_", None) if mass is not None: return mass @@ -98,7 +112,7 @@ def element(self) -> Union[Element, None]: return self.__dict__.get("element_", None) @property - def atom_type(self) -> Union[AtomType, None]: + def atom_type(self) -> Union[AtomType, property]: """Return the atom_type associated with the atom.""" return self.__dict__.get("atom_type_", None) @@ -107,6 +121,30 @@ def restraint(self): """Return the restraint of this atom.""" return self.__dict__.get("restraint_") + @field_serializer("charge_") + def serialize_charge(self, charge_: Union[u.unyt_quantity, None]): + if charge_ is None: + return None + else: + return unyt_to_dict(charge_) + + @field_serializer("mass_") + def serialize_mass(self, mass_: Union[u.unyt_quantity, None]): + if mass_ is None: + return None + else: + return unyt_to_dict(mass_) + + @field_serializer("restraint_") + def serialize_restraint(self, restraint_: Union[dict, None]): + if restraint_ is None: + return None + else: + converted_restraint = { + key: unyt_to_dict(val) for key, val in restraint_.items() + } + return converted_restraint + def clone(self): """Clone this atom.""" return Atom( @@ -116,10 +154,12 @@ def clone(self): molecule=self.molecule, residue=self.residue, position=self.position, - charge=self.charge_, - mass=self.mass_, - element=self.element_, - atom_type=None if not self.atom_type else self.atom_type.clone(), + charge=self.charge, + mass=self.mass, + element=self.element, + atom_type=property + if not self.atom_type + else self.atom_type.clone(), ) def __le__(self, other): @@ -140,7 +180,8 @@ def __lt__(self, other): f"Cannot compare equality between {type(self)} and {type(other)}" ) - @validator("charge_") + @field_validator("charge_") + @classmethod def is_valid_charge(cls, charge): """Ensure that the charge is physically meaningful.""" if charge is None: @@ -155,7 +196,8 @@ def is_valid_charge(cls, charge): return charge - @validator("mass_") + @field_validator("mass_") + @classmethod def is_valid_mass(cls, mass): """Ensure that the mass is physically meaningful.""" if mass is None: @@ -167,26 +209,3 @@ def is_valid_mass(cls, mass): else: ensure_valid_dimensions(mass, default_mass_units) return mass - - class Config: - """Pydantic configuration for the atom class.""" - - extra = "forbid" - - fields = { - "charge_": "charge", - "mass_": "mass", - "element_": "element", - "atom_type_": "atom_type", - "restraint_": "restraint", - } - - alias_to_fields = { - "charge": "charge_", - "mass": "mass_", - "element": "element_", - "atom_type": "atom_type_", - "restraint": "restraint_", - } - - validate_assignment = True diff --git a/gmso/core/atom_type.py b/gmso/core/atom_type.py index 7f9f326bb..89f3ac99f 100644 --- a/gmso/core/atom_type.py +++ b/gmso/core/atom_type.py @@ -1,9 +1,11 @@ """Support non-bonded interactions between sites.""" import warnings -from typing import Optional, Set +from typing import Optional, Set, Union import unyt as u +from pydantic import ConfigDict, Field, field_serializer, field_validator +from gmso.abc.serialization_utils import unyt_to_dict from gmso.core.parametric_potential import ParametricPotential from gmso.utils._constants import UNIT_WARNING_STRING from gmso.utils.expression import PotentialExpression @@ -13,11 +15,6 @@ unyt_to_hashable, ) -try: - from pydantic.v1 import Field, validator -except ImportError: - from pydantic import Field, validator - class AtomType(ParametricPotential): __base_doc__ = """A description of non-bonded interactions between sites. @@ -33,33 +30,55 @@ class AtomType(ParametricPotential): """ mass_: Optional[u.unyt_array] = Field( - 0.0 * u.gram / u.mol, description="The mass of the atom type" + 0.0 * u.gram / u.mol, + description="The mass of the atom type", + alias="mass", ) charge_: Optional[u.unyt_array] = Field( - 0.0 * u.elementary_charge, description="The charge of the atom type" + 0.0 * u.elementary_charge, + description="The charge of the atom type", + alias="charge", ) atomclass_: Optional[str] = Field( - "", description="The class of the atomtype" + "", description="The class of the atomtype", alias="atomclass" ) doi_: Optional[str] = Field( "", description="Digital Object Identifier of publication where this atom type was introduced", + alias="doi", ) overrides_: Optional[Set[str]] = Field( set(), description="Set of other atom types that this atom type overrides", + alias="overrides", ) definition_: Optional[str] = Field( - "", description="SMARTS string defining this atom type" + "", + description="SMARTS string defining this atom type", + alias="definition", ) description_: Optional[str] = Field( - "", description="Description for the AtomType" + "", description="Description for the AtomType", alias="description" + ) + model_config = ConfigDict( + alias_to_fields=dict( + **ParametricPotential.model_config["alias_to_fields"], + **{ + "mass": "mass_", + "charge": "charge_", + "atomclass": "atomclass_", + "doi": "doi_", + "overrides": "overrides_", + "definition": "definition_", + "description": "description_", + }, + ), ) def __init__( @@ -132,6 +151,20 @@ def definition(self): """Return the SMARTS string of the atom_type.""" return self.__dict__.get("definition_") + @field_serializer("charge_") + def serialize_charge(self, charge_: Union[u.unyt_quantity, None]): + if charge_ is None: + return None + else: + return unyt_to_dict(charge_) + + @field_serializer("mass_") + def serialize_mass(self, mass_: Union[u.unyt_quantity, None]): + if mass_ is None: + return None + else: + return unyt_to_dict(mass_) + def clone(self, fast_copy=False): """Clone this AtomType, faster alternative to deepcopying.""" return AtomType( @@ -140,18 +173,22 @@ def clone(self, fast_copy=False): expression=None, parameters=None, independent_variables=None, - potential_expression=self.potential_expression_.clone(fast_copy), - mass=u.unyt_quantity(self.mass_.value, self.mass_.units), - charge=u.unyt_quantity(self.charge_.value, self.charge_.units), - atomclass=self.atomclass_, - doi=self.doi_, - overrides=set(o for o in self.overrides_) - if self.overrides_ + potential_expression=self.potential_expression.clone(fast_copy), + mass=u.unyt_quantity(self.mass.value, self.mass.units), + charge=u.unyt_quantity(self.charge.value, self.charge.units), + atomclass=self.atomclass, + doi=self.doi, + overrides=set(o for o in self.overrides) + if self.overrides else None, - description=self.description_, - definition=self.definition_, + description=self.description, + definition=self.definition, ) + def __hash__(self): + """Return the unique hash of the object.""" + return id(self) + def __eq__(self, other): if other is self: return True @@ -178,6 +215,12 @@ def _etree_attrib(self): attrib = super()._etree_attrib() if self.overrides == set(): attrib.pop("overrides") + mass = eval(attrib["mass"]) + charge = eval(attrib["charge"]) + + attrib["mass"] = str(mass["array"]) + attrib["charge"] = str(charge["array"]) + return attrib def __repr__(self): @@ -190,7 +233,8 @@ def __repr__(self): ) return desc - @validator("mass_", pre=True) + @field_validator("mass_", mode="before") + @classmethod def validate_mass(cls, mass): """Check to see that a mass is a unyt array of the right dimension.""" default_mass_units = u.gram / u.mol @@ -202,7 +246,8 @@ def validate_mass(cls, mass): return mass - @validator("charge_", pre=True) + @field_validator("charge_", mode="before") + @classmethod def validate_charge(cls, charge): """Check to see that a charge is a unyt array of the right dimension.""" if not isinstance(charge, u.unyt_array): @@ -225,26 +270,3 @@ def _default_potential_expr(): "epsilon": 0.3 * u.Unit("kJ"), }, ) - - class Config: - """Pydantic configuration of the attributes for an atom_type.""" - - fields = { - "mass_": "mass", - "charge_": "charge", - "atomclass_": "atomclass", - "overrides_": "overrides", - "doi_": "doi", - "description_": "description", - "definition_": "definition", - } - - alias_to_fields = { - "mass": "mass_", - "charge": "charge_", - "atomclass": "atomclass_", - "overrides": "overrides_", - "doi": "doi_", - "description": "description_", - "definition": "definition_", - } diff --git a/gmso/core/bond.py b/gmso/core/bond.py index 7b4ddca83..f3fcea650 100644 --- a/gmso/core/bond.py +++ b/gmso/core/bond.py @@ -1,15 +1,12 @@ """Module for 2-partner connections between sites.""" from typing import Callable, ClassVar, Optional, Tuple +from pydantic import ConfigDict, Field + from gmso.abc.abstract_connection import Connection from gmso.core.atom import Atom from gmso.core.bond_type import BondType -try: - from pydantic.v1 import Field -except ImportError: - from pydantic import Field - class Bond(Connection): __base_doc__ = """A 2-partner connection between sites. @@ -24,14 +21,17 @@ class Bond(Connection): __eq__, __repr__, _validate methods. Additional _validate methods are presented. """ - __members_creator__: ClassVar[Callable] = Atom.parse_obj + __members_creator__: ClassVar[Callable] = Atom.model_validate connection_members_: Tuple[Atom, Atom] = Field( - ..., description="The 2 atoms involved in the bond." + ..., + description="The 2 atoms involved in the bond.", + alias="connection_members", ) - bond_type_: Optional[BondType] = Field( - default=None, description="BondType of this bond." + default=None, + description="BondType of this bond.", + alias="bond_type", ) restraint_: Optional[dict] = Field( default=None, @@ -41,6 +41,16 @@ class Bond(Connection): Refer to https://manual.gromacs.org/current/reference-manual/topologies/topology-file-formats.html for more information. """, + alias="restraint", + ) + model_config = ConfigDict( + alias_to_fields=dict( + **Connection.model_config["alias_to_fields"], + **{ + "bond_type": "bond_type_", + "restraint": "restraint_", + } + ) ) @property @@ -51,7 +61,6 @@ def bond_type(self): @property def connection_type(self): """Return parameters of the potential type.""" - # ToDo: Deprecate this? return self.__dict__.get("bond_type_") @property @@ -85,17 +94,3 @@ def __setattr__(self, key, value): super(Bond, self).__setattr__("bond_type", value) else: super(Bond, self).__setattr__(key, value) - - class Config: - """Pydantic configuration for Bond.""" - - fields = { - "bond_type_": "bond_type", - "connection_members_": "connection_members", - "restraint_": "restraint", - } - alias_to_fields = { - "bond_type": "bond_type_", - "connection_members": "connection_members_", - "restraint": "restraint_", - } diff --git a/gmso/core/bond_type.py b/gmso/core/bond_type.py index 98c08c099..13fb6e8e9 100644 --- a/gmso/core/bond_type.py +++ b/gmso/core/bond_type.py @@ -2,15 +2,11 @@ from typing import Optional, Tuple import unyt as u +from pydantic import ConfigDict, Field from gmso.core.parametric_potential import ParametricPotential from gmso.utils.expression import PotentialExpression -try: - from pydantic.v1 import Field -except ImportError: - from pydantic import Field - class BondType(ParametricPotential): __base_doc__ = """A descripton of the interaction between 2 bonded partners. @@ -33,12 +29,23 @@ class BondType(ParametricPotential): None, description="List-like of of gmso.AtomType.name " "defining the members of this bond type", + alias="member_types", ) member_classes_: Optional[Tuple[str, str]] = Field( None, description="List-like of of gmso.AtomType.atomclass " "defining the members of this bond type", + alias="member_classes", + ) + model_config = ConfigDict( + alias_to_fields=dict( + **ParametricPotential.model_config["alias_to_fields"], + **{ + "member_types": "member_types_", + "member_classes": "member_classes_", + }, + ), ) def __init__( @@ -82,16 +89,3 @@ def _default_potential_expr(): "r_eq": 0.14 * u.nm, }, ) - - class Config: - """Pydantic configuration for class attributes.""" - - fields = { - "member_types_": "member_types", - "member_classes_": "member_classes", - } - - alias_to_fields = { - "member_types": "member_types_", - "member_classes": "member_classes_", - } diff --git a/gmso/core/dihedral.py b/gmso/core/dihedral.py index 22c56fa05..219720259 100644 --- a/gmso/core/dihedral.py +++ b/gmso/core/dihedral.py @@ -1,14 +1,11 @@ from typing import Callable, ClassVar, Optional, Tuple +from pydantic import ConfigDict, Field + from gmso.abc.abstract_connection import Connection from gmso.core.atom import Atom from gmso.core.dihedral_type import DihedralType -try: - from pydantic.v1 import Field -except ImportError: - from pydantic import Field - class Dihedral(Connection): __base_doc__ = """A 4-partner connection between sites. @@ -28,14 +25,18 @@ class Dihedral(Connection): Additional _validate methods are presented """ - __members_creator__: ClassVar[Callable] = Atom.parse_obj + __members_creator__: ClassVar[Callable] = Atom.model_validate connection_members_: Tuple[Atom, Atom, Atom, Atom] = Field( - ..., description="The 4 atoms involved in the dihedral." + ..., + description="The 4 atoms involved in the dihedral.", + alias="connection_members", ) dihedral_type_: Optional[DihedralType] = Field( - default=None, description="DihedralType of this dihedral." + default=None, + description="DihedralType of this dihedral.", + alias="dihedral_type", ) restraint_: Optional[dict] = Field( @@ -46,6 +47,16 @@ class Dihedral(Connection): Refer to https://manual.gromacs.org/current/reference-manual/topologies/topology-file-formats.html for more information. """, + alias="restraint", + ) + model_config = ConfigDict( + alias_to_fields=dict( + **Connection.model_config["alias_to_fields"], + **{ + "dihedral_type": "dihedral_type_", + "restraint": "restraint_", + } + ) ) @property @@ -84,18 +95,6 @@ def equivalent_members(self): def __setattr__(self, key, value): if key == "connection_type": - super(Dihedral, self).__setattr__("dihedral_type", value) + super(Dihedral, self).__setattr__("dihedral_type_", value) else: super(Dihedral, self).__setattr__(key, value) - - class Config: - fields = { - "dihedral_type_": "dihedral_type", - "connection_members_": "connection_members", - "restraint_": "restraint", - } - alias_to_fields = { - "dihedral_type": "dihedral_type_", - "connection_members": "connection_members_", - "restraint": "restraint_", - } diff --git a/gmso/core/dihedral_type.py b/gmso/core/dihedral_type.py index 6096c4d46..32cebba3c 100644 --- a/gmso/core/dihedral_type.py +++ b/gmso/core/dihedral_type.py @@ -1,15 +1,11 @@ from typing import Optional, Tuple import unyt as u +from pydantic import ConfigDict, Field from gmso.core.parametric_potential import ParametricPotential from gmso.utils.expression import PotentialExpression -try: - from pydantic.v1 import Field -except ImportError: - from pydantic import Field - class DihedralType(ParametricPotential): __base_doc__ = """A descripton of the interaction between 4 bonded partners. @@ -38,12 +34,23 @@ class DihedralType(ParametricPotential): None, description="List-like of of gmso.AtomType.name " "defining the members of this dihedral type", + alias="member_types", ) member_classes_: Optional[Tuple[str, str, str, str]] = Field( None, description="List-like of of gmso.AtomType.atomclass defining the " "members of this dihedral type", + alias="member_classes", + ) + model_config = ConfigDict( + alias_to_fields=dict( + **ParametricPotential.model_config["alias_to_fields"], + **{ + "member_types": "member_types_", + "member_classes": "member_classes_", + }, + ), ) def __init__( @@ -87,14 +94,3 @@ def _default_potential_expr(): }, independent_variables={"phi"}, ) - - class Config: - fields = { - "member_types_": "member_types", - "member_classes_": "member_classes", - } - - alias_to_fields = { - "member_types": "member_types_", - "member_classes": "member_classes_", - } diff --git a/gmso/core/element.py b/gmso/core/element.py index ccb3fed11..9987a2a53 100644 --- a/gmso/core/element.py +++ b/gmso/core/element.py @@ -2,20 +2,18 @@ import json import warnings from re import sub +from typing import Union import numpy as np import unyt as u from pkg_resources import resource_filename +from pydantic import ConfigDict, Field, field_serializer from gmso.abc.gmso_base import GMSOBase +from gmso.abc.serialization_utils import unyt_to_dict from gmso.exceptions import GMSOError from gmso.utils.misc import unyt_to_hashable -try: - from pydantic.v1 import Field -except ImportError: - from pydantic import Field - exported = [ "element_by_mass", "element_by_symbol", @@ -42,6 +40,13 @@ class Element(GMSOBase): mass: u.unyt_quantity = Field(..., description="Mass of the element.") + @field_serializer("mass") + def serialize_mass(self, mass: Union[u.unyt_quantity, None]): + if mass is None: + return None + else: + return unyt_to_dict(mass) + def __repr__(self): """Representation of the element.""" return ( @@ -61,11 +66,7 @@ def __eq__(self, other): and self.atomic_number == other.atomic_number ) - class Config: - """Pydantic configuration for element.""" - - arbitrary_types_allowed = True - allow_mutation = False + model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) def element_by_symbol(symbol, verbose=False): diff --git a/gmso/core/forcefield.py b/gmso/core/forcefield.py index 691cb1788..ebf93ccc6 100644 --- a/gmso/core/forcefield.py +++ b/gmso/core/forcefield.py @@ -598,7 +598,7 @@ def to_xml(self, filename, overwrite=False, backend="gmso"): Can be "gmso" or "forcefield-utilities". This will define the methods to write the xml. """ - if backend == "gmso" or backend == "GMSO": + if backend.lower() == "gmso": self._xml_from_gmso(filename, overwrite) elif backend in [ "forcefield_utilities", diff --git a/gmso/core/improper.py b/gmso/core/improper.py index 15c88ee98..93f15f030 100644 --- a/gmso/core/improper.py +++ b/gmso/core/improper.py @@ -1,15 +1,12 @@ """Support for improper style connections (4-member connection).""" from typing import Callable, ClassVar, Optional, Tuple +from pydantic import ConfigDict, Field + from gmso.abc.abstract_connection import Connection from gmso.core.atom import Atom from gmso.core.improper_type import ImproperType -try: - from pydantic.v1 import Field -except ImportError: - from pydantic import Field - class Improper(Connection): __base_doc__ = """sA 4-partner connection between sites. @@ -34,16 +31,27 @@ class Improper(Connection): Additional _validate methods are presented """ - __members_creator__: ClassVar[Callable] = Atom.parse_obj + __members_creator__: ClassVar[Callable] = Atom.model_validate connection_members_: Tuple[Atom, Atom, Atom, Atom] = Field( ..., description="The 4 atoms of this improper. Central atom first, " "then the three atoms connected to the central site.", + alias="connection_members", ) improper_type_: Optional[ImproperType] = Field( - default=None, description="ImproperType of this improper." + default=None, + description="ImproperType of this improper.", + alias="improper_type", + ) + model_config = ConfigDict( + alias_to_fields=dict( + **Connection.model_config["alias_to_fields"], + **{ + "improper_type": "improper_type_", + } + ) ) @property @@ -85,18 +93,6 @@ def equivalent_members(self): def __setattr__(self, key, value): """Set attribute override to support connection_type key.""" if key == "connection_type": - super(Improper, self).__setattr__("improper_type", value) + super(Improper, self).__setattr__("improper_type_", value) else: super(Improper, self).__setattr__(key, value) - - class Config: - """Pydantic configuration to link fields to their public attribute.""" - - fields = { - "improper_type_": "improper_type", - "connection_members_": "connection_members", - } - alias_to_fields = { - "improper_type": "improper_type_", - "connection_members": "connection_members_", - } diff --git a/gmso/core/improper_type.py b/gmso/core/improper_type.py index 3cb3ebb59..5dc35c968 100644 --- a/gmso/core/improper_type.py +++ b/gmso/core/improper_type.py @@ -2,15 +2,11 @@ from typing import Optional, Tuple import unyt as u +from pydantic import ConfigDict, Field from gmso.core.parametric_potential import ParametricPotential from gmso.utils.expression import PotentialExpression -try: - from pydantic.v1 import Field -except ImportError: - from pydantic import Field - class ImproperType(ParametricPotential): __base_doc__ = """A description of the interaction between 4 bonded partners. @@ -44,12 +40,23 @@ class ImproperType(ParametricPotential): None, description="List-like of gmso.AtomType.name " "defining the members of this improper type", + alias="member_types", ) member_classes_: Optional[Tuple[str, str, str, str]] = Field( None, description="List-like of gmso.AtomType.atomclass " "defining the members of this improper type", + alias="member_classes", + ) + model_config = ConfigDict( + alias_to_fields=dict( + **ParametricPotential.model_config["alias_to_fields"], + **{ + "member_types": "member_types_", + "member_classes": "member_classes_", + }, + ), ) def __init__( @@ -93,16 +100,3 @@ def _default_potential_expr(): }, independent_variables={"phi"}, ) - - class Config: - """Pydantic configuration for attributes.""" - - fields = { - "member_types_": "member_types", - "member_classes_": "member_classes", - } - - alias_to_fields = { - "member_types": "member_types_", - "member_classes": "member_classes_", - } diff --git a/gmso/core/pairpotential_type.py b/gmso/core/pairpotential_type.py index 38b992396..8791dbb2a 100644 --- a/gmso/core/pairpotential_type.py +++ b/gmso/core/pairpotential_type.py @@ -1,15 +1,11 @@ from typing import Optional, Tuple import unyt as u +from pydantic import ConfigDict, Field from gmso.core.parametric_potential import ParametricPotential from gmso.utils.expression import PotentialExpression -try: - from pydantic.v1 import Field -except ImportError: - from pydantic import Field - class PairPotentialType(ParametricPotential): __base_doc__ = """A description of custom pairwise potential between 2 AtomTypes that does not follow combination rule. @@ -34,6 +30,15 @@ class PairPotentialType(ParametricPotential): None, description="List-like of strs, referring to gmso.Atomtype.name or gmso.Atomtype.atomclass, " "defining the members of this pair potential type", + alias="member_types", + ) + model_config = ConfigDict( + alias_to_fields=dict( + **ParametricPotential.model_config["alias_to_fields"], + **{ + "member_types": "member_types_", + }, + ), ) def __init__( @@ -67,8 +72,3 @@ def _default_potential_expr(): independent_variables={"r"}, parameters={"eps": 1 * u.Unit("kJ / mol"), "sigma": 1 * u.nm}, ) - - class Config: - fields = {"member_types_": "member_types"} - - alias_to_fields = {"member_types": "member_types_"} diff --git a/gmso/core/parametric_potential.py b/gmso/core/parametric_potential.py index fa7b56af4..9b151b5b8 100644 --- a/gmso/core/parametric_potential.py +++ b/gmso/core/parametric_potential.py @@ -3,6 +3,7 @@ import unyt as u from lxml import etree +from pydantic import ConfigDict from gmso.abc.abstract_potential import AbstractPotential from gmso.utils.expression import PotentialExpression @@ -21,6 +22,14 @@ class ParametricPotential(AbstractPotential): by classes that represent these potentials. """ + model_config = ConfigDict( + alias_to_fields=dict( + **AbstractPotential.model_config["alias_to_fields"], + **{"topology": "topology_", "set_ref": "set_ref_"}, + ), + validate_assignment=True, + ) + def __init__( self, name="ParametricPotential", @@ -90,7 +99,7 @@ def _default_potential_expr(): @property def parameters(self): """Optional[dict]\n\tThe parameters of the `Potential` expression and their corresponding values, as `unyt` quantities""" - return self.potential_expression_.parameters + return self.potential_expression.parameters def __setattr__(self, key: Any, value: Any) -> None: """Set the attributes of the potential.""" @@ -126,13 +135,12 @@ def set_expression( parameters=parameters, ) - def dict( + def model_dump( self, *, include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, by_alias: bool = False, - skip_defaults: bool = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, @@ -144,16 +152,19 @@ def dict( exclude = exclude.union({"topology_", "set_ref_"}) - return super().dict( + return super().model_dump( include=include, exclude=exclude, by_alias=True, - skip_defaults=skip_defaults, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) + def __hash__(self): + """Return the unique hash of the object.""" + return id(self) + def __eq__(self, other): if other is self: return True @@ -184,7 +195,7 @@ def get_parameters(self, copy=False): def clone(self, fast_copy=False): """Clone this parametric potential, faster alternative to deepcopying.""" Creator = self.__class__ - kwargs = {"tags": deepcopy(self.tags_)} + kwargs = {"tags": deepcopy(self.tags)} if hasattr(self, "member_classes"): kwargs["member_classes"] = ( copy(self.member_classes) if self.member_classes else None @@ -197,7 +208,7 @@ def clone(self, fast_copy=False): return Creator( name=self.name, - potential_expression=self.potential_expression_.clone(fast_copy), + potential_expression=self.potential_expression.clone(fast_copy), **kwargs, ) @@ -205,7 +216,7 @@ def _etree_attrib(self): """Return the XML equivalent representation of this ParametricPotential""" attrib = { key: get_xml_representation(value) - for key, value in self.dict( + for key, value in self.model_dump( by_alias=True, exclude_none=True, exclude={ @@ -333,10 +344,3 @@ def __repr__(self): f"member types: {member_types(self)}>", ) return desc - - class Config: - """Pydantic configuration class.""" - - fields = {"topology_": "topology", "set_ref_": "set_ref"} - alias_to_fields = {"topology": "topology_"} - validate_assignment = True diff --git a/gmso/core/topology.py b/gmso/core/topology.py index 814c06787..b7f176556 100644 --- a/gmso/core/topology.py +++ b/gmso/core/topology.py @@ -1357,7 +1357,7 @@ def get_forcefield(self): } for atom_type in self.atom_types: ff.atom_types[atom_type.name] = atom_type.copy( - deep=True, exclude={"topology_", "set_ref_"} + deep=True, exclude={"topology", "set_ref"} ) ff_conn_types = { @@ -1371,7 +1371,7 @@ def get_forcefield(self): ff_conn_types[type(connection_type)][ FF_TOKENS_SEPARATOR.join(connection_type.member_types) ] = connection_type.copy( - deep=True, exclude={"topology_", "set_ref_"} + deep=True, exclude={"topology", "set_ref"} ) return ff diff --git a/gmso/external/convert_parmed.py b/gmso/external/convert_parmed.py index 7b89812c7..55431cc15 100644 --- a/gmso/external/convert_parmed.py +++ b/gmso/external/convert_parmed.py @@ -351,22 +351,28 @@ def _sort_improper_members(top, site_map, atom1, atom2, atom3, atom4): def _add_conn_type_from_pmd( connStr, pmd_conn, gmso_conn, conn_params, name, expression, variables ): - """Convert ParmEd dihedral types to GMSO DihedralType. + """Create a GMSO connection type and add to the conneciton object. + + This function creates the connection type object and add it to the + connection object provided. - This function take in a Parmed Structure, iterate through its - dihedral_types, create a corresponding - GMSO.DihedralType, and finally return a dictionary containing all - pairs of pmd.Dihedraltype and GMSO.DihedralType Parameters ---------- - structure: pmd.Structure - Parmed Structure that needed to be converted. - - Returns - ------- - pmd_top_dihedraltypes : dict - A dictionary linking a pmd.DihedralType - object to its corresponding GMSO.DihedralType object. + connStr : str + The name of the connection type. Accepted values include + "BondType", "AngleType", "DihedralType", and "ImproperType". + pmd_conn : pmd.Bond/Angle/Dihedral/Improper + The parmed connection object. + gmso_conn : gmso.Bond/Angle/Dihedral/Improper + The GMSO connection object. + conn_params : dict + The potential expression parameters in dictionary form. + name : str + Name of the potential form. + expression : expression + The potential expression form. + variables : dict + The independent variables. """ try: member_types = list( diff --git a/gmso/formats/gro.py b/gmso/formats/gro.py index 76d57ca2f..3c2c07aeb 100644 --- a/gmso/formats/gro.py +++ b/gmso/formats/gro.py @@ -159,8 +159,7 @@ def _validate_positions(pos_array): "in order to ensure all coordinates are non-negative." ) min_xyz = np.min(pos_array, axis=0) - unit = min_xyz.units - min_xyz0 = np.where(min_xyz < 0 * unit, min_xyz, 0 * unit) + min_xyz0 = np.where(min_xyz < 0 * min_xyz.units, min_xyz, 0 * min_xyz.units) pos_array -= min_xyz0 diff --git a/gmso/formats/json.py b/gmso/formats/json.py index c3698ce6c..091f27d51 100644 --- a/gmso/formats/json.py +++ b/gmso/formats/json.py @@ -181,7 +181,7 @@ def _from_json(json_dict): id_to_type_map = {} for atom_dict in json_dict["atoms"]: atom_type_id = atom_dict.pop("atom_type", None) - atom = Atom.parse_obj(atom_dict) + atom = Atom.model_validate(atom_dict) top.add_site(atom) if atom_type_id: if not id_to_type_map.get(atom_type_id): @@ -194,7 +194,7 @@ def _from_json(json_dict): top._sites[member_idx] for member_idx in bond_dict["connection_members"] ] - bond = Bond.parse_obj(bond_dict) + bond = Bond.model_validate(bond_dict) top.add_connection(bond) if bond_type_id: if not id_to_type_map.get(bond_type_id): @@ -207,7 +207,7 @@ def _from_json(json_dict): top._sites[member_idx] for member_idx in angle_dict["connection_members"] ] - angle = Angle.parse_obj(angle_dict) + angle = Angle.model_validate(angle_dict) top.add_connection(angle) if angle_type_id: if not id_to_type_map.get(angle_type_id): @@ -220,7 +220,7 @@ def _from_json(json_dict): top._sites[member_idx] for member_idx in dihedral_dict["connection_members"] ] - dihedral = Dihedral.parse_obj(dihedral_dict) + dihedral = Dihedral.model_validate(dihedral_dict) top.add_connection(dihedral) if dihedral_type_id: if not id_to_type_map.get(dihedral_type_id): @@ -233,7 +233,7 @@ def _from_json(json_dict): top._sites[member_idx] for member_idx in improper_dict["connection_members"] ] - improper = Improper.parse_obj(improper_dict) + improper = Improper.model_validate(improper_dict) if improper_type_id: if not id_to_type_map.get(improper_type_id): id_to_type_map[improper_type_id] = [] @@ -241,7 +241,7 @@ def _from_json(json_dict): for atom_type_dict in json_dict["atom_types"]: atom_type_id = atom_type_dict.pop("id", None) - atom_type = AtomType.parse_obj(atom_type_dict) + atom_type = AtomType.model_validate(atom_type_dict) if atom_type_id in id_to_type_map: for associated_atom in id_to_type_map[atom_type_id]: associated_atom.atom_type = atom_type @@ -254,7 +254,7 @@ def _from_json(json_dict): ]: for connection_type_dict in connection_types: connection_type_id = connection_type_dict.pop("id") - connection_type = Creator.parse_obj(connection_type_dict) + connection_type = Creator.model_validate(connection_type_dict) if connection_type_id in id_to_type_map: for associated_connection in id_to_type_map[connection_type_id]: setattr(associated_connection, attr, connection_type) @@ -273,7 +273,7 @@ def _from_json(json_dict): # AtomTypes need to be updated for pairpotentialtype addition for pair_potentialtype_dict in json_dict["pair_potentialtypes"]: - pair_potentialtype = PairPotentialType.parse_obj( + pair_potentialtype = PairPotentialType.model_validate( pair_potentialtype_dict ) top.add_pairpotentialtype(pair_potentialtype, update=False) diff --git a/gmso/lib/potential_templates.py b/gmso/lib/potential_templates.py index a3ad09e4b..71804d8f8 100644 --- a/gmso/lib/potential_templates.py +++ b/gmso/lib/potential_templates.py @@ -5,6 +5,7 @@ import sympy import unyt as u +from pydantic import ConfigDict, Field, field_validator from gmso.abc.abstract_potential import AbstractPotential from gmso.exceptions import ( @@ -15,11 +16,6 @@ from gmso.utils.expression import PotentialExpression from gmso.utils.singleton import Singleton -try: - from pydantic.v1 import Field, validator -except ImportError: - from pydantic import Field, validator - POTENTIAL_JSONS = list(Path(__file__).parent.glob("jsons/*.json")) JSON_DIR = Path.joinpath(Path(__file__).parent, "jsons") @@ -53,7 +49,19 @@ class PotentialTemplate(AbstractPotential): """Template for potential objects to be re-used.""" expected_parameters_dimensions_: Dict[str, sympy.Expr] = Field( - ..., description="The expected dimensions for parameters." + ..., + description="The expected dimensions for parameters.", + alias="expected_parameters_dimensions", + ) + + model_config = ConfigDict( + frozen=True, + alias_to_fields=dict( + **AbstractPotential.model_config["alias_to_fields"], + **{ + "expected_parameters_dimensions": "expected_parameters_dimensions_", + }, + ), ) def __init__( @@ -81,7 +89,7 @@ def __init__( expected_parameters_dimensions=expected_parameters_dimensions, ) - @validator("expected_parameters_dimensions_", pre=True, always=True) + @field_validator("expected_parameters_dimensions_", mode="before") def validate_expected_parameters(cls, dim_dict): """Validate the expected parameters and dimensions for this template.""" if not isinstance(dim_dict, Dict): @@ -149,17 +157,6 @@ def assert_can_parameterize_with( f"parameters: {parameters}" ) - class Config: - """Pydantic configuration for potential template.""" - - allow_mutation = False - fields = { - "expected_parameters_dimensions_": "expected_parameters_dimensions" - } - alias_to_fields = { - "expected_parameters_dimensions": "expected_parameters_dimensions_" - } - class PotentialTemplateLibrary(Singleton): """A singleton collection of all the potential templates.""" diff --git a/gmso/parameterization/parameterize.py b/gmso/parameterization/parameterize.py index b8e58de27..851dfd26e 100644 --- a/gmso/parameterization/parameterize.py +++ b/gmso/parameterization/parameterize.py @@ -73,7 +73,7 @@ def apply( necessary post parameterization. """ ignore_params = set([option.lower() for option in ignore_params]) - config = TopologyParameterizationConfig.parse_obj( + config = TopologyParameterizationConfig.model_validate( dict( match_ff_by=match_ff_by, identify_connections=identify_connections, diff --git a/gmso/tests/abc/test_serialization_utils.py b/gmso/tests/abc/test_serialization_utils.py index 312af8949..ccebabcfe 100644 --- a/gmso/tests/abc/test_serialization_utils.py +++ b/gmso/tests/abc/test_serialization_utils.py @@ -1,7 +1,7 @@ import pytest import unyt as u -from gmso.abc.serialization_utils import JSONHandler, dict_to_unyt, unyt_to_dict +from gmso.abc.serialization_utils import dict_to_unyt, unyt_to_dict from gmso.tests.base_test import BaseTest @@ -52,9 +52,3 @@ def test_dict_to_unyt_nested(self): unyt_dict["level1"]["level2"]["my_quantity"], u.unyt_array([-10, 2.5, 100], u.C), ) - - def test_json_handler(self): - handler = JSONHandler() - handler.register(int, str, override=False) - with pytest.warns(UserWarning): - handler.register(int, str, override=True) diff --git a/gmso/tests/base_test.py b/gmso/tests/base_test.py index 383bcfd76..55dc41388 100644 --- a/gmso/tests/base_test.py +++ b/gmso/tests/base_test.py @@ -386,8 +386,10 @@ def test_atom_equality(atom1, atom2): if isinstance(x1, u.unyt_array) and isinstance(x2, u.unyt_array) else x1 == x2 ) - for prop in atom1.dict(by_alias=True): - if not equal(atom2.dict().get(prop), atom1.dict().get(prop)): + for prop in atom1.model_dump(by_alias=True): + if not equal( + atom2.model_dump().get(prop), atom1.model_dump().get(prop) + ): return False return True diff --git a/gmso/tests/test_angle.py b/gmso/tests/test_angle.py index 25693cebe..96739d9ee 100644 --- a/gmso/tests/test_angle.py +++ b/gmso/tests/test_angle.py @@ -1,4 +1,5 @@ import pytest +from pydantic import ValidationError from gmso.core.angle import Angle from gmso.core.angle_type import AngleType @@ -7,11 +8,6 @@ from gmso.core.topology import Topology from gmso.tests.base_test import BaseTest -try: - from pydantic.v1 import ValidationError -except ImportError: - from pydantic import ValidationError - class TestAngle(BaseTest): def test_angle_nonparametrized(self): @@ -43,7 +39,7 @@ def test_angle_fake(self): atom1 = Atom(name="atom1") atom2 = Atom(name="atom2") atom3 = Atom(name="atom3") - with pytest.raises(ValidationError): + with pytest.raises(TypeError): Angle(connection_members=["fakesite1", "fakesite2", 4.2]) def test_angle_fake_angletype(self): diff --git a/gmso/tests/test_atom.py b/gmso/tests/test_atom.py index db58241f2..34820d49c 100644 --- a/gmso/tests/test_atom.py +++ b/gmso/tests/test_atom.py @@ -1,6 +1,7 @@ import numpy as np import pytest import unyt as u +from pydantic import ValidationError from gmso.core.atom import Atom from gmso.core.atom_type import AtomType @@ -8,11 +9,6 @@ from gmso.exceptions import GMSOError from gmso.tests.base_test import BaseTest -try: - from pydantic.v1 import ValidationError -except ImportError: - from pydantic import ValidationError - class TestSite(BaseTest): def test_new_site(self): @@ -28,7 +24,7 @@ def test_dtype(self): assert isinstance(atom.position, np.ndarray) def test_name_none(self): - atom = Atom(name=None) + atom = Atom() assert atom.name == "Atom" def test_setters_and_getters(self): diff --git a/gmso/tests/test_atom_type.py b/gmso/tests/test_atom_type.py index b6e6c0088..0be6c5c7d 100644 --- a/gmso/tests/test_atom_type.py +++ b/gmso/tests/test_atom_type.py @@ -362,7 +362,7 @@ def test_metadata_delete_tags(self, atomtype_metadata): def test_atom_type_dict(self): atype = AtomType() - atype_dict = atype.dict(exclude={"potential_expression"}) + atype_dict = atype.model_dump(exclude={"potential_expression"}) assert "potential_expression" not in atype_dict assert "charge" in atype_dict @@ -393,8 +393,10 @@ def test_atom_type_clone(self): assert len(top.atom_types) == 2 - atype_dict = atype.dict(exclude={"topology", "set_ref"}) - atype_clone_dict = atype_clone.dict(exclude={"topology", "set_ref"}) + atype_dict = atype.model_dump(exclude={"topology", "set_ref"}) + atype_clone_dict = atype_clone.model_dump( + exclude={"topology", "set_ref"} + ) for key, value in atype_dict.items(): cloned = atype_clone_dict[key] diff --git a/gmso/tests/test_bond.py b/gmso/tests/test_bond.py index 1be149773..59cbeddc0 100644 --- a/gmso/tests/test_bond.py +++ b/gmso/tests/test_bond.py @@ -1,4 +1,5 @@ import pytest +from pydantic import ValidationError from gmso.core.atom import Atom from gmso.core.atom_type import AtomType @@ -7,11 +8,6 @@ from gmso.core.topology import Topology from gmso.tests.base_test import BaseTest -try: - from pydantic.v1 import ValidationError -except ImportError: - from pydantic import ValidationError - class TestBond(BaseTest): def test_bond_nonparametrized(self): @@ -41,7 +37,7 @@ def test_bond_parametrized(self): def test_bond_fake(self): atom1 = Atom(name="atom1") atom2 = Atom(name="atom2") - with pytest.raises(ValidationError): + with pytest.raises(TypeError): Bond(connection_members=["fakeatom1", "fakeatom2"]) def test_bond_fake_bondtype(self): diff --git a/gmso/tests/test_dihedral.py b/gmso/tests/test_dihedral.py index 13e2ad45d..10ddede53 100644 --- a/gmso/tests/test_dihedral.py +++ b/gmso/tests/test_dihedral.py @@ -1,4 +1,5 @@ import pytest +from pydantic import ValidationError from gmso.core.atom import Atom from gmso.core.atom_type import AtomType @@ -7,11 +8,6 @@ from gmso.core.topology import Topology from gmso.tests.base_test import BaseTest -try: - from pydantic.v1 import ValidationError -except ImportError: - from pydantic import ValidationError - class TestDihedral(BaseTest): def test_dihedral_nonparametrized(self): @@ -48,7 +44,7 @@ def test_dihedral_fake(self): atom2 = Atom(name="atom2") atom3 = Atom(name="atom3") atom4 = Atom(name="atom4") - with pytest.raises(ValidationError): + with pytest.raises(TypeError): Dihedral(connection_members=["fakeatom1", "fakeatom2", 4.2]) def test_dihedral_fake_dihedraltype(self): diff --git a/gmso/tests/test_improper.py b/gmso/tests/test_improper.py index 4188b444a..f72e488d9 100644 --- a/gmso/tests/test_improper.py +++ b/gmso/tests/test_improper.py @@ -1,4 +1,5 @@ import pytest +from pydantic import ValidationError from gmso.core.atom import Atom from gmso.core.atom_type import AtomType @@ -7,11 +8,6 @@ from gmso.core.topology import Topology from gmso.tests.base_test import BaseTest -try: - from pydantic.v1 import ValidationError -except ImportError: - from pydantic import ValidationError - class TestImproper(BaseTest): def test_improper_nonparametrized(self): @@ -48,7 +44,7 @@ def test_improper_fake(self): atom2 = Atom(name="atom2") atom3 = Atom(name="atom3") atom4 = Atom(name="atom4") - with pytest.raises(ValidationError): + with pytest.raises(TypeError): Improper(connection_members=["fakeatom1", "fakeatom2", 4.2]) def test_improper_fake_impropertype(self): diff --git a/gmso/tests/test_potential.py b/gmso/tests/test_potential.py index 2c9f9d2c2..346cac958 100644 --- a/gmso/tests/test_potential.py +++ b/gmso/tests/test_potential.py @@ -276,8 +276,10 @@ def test_bondtype_clone(self): assert len(top.bond_types) == 2 - btype_dict = btype.dict(exclude={"topology", "set_ref"}) - btype_clone_dict = btype_clone.dict(exclude={"topology", "set_ref"}) + btype_dict = btype.model_dump(exclude={"topology", "set_ref"}) + btype_clone_dict = btype_clone.model_dump( + exclude={"topology", "set_ref"} + ) for key, value in btype_dict.items(): cloned = btype_clone_dict[key] diff --git a/gmso/tests/test_serialization.py b/gmso/tests/test_serialization.py index 6a6298a46..fa508e872 100644 --- a/gmso/tests/test_serialization.py +++ b/gmso/tests/test_serialization.py @@ -1,3 +1,5 @@ +import json + import pytest import unyt as u @@ -43,21 +45,21 @@ def full_atom_type(self): def test_atom_to_json_loop(self, typed_ethane, are_equivalent_atoms): atoms_to_test = typed_ethane.sites for atom in atoms_to_test: - atom_json = atom.json() - atom_copy = Atom.parse_raw(atom_json) + atom_json = atom.model_dump_json() + atom_copy = Atom.model_validate(json.loads(atom_json)) assert are_equivalent_atoms(atom, atom_copy) def test_atom_types_to_json_loop(self, typed_ethane): atom_types_to_test = typed_ethane.atom_types for atom_type in atom_types_to_test: - atom_type_json = atom_type.json() - atom_type_copy = AtomType.parse_raw(atom_type_json) + atom_type_json = atom_type.model_dump_json() + atom_type_copy = AtomType.model_validate(json.loads(atom_type_json)) assert atom_type_copy == atom_type def test_bond_to_json_loop(self, typed_ethane, are_equivalent_atoms): for bond in typed_ethane.bonds: - bond_json = bond.json() - bond_copy = Bond.parse_raw(bond_json) + bond_json = bond.model_dump_json() + bond_copy = Bond.model_validate(json.loads(bond_json)) assert bond_copy.name == bond.name for member1, member2 in zip( bond.connection_members, bond_copy.connection_members @@ -68,14 +70,14 @@ def test_bond_to_json_loop(self, typed_ethane, are_equivalent_atoms): def test_bond_type_to_json_loop(self, typed_ethane): bond_types_to_test = typed_ethane.bond_types for bond_type in bond_types_to_test: - bond_type_json = bond_type.json() - bond_type_copy = BondType.parse_raw(bond_type_json) + bond_type_json = bond_type.model_dump_json() + bond_type_copy = BondType.model_validate(json.loads(bond_type_json)) assert bond_type_copy == bond_type def test_angle_to_json_loop(self, typed_ethane, are_equivalent_atoms): for angle in typed_ethane.angles: - angle_json = angle.json() - angle_copy = Angle.parse_raw(angle_json) + angle_json = angle.model_dump_json() + angle_copy = Angle.model_validate(json.loads(angle_json)) for member1, member2 in zip( angle.connection_members, angle_copy.connection_members ): @@ -85,14 +87,16 @@ def test_angle_to_json_loop(self, typed_ethane, are_equivalent_atoms): def test_angle_type_to_json_loop(self, typed_ethane): angle_types_to_test = typed_ethane.angle_types for angle_type in angle_types_to_test: - angle_type_json = angle_type.json() - angle_type_copy = AngleType.parse_raw(angle_type_json) + angle_type_json = angle_type.model_dump_json() + angle_type_copy = AngleType.model_validate( + json.loads(angle_type_json) + ) assert angle_type_copy == angle_type def test_dihedral_to_json_loop(self, typed_ethane, are_equivalent_atoms): for dihedral in typed_ethane.dihedrals: - dihedral_json = dihedral.json() - dihedral_copy = Dihedral.parse_raw(dihedral_json) + dihedral_json = dihedral.model_dump_json() + dihedral_copy = Dihedral.model_validate(json.loads(dihedral_json)) for member1, member2 in zip( dihedral.connection_members, dihedral_copy.connection_members ): @@ -102,14 +106,16 @@ def test_dihedral_to_json_loop(self, typed_ethane, are_equivalent_atoms): def test_dihedral_types_to_json_loop(self, typed_ethane): dihedral_types_to_test = typed_ethane.dihedral_types for dihedral_type in dihedral_types_to_test: - dihedral_type_json = dihedral_type.json() - dihedral_type_copy = DihedralType.parse_raw(dihedral_type_json) + dihedral_type_json = dihedral_type.model_dump_json() + dihedral_type_copy = DihedralType.model_validate( + json.loads(dihedral_type_json) + ) assert dihedral_type_copy == dihedral_type def test_improper_to_json_loop(self, typed_ethane, are_equivalent_atoms): for improper in typed_ethane.impropers: - improper_json = improper.json() - improper_copy = Improper.parse_raw(improper_json) + improper_json = improper.model_dump_json() + improper_copy = Improper.model_validate(json.loads(improper_json)) for member1, member2 in zip( improper_copy.connection_members, improper.connection_members ): @@ -119,8 +125,10 @@ def test_improper_to_json_loop(self, typed_ethane, are_equivalent_atoms): def test_improper_types_to_json_loop(self, typed_ethane): improper_types_to_test = typed_ethane.improper_types for improper_type in improper_types_to_test: - improper_type_json = improper_type.json() - improper_type_copy = ImproperType.parse_raw(improper_type_json) + improper_type_json = improper_type.model_dump_json() + improper_type_copy = ImproperType.model_validate( + json.loads(improper_type_json) + ) improper_type_copy.topology = improper_type.topology assert improper_type_copy == improper_type @@ -135,7 +143,7 @@ def test_atom_every_field_set(self, full_atom_type, are_equivalent_atoms): atom_type=full_atom_type, ) - atom_copy = Atom.parse_raw(atom.json()) + atom_copy = Atom.model_validate(atom.json_dict()) assert are_equivalent_atoms(atom, atom_copy) def test_bond_every_field_set(self, full_atom_type, are_equivalent_atoms): @@ -168,7 +176,7 @@ def test_bond_every_field_set(self, full_atom_type, are_equivalent_atoms): independent_variables={"c"}, ) - bond_copy = Bond.parse_raw(bond.json()) + bond_copy = Bond.model_validate(bond.json_dict()) assert bond_copy.name == bond.name for member1, member2 in zip( bond.connection_members, bond_copy.connection_members @@ -178,13 +186,13 @@ def test_bond_every_field_set(self, full_atom_type, are_equivalent_atoms): def test_include_and_exclude(self): atom = Atom(mass=2.0 * u.g / u.mol, charge=30.0 * u.C, name="TestAtom") - atom_json = atom.json(exclude={"mass"}) + atom_json = atom.model_dump_json(exclude={"mass"}) assert "mass" not in atom_json - atom_json = atom.json(exclude={"mass_"}) + atom_json = atom.model_dump_json(exclude={"mass_"}) assert "mass" not in atom_json - atom_json = atom.json(include={"mass"}) + atom_json = atom.model_dump_json(include={"mass"}) assert "name" not in atom_json - atom_json = atom.json(include={"mass_"}) + atom_json = atom.model_dump_json(include={"mass_"}) assert "name" not in atom_json def test_full_serialization( diff --git a/gmso/tests/test_views.py b/gmso/tests/test_views.py index 5ccd783ea..addd6f2b3 100644 --- a/gmso/tests/test_views.py +++ b/gmso/tests/test_views.py @@ -53,7 +53,7 @@ def custom_top(self): custom_top.sites[j], custom_top.sites[j + 1], ], - bond_type=BondType(member_classes=(j, j + 1)), + bond_type=BondType(member_classes=(str(j), str(j + 1))), ) custom_top.add_connection(bond) diff --git a/gmso/utils/decorators.py b/gmso/utils/decorators.py index 7c0314386..6df98bee0 100644 --- a/gmso/utils/decorators.py +++ b/gmso/utils/decorators.py @@ -2,21 +2,6 @@ import functools import warnings -from gmso.abc import GMSOJSONHandler - - -class register_pydantic_json(object): - """Provides a way to register json encoders for a non-JSON serializable class.""" - - def __init__(self, method="json"): - self.method = method - - def __call__(self, cls): - """Register this class's json encoder method to GMSOJSONHandler.""" - json_method = getattr(cls, self.method) - GMSOJSONHandler.register(cls, json_method) - return cls - def deprecate_kwargs(deprecated_kwargs=None): """Decorate functions with deprecated/deprecating kwargs.""" diff --git a/gmso/utils/expression.py b/gmso/utils/expression.py index 590040c70..73b492346 100644 --- a/gmso/utils/expression.py +++ b/gmso/utils/expression.py @@ -7,8 +7,6 @@ import sympy import unyt as u -from gmso.utils.decorators import register_pydantic_json - __all__ = ["PotentialExpression"] @@ -30,7 +28,6 @@ def _are_equal_parameters(u1, u2): return True -@register_pydantic_json(method="json") class PotentialExpression: """A general Expression class with parameters.