Skip to content

Commit

Permalink
IdMixin's variable int/UUID primary key is now compatible with type h…
Browse files Browse the repository at this point in the history
…inting
  • Loading branch information
jace committed Nov 6, 2023
1 parent 0c74e97 commit 612dd28
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dependencies = [
'sqlalchemy-utils',
'SQLAlchemy>=2.0.4',
'tldextract',
'typing_extensions',
'typing_extensions>=4.8.0',
'Unidecode',
'werkzeug',
]
Expand Down
99 changes: 86 additions & 13 deletions src/coaster/sqlalchemy/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
Coaster provides a number of mixin classes for SQLAlchemy models. To use in
your Flask app::
from sqlalchemy.orm import DeclarativeBase
from flask_sqlalchemy import SQLAlchemy
from coaster.sqlalchemy import BaseMixin
from coaster.sqlalchemy import BaseMixin, ModelBase
class Model(ModelBase, DeclarativeBase):
'''Model base class.'''
db = SQLAlchemy(metadata=Model.metadata)
Model.init_flask_sqlalchemy(db)
class MyModel(BaseMixin[int], Model):
class MyModel(BaseMixin[int], Model): # Integer serial primary key; alt: UUID
__tablename__ = 'my_model'
Mixin classes must always appear *before* ``Model`` or ``db.Model`` in your model's
Expand All @@ -27,6 +28,7 @@ class MyModel(BaseMixin[int], Model):

import typing as t
import typing_extensions as te
import warnings
from collections import abc, namedtuple
from datetime import datetime
from decimal import Decimal
Expand Down Expand Up @@ -91,9 +93,17 @@ class MyModel(BaseMixin[int], Model):

_T = t.TypeVar('_T', bound=t.Any)

PkeyType = te.TypeVar('PkeyType', int, UUID, default=int)
# `default=int` is processed by type checkers implementing PEP 696, but seemingly has no
# impact in runtime, so no default will be received in `IdMixin.__init_subclass__`


class PkeyWarning(UserWarning):
"""Warning when the primary key type is not specified as a base class argument."""


@declarative_mixin
class IdMixin:
class IdMixin(t.Generic[PkeyType]):
"""
Provides the :attr:`id` primary key column.
Expand All @@ -120,12 +130,75 @@ class MyModel(IdMixin, Model):
query_class: t.ClassVar[t.Type[Query]] = Query
query: t.ClassVar[QueryProperty]
#: Use UUID primary key? If yes, UUIDs are automatically generated without
#: the need to commit to the database
#: the need to commit to the database. Do not set this directly; pass UUID as a
#: Generic argument to the base class instead: ``class MyModel(IdMixin[UUID])``.
__uuid_primary_key__: t.ClassVar[bool] = False

def __init_subclass__(cls) -> None:
# If a generic arg is specified, set `__uuid_primary_key__` from it. Do this
# before `super().__init_subclass__` calls SQLAlchemy's implementation,
# which processes the `declared_attr` classmethods into class attributes. They
# depend on `__uuid_primary_key__` already being set on the class.
if '__uuid_primary_key__' in cls.__dict__:
# This is only a warning, but it will turn into an error below if the value
# varies from the generic arg
warnings.warn(
f"`{cls.__qualname__}` must specify primary key type as `int` or `UUID`"
" to the base class (`IdMixin[int]` or `IdMixin[UUID]`) instead of"
" specifying `__uuid_primary_key__` directly",
PkeyWarning,
)

for base in te.get_original_bases(cls):
# XXX: Is this the correct way to examine a generic subclass that may have
# more generic args in a redefined order? The docs suggest Generic args are
# assumed positional, but they may be reordered, so how do we determine the
# arg to IdMixin itself? There is no variant of `cls.mro()` that returns
# original base classes with their generic args. For now, we expect that
# generic subclasses _must_ use the static `PkeyType` typevar in their
# definitions. This may need to be revisited with Python 3.12's new type
# parameter syntax (via PEP 695).
origin_base = t.get_origin(base)
if (
origin_base is not None
and issubclass(origin_base, IdMixin)
and PkeyType in origin_base.__parameters__ # type: ignore[misc]
):
pkey_type = t.get_args(base)[
origin_base.__parameters__.index(PkeyType) # type: ignore[misc]
]
if pkey_type is int:
if (
'__uuid_primary_key__' in cls.__dict__
and cls.__uuid_primary_key__ is not False
):
raise TypeError(
f"{cls.__qualname__}.__uuid_primary_key__ conflicts with"
" pkey type argument to the base class"
)
cls.__uuid_primary_key__ = False
elif pkey_type is UUID:
if (
'__uuid_primary_key__' in cls.__dict__
and cls.__uuid_primary_key__ is not True
):
raise TypeError(
f"{cls.__qualname__}.__uuid_primary_key__ conflicts with"
" pkey type argument to the base class"
)
cls.__uuid_primary_key__ = True
elif pkey_type is PkeyType: # type: ignore[misc]
# This must be a generic subclass, ignore it
pass
else:
raise TypeError(f"Unsupported primary key type in {base!r}")
break

return super().__init_subclass__()

@immutable
@declared_attr
def id(cls) -> Mapped[t.Union[int, UUID]]: # noqa: A003
def id(cls) -> Mapped[PkeyType]: # noqa: A003
"""Database identity for this model."""
if cls.__uuid_primary_key__:
return sa.orm.mapped_column(
Expand Down Expand Up @@ -584,12 +657,12 @@ def _set_fields(self, fields: t.Mapping[str, t.Any]) -> None:


@declarative_mixin
class BaseMixin(IdMixin, NoIdMixin):
class BaseMixin(IdMixin[PkeyType], NoIdMixin):
"""Base mixin class for all tables that have an id column."""


@declarative_mixin
class BaseNameMixin(BaseMixin):
class BaseNameMixin(BaseMixin[PkeyType]):
"""
Base mixin class for named objects.
Expand Down Expand Up @@ -723,7 +796,7 @@ def checkused(c: str) -> bool:


