From 612dd28c1d7bda95ae059615f035690b66e3194f Mon Sep 17 00:00:00 2001 From: Kiran Jonnalagadda Date: Mon, 6 Nov 2023 23:35:26 +0530 Subject: [PATCH] IdMixin's variable int/UUID primary key is now compatible with type hinting --- pyproject.toml | 2 +- src/coaster/sqlalchemy/mixins.py | 99 ++++++++++++++++--- tests/coaster_tests/sqlalchemy_models_test.py | 36 +++---- 3 files changed, 101 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7e5f9f3c..06e3f9a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ 'sqlalchemy-utils', 'SQLAlchemy>=2.0.4', 'tldextract', - 'typing_extensions', + 'typing_extensions>=4.8.0', 'Unidecode', 'werkzeug', ] diff --git a/src/coaster/sqlalchemy/mixins.py b/src/coaster/sqlalchemy/mixins.py index ff78ad9b..8eb95edf 100644 --- a/src/coaster/sqlalchemy/mixins.py +++ b/src/coaster/sqlalchemy/mixins.py @@ -5,8 +5,9 @@ Coaster provides a number of mixin classes for SQLAlchemy models. To use in your Flask app:: + from sqlalchemy.orm import DeclarativeBase from flask_sqlalchemy import SQLAlchemy - from coaster.sqlalchemy import BaseMixin + from coaster.sqlalchemy import BaseMixin, ModelBase class Model(ModelBase, DeclarativeBase): '''Model base class.''' @@ -14,7 +15,7 @@ class Model(ModelBase, DeclarativeBase): db = SQLAlchemy(metadata=Model.metadata) Model.init_flask_sqlalchemy(db) - class MyModel(BaseMixin[int], Model): + class MyModel(BaseMixin[int], Model): # Integer serial primary key; alt: UUID __tablename__ = 'my_model' Mixin classes must always appear *before* ``Model`` or ``db.Model`` in your model's @@ -27,6 +28,7 @@ class MyModel(BaseMixin[int], Model): import typing as t import typing_extensions as te +import warnings from collections import abc, namedtuple from datetime import datetime from decimal import Decimal @@ -91,9 +93,17 @@ class MyModel(BaseMixin[int], Model): _T = t.TypeVar('_T', bound=t.Any) +PkeyType = te.TypeVar('PkeyType', int, UUID, default=int) +# `default=int` is processed by type checkers implementing PEP 696, but seemingly has no +# impact in runtime, so no default will be received in `IdMixin.__init_subclass__` + + +class PkeyWarning(UserWarning): + """Warning when the primary key type is not specified as a base class argument.""" + @declarative_mixin -class IdMixin: +class IdMixin(t.Generic[PkeyType]): """ Provides the :attr:`id` primary key column. @@ -120,12 +130,75 @@ class MyModel(IdMixin, Model): query_class: t.ClassVar[t.Type[Query]] = Query query: t.ClassVar[QueryProperty] #: Use UUID primary key? If yes, UUIDs are automatically generated without - #: the need to commit to the database + #: the need to commit to the database. Do not set this directly; pass UUID as a + #: Generic argument to the base class instead: ``class MyModel(IdMixin[UUID])``. __uuid_primary_key__: t.ClassVar[bool] = False + def __init_subclass__(cls) -> None: + # If a generic arg is specified, set `__uuid_primary_key__` from it. Do this + # before `super().__init_subclass__` calls SQLAlchemy's implementation, + # which processes the `declared_attr` classmethods into class attributes. They + # depend on `__uuid_primary_key__` already being set on the class. + if '__uuid_primary_key__' in cls.__dict__: + # This is only a warning, but it will turn into an error below if the value + # varies from the generic arg + warnings.warn( + f"`{cls.__qualname__}` must specify primary key type as `int` or `UUID`" + " to the base class (`IdMixin[int]` or `IdMixin[UUID]`) instead of" + " specifying `__uuid_primary_key__` directly", + PkeyWarning, + ) + + for base in te.get_original_bases(cls): + # XXX: Is this the correct way to examine a generic subclass that may have + # more generic args in a redefined order? The docs suggest Generic args are + # assumed positional, but they may be reordered, so how do we determine the + # arg to IdMixin itself? There is no variant of `cls.mro()` that returns + # original base classes with their generic args. For now, we expect that + # generic subclasses _must_ use the static `PkeyType` typevar in their + # definitions. This may need to be revisited with Python 3.12's new type + # parameter syntax (via PEP 695). + origin_base = t.get_origin(base) + if ( + origin_base is not None + and issubclass(origin_base, IdMixin) + and PkeyType in origin_base.__parameters__ # type: ignore[misc] + ): + pkey_type = t.get_args(base)[ + origin_base.__parameters__.index(PkeyType) # type: ignore[misc] + ] + if pkey_type is int: + if ( + '__uuid_primary_key__' in cls.__dict__ + and cls.__uuid_primary_key__ is not False + ): + raise TypeError( + f"{cls.__qualname__}.__uuid_primary_key__ conflicts with" + " pkey type argument to the base class" + ) + cls.__uuid_primary_key__ = False + elif pkey_type is UUID: + if ( + '__uuid_primary_key__' in cls.__dict__ + and cls.__uuid_primary_key__ is not True + ): + raise TypeError( + f"{cls.__qualname__}.__uuid_primary_key__ conflicts with" + " pkey type argument to the base class" + ) + cls.__uuid_primary_key__ = True + elif pkey_type is PkeyType: # type: ignore[misc] + # This must be a generic subclass, ignore it + pass + else: + raise TypeError(f"Unsupported primary key type in {base!r}") + break + + return super().__init_subclass__() + @immutable @declared_attr - def id(cls) -> Mapped[t.Union[int, UUID]]: # noqa: A003 + def id(cls) -> Mapped[PkeyType]: # noqa: A003 """Database identity for this model.""" if cls.__uuid_primary_key__: return sa.orm.mapped_column( @@ -584,12 +657,12 @@ def _set_fields(self, fields: t.Mapping[str, t.Any]) -> None: @declarative_mixin -class BaseMixin(IdMixin, NoIdMixin): +class BaseMixin(IdMixin[PkeyType], NoIdMixin): """Base mixin class for all tables that have an id column.""" @declarative_mixin -class BaseNameMixin(BaseMixin): +class BaseNameMixin(BaseMixin[PkeyType]): """ Base mixin class for named objects. @@ -723,7 +796,7 @@ def checkused(c: str) -> bool: @declarative_mixin -class BaseScopedNameMixin(BaseMixin): +class BaseScopedNameMixin(BaseMixin[PkeyType]): """ Base mixin class for named objects within containers. @@ -905,7 +978,7 @@ def permissions( @declarative_mixin -class BaseIdNameMixin(BaseMixin): +class BaseIdNameMixin(BaseMixin[PkeyType]): """ Base mixin class for named objects with an id tag. @@ -1023,7 +1096,7 @@ def _url_name_uuid_b58_comparator(cls) -> SqlUuidB58Comparator: @declarative_mixin -class BaseScopedIdMixin(BaseMixin): +class BaseScopedIdMixin(BaseMixin[PkeyType]): """ Base mixin class for objects with an id that is unique within a parent. @@ -1087,7 +1160,7 @@ def permissions( @declarative_mixin -class BaseScopedIdNameMixin(BaseScopedIdMixin): +class BaseScopedIdNameMixin(BaseScopedIdMixin[PkeyType]): """ Base mixin class for named objects with an id tag that is unique within a parent. @@ -1269,12 +1342,12 @@ def coordinates( # Setup listeners for UUID-based subclasses -def _configure_id_listener(mapper: t.Any, class_: IdMixin) -> None: +def _configure_id_listener(mapper: t.Any, class_: t.Type[IdMixin]) -> None: if hasattr(class_, '__uuid_primary_key__') and class_.__uuid_primary_key__: auto_init_default(mapper.column_attrs.id) -def _configure_uuid_listener(mapper: t.Any, class_: UuidMixin) -> None: +def _configure_uuid_listener(mapper: t.Any, class_: t.Type[UuidMixin]) -> None: if hasattr(class_, '__uuid_primary_key__') and class_.__uuid_primary_key__: return # Only configure this listener if the class doesn't use UUID primary keys, diff --git a/tests/coaster_tests/sqlalchemy_models_test.py b/tests/coaster_tests/sqlalchemy_models_test.py index 69153939..90a20f47 100644 --- a/tests/coaster_tests/sqlalchemy_models_test.py +++ b/tests/coaster_tests/sqlalchemy_models_test.py @@ -220,68 +220,60 @@ class MyUrlModel(Model): ) -class NonUuidKey(BaseMixin, Model): +class NonUuidKey(BaseMixin[int], Model): __tablename__ = 'non_uuid_key' - __uuid_primary_key__ = False -class UuidKey(BaseMixin, Model): +class UuidKey(BaseMixin[UUID], Model): __tablename__ = 'uuid_key' - __uuid_primary_key__ = True -class UuidKeyNoDefault(BaseMixin, Model): +class UuidKeyNoDefault(BaseMixin[UUID], Model): __tablename__ = 'uuid_key_no_default' - __uuid_primary_key__ = True id: Mapped[UUID] = sa.orm.mapped_column( # type: ignore[assignment] # noqa: A003 sa.Uuid, primary_key=True ) -class UuidForeignKey1(BaseMixin, Model): +class UuidForeignKey1(BaseMixin[int], Model): __tablename__ = 'uuid_foreign_key1' - __uuid_primary_key__ = False uuidkey_id: Mapped[UUID] = sa.orm.mapped_column(sa.ForeignKey('uuid_key.id')) uuidkey: Mapped[UuidKey] = relationship(UuidKey) -class UuidForeignKey2(BaseMixin, Model): +class UuidForeignKey2(BaseMixin[UUID], Model): __tablename__ = 'uuid_foreign_key2' - __uuid_primary_key__ = True uuidkey_id: Mapped[UUID] = sa.orm.mapped_column(sa.ForeignKey('uuid_key.id')) uuidkey: Mapped[UuidKey] = relationship(UuidKey) -class UuidIdName(BaseIdNameMixin, Model): +class UuidIdName(BaseIdNameMixin[UUID], Model): __tablename__ = 'uuid_id_name' - __uuid_primary_key__ = True -class UuidIdNameMixin(UuidMixin, BaseIdNameMixin, Model): +class UuidIdNameMixin(UuidMixin, BaseIdNameMixin[UUID], Model): __tablename__ = 'uuid_id_name_mixin' - __uuid_primary_key__ = True -class UuidIdNameSecondary(UuidMixin, BaseIdNameMixin, Model): +class UuidIdNameSecondary(UuidMixin, BaseIdNameMixin[int], Model): __tablename__ = 'uuid_id_name_secondary' - __uuid_primary_key__ = False -class NonUuidMixinKey(UuidMixin, BaseMixin, Model): +class NonUuidMixinKey(UuidMixin, BaseMixin[int], Model): __tablename__ = 'non_uuid_mixin_key' - __uuid_primary_key__ = False -class UuidMixinKey(UuidMixin, BaseMixin, Model): +class UuidMixinKey(UuidMixin, BaseMixin[UUID], Model): __tablename__ = 'uuid_mixin_key' - __uuid_primary_key__ = True class ParentForPrimary(BaseMixin, Model): __tablename__ = 'parent_for_primary' - __allow_unmapped__ = True # Required for primary_child not being wrapped in Mapped - primary_child: t.Optional[ChildForPrimary] # For the added relationship + # The relationship must be explicitly defined for type hinting to work. + # add_primary_relationship will replace this with a fleshed-out relationship + # for SQLAlchemy configuration + primary_child: Mapped[t.Optional[ChildForPrimary]] = relationship() class ChildForPrimary(BaseMixin, Model):