@declarative_mixin
class BaseScopedNameMixin(BaseMixin):
class BaseScopedNameMixin(BaseMixin[PkeyType]):
"""
Base mixin class for named objects within containers.
Expand Down Expand Up @@ -905,7 +978,7 @@ def permissions(


@declarative_mixin
class BaseIdNameMixin(BaseMixin):
class BaseIdNameMixin(BaseMixin[PkeyType]):
"""
Base mixin class for named objects with an id tag.
Expand Down Expand Up @@ -1023,7 +1096,7 @@ def _url_name_uuid_b58_comparator(cls) -> SqlUuidB58Comparator:


@declarative_mixin
class BaseScopedIdMixin(BaseMixin):
class BaseScopedIdMixin(BaseMixin[PkeyType]):
"""
Base mixin class for objects with an id that is unique within a parent.
Expand Down Expand Up @@ -1087,7 +1160,7 @@ def permissions(


@declarative_mixin
class BaseScopedIdNameMixin(BaseScopedIdMixin):
class BaseScopedIdNameMixin(BaseScopedIdMixin[PkeyType]):
"""
Base mixin class for named objects with an id tag that is unique within a parent.
Expand Down Expand Up @@ -1269,12 +1342,12 @@ def coordinates(


# Setup listeners for UUID-based subclasses
def _configure_id_listener(mapper: t.Any, class_: IdMixin) -> None:
def _configure_id_listener(mapper: t.Any, class_: t.Type[IdMixin]) -> None:
if hasattr(class_, '__uuid_primary_key__') and class_.__uuid_primary_key__:
auto_init_default(mapper.column_attrs.id)


def _configure_uuid_listener(mapper: t.Any, class_: UuidMixin) -> None:
def _configure_uuid_listener(mapper: t.Any, class_: t.Type[UuidMixin]) -> None:
if hasattr(class_, '__uuid_primary_key__') and class_.__uuid_primary_key__:
return
# Only configure this listener if the class doesn't use UUID primary keys,
Expand Down
36 changes: 14 additions & 22 deletions tests/coaster_tests/sqlalchemy_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,68 +220,60 @@ class MyUrlModel(Model):
)


class NonUuidKey(BaseMixin, Model):
class NonUuidKey(BaseMixin[int], Model):
__tablename__ = 'non_uuid_key'
__uuid_primary_key__ = False


class UuidKey(BaseMixin, Model):
class UuidKey(BaseMixin[UUID], Model):
__tablename__ = 'uuid_key'
__uuid_primary_key__ = True


class UuidKeyNoDefault(BaseMixin, Model):
class UuidKeyNoDefault(BaseMixin[UUID], Model):
__tablename__ = 'uuid_key_no_default'
__uuid_primary_key__ = True
id: Mapped[UUID] = sa.orm.mapped_column( # type: ignore[assignment] # noqa: A003
sa.Uuid, primary_key=True
)


class UuidForeignKey1(BaseMixin, Model):
class UuidForeignKey1(BaseMixin[int], Model):
__tablename__ = 'uuid_foreign_key1'
__uuid_primary_key__ = False
uuidkey_id: Mapped[UUID] = sa.orm.mapped_column(sa.ForeignKey('uuid_key.id'))
uuidkey: Mapped[UuidKey] = relationship(UuidKey)


class UuidForeignKey2(BaseMixin, Model):
class UuidForeignKey2(BaseMixin[UUID], Model):
__tablename__ = 'uuid_foreign_key2'
__uuid_primary_key__ = True
uuidkey_id: Mapped[UUID] = sa.orm.mapped_column(sa.ForeignKey('uuid_key.id'))
uuidkey: Mapped[UuidKey] = relationship(UuidKey)


class UuidIdName(BaseIdNameMixin, Model):
class UuidIdName(BaseIdNameMixin[UUID], Model):
__tablename__ = 'uuid_id_name'
__uuid_primary_key__ = True


class UuidIdNameMixin(UuidMixin, BaseIdNameMixin, Model):
class UuidIdNameMixin(UuidMixin, BaseIdNameMixin[UUID], Model):
__tablename__ = 'uuid_id_name_mixin'
__uuid_primary_key__ = True


class UuidIdNameSecondary(UuidMixin, BaseIdNameMixin, Model):
class UuidIdNameSecondary(UuidMixin, BaseIdNameMixin[int], Model):
__tablename__ = 'uuid_id_name_secondary'
__uuid_primary_key__ = False


class NonUuidMixinKey(UuidMixin, BaseMixin, Model):
class NonUuidMixinKey(UuidMixin, BaseMixin[int], Model):
__tablename__ = 'non_uuid_mixin_key'
__uuid_primary_key__ = False


class UuidMixinKey(UuidMixin, BaseMixin, Model):
class UuidMixinKey(UuidMixin, BaseMixin[UUID], Model):
__tablename__ = 'uuid_mixin_key'
__uuid_primary_key__ = True


class ParentForPrimary(BaseMixin, Model):
__tablename__ = 'parent_for_primary'
__allow_unmapped__ = True # Required for primary_child not being wrapped in Mapped

primary_child: t.Optional[ChildForPrimary] # For the added relationship
# The relationship must be explicitly defined for type hinting to work.
# add_primary_relationship will replace this with a fleshed-out relationship
# for SQLAlchemy configuration
primary_child: Mapped[t.Optional[ChildForPrimary]] = relationship()


class ChildForPrimary(BaseMixin, Model):
Expand Down

0 comments on commit 612dd28

Please sign in to comment.