Skip to content

Commit

Permalink
Return Query[Self] and list[Self] where relevant
Browse files Browse the repository at this point in the history
  • Loading branch information
jace committed Dec 14, 2023
1 parent 736fb3d commit 6961920
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 39 deletions.
9 changes: 5 additions & 4 deletions funnel/models/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import itertools
from collections.abc import Iterable, Iterator, Sequence
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, ClassVar, Literal, cast, overload
from typing import TYPE_CHECKING, ClassVar, Literal, Self, cast, overload
from uuid import UUID

import phonenumbers
Expand Down Expand Up @@ -1099,7 +1099,7 @@ def all_public(cls) -> Query:
return query

@classmethod
def autocomplete(cls, prefix: str) -> list[Account]:
def autocomplete(cls, prefix: str) -> list[Self]:
"""
Return accounts whose names begin with the prefix, for autocomplete UI.
Expand Down Expand Up @@ -1332,8 +1332,9 @@ def __init__(self, **kwargs) -> None:
Account.userid = Account.uuid_b64


# TODO: Make an Actor Protocol as the base for both -- maybe placing that in Coaster
class DuckTypeAccount(RoleMixin):
"""User singleton constructor. Ducktypes a regular user object."""
"""User singleton constructor. Duck types a regular user object."""

id: None = None # noqa: A003
created_at: None = None
Expand Down Expand Up @@ -1872,7 +1873,7 @@ def get_by(
)

@classmethod
def all(cls, email: str) -> Query[AccountEmailClaim]: # noqa: A003
def all(cls, email: str) -> Query[Self]: # noqa: A003
"""
Return all instances with the matching email address.
Expand Down
20 changes: 9 additions & 11 deletions funnel/models/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Iterable, Sequence
from datetime import datetime, timedelta
from hashlib import blake2b, sha256
from typing import cast, overload
from typing import Self, cast, overload

from sqlalchemy.orm import attribute_keyed_dict, load_only
from sqlalchemy.orm.query import Query as QueryBaseClass
Expand Down Expand Up @@ -264,7 +264,7 @@ def get(cls, buid: str) -> AuthClient | None:
return cls.query.filter(cls.buid == buid, cls.active.is_(True)).one_or_none()

@classmethod
def all_for(cls, account: Account | None) -> Query[AuthClient]:
def all_for(cls, account: Account | None) -> Query[Self]:
"""Return all clients, optionally all clients owned by the specified account."""
if account is None:
return cls.query.order_by(cls.title)
Expand Down Expand Up @@ -401,7 +401,7 @@ def is_valid(self) -> bool:
return not self.used and self.created_at >= utcnow() - timedelta(minutes=3)

@classmethod
def all_for(cls, account: Account) -> Query[AuthCode]:
def all_for(cls, account: Account) -> Query[Self]:
"""Return all auth codes for the specified account."""
return cls.query.filter(cls.account == account)

Expand Down Expand Up @@ -596,7 +596,7 @@ def get_for(
).one_or_none()

@classmethod
def all(cls, accounts: Query | Sequence[Account]) -> list[AuthToken]: # noqa: A003
def all(cls, accounts: Query | Sequence[Account]) -> list[Self]: # noqa: A003
"""Return all AuthToken for the specified accounts."""
query = cls.query.join(AuthClient)
if isinstance(accounts, QueryBaseClass):
Expand All @@ -622,7 +622,7 @@ def all(cls, accounts: Query | Sequence[Account]) -> list[AuthToken]: # noqa: A
return []

@classmethod
def all_for(cls, account: Account) -> Query[AuthToken]:
def all_for(cls, account: Account) -> Query[Self]:
"""Get all AuthTokens for a specified account (direct only)."""
return cls.query.filter(cls.account == account)

Expand Down Expand Up @@ -689,12 +689,12 @@ def get(
).one_or_none()

@classmethod
def all_for(cls, account: Account) -> Query[AuthClientPermissions]:
def all_for(cls, account: Account) -> Query[Self]:
"""Get all permissions assigned to account for various clients."""
return cls.query.filter(cls.account == account)

@classmethod
def all_forclient(cls, auth_client: AuthClient) -> Query[AuthClientPermissions]:
def all_forclient(cls, auth_client: AuthClient) -> Query[Self]:
"""Get all permissions assigned on the specified auth client."""
return cls.query.filter(cls.auth_client == auth_client)

Expand Down Expand Up @@ -744,17 +744,15 @@ def get(
).one_or_none()

@classmethod
def all_for(
cls, auth_client: AuthClient, account: Account
) -> Query[AuthClientTeamPermissions]:
def all_for(cls, auth_client: AuthClient, account: Account) -> Query[Self]:
"""Get all permissions for the specified account via their teams."""
return cls.query.filter(
cls.auth_client == auth_client,
cls.team_id.in_([team.id for team in account.member_teams]),
)

@classmethod
def all_forclient(cls, auth_client: AuthClient) -> Query[AuthClientTeamPermissions]:
def all_forclient(cls, auth_client: AuthClient) -> Query[Self]:
"""Get all permissions assigned on the specified auth client."""
return cls.query.filter(cls.auth_client == auth_client)

Expand Down
3 changes: 2 additions & 1 deletion funnel/models/commentset_membership.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Self

from werkzeug.utils import cached_property

Expand Down Expand Up @@ -90,7 +91,7 @@ def update_last_seen_at(self) -> None:
self.last_seen_at = sa.func.utcnow()

@classmethod
def for_user(cls, account: Account) -> Query[CommentsetMembership]:
def for_user(cls, account: Account) -> Query[Self]:
"""
Return a query representing all active commentset memberships for a user.
Expand Down
5 changes: 3 additions & 2 deletions funnel/models/contact_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from datetime import date as date_type, datetime
from itertools import groupby
from typing import Self
from uuid import UUID

from pytz import timezone
Expand Down Expand Up @@ -258,7 +259,7 @@ def grouped_counts_for(
@classmethod
def contacts_for_project_and_date(
cls, account: Account, project: Project, date: date_type, archived: bool = False
) -> Query[ContactExchange]:
) -> Query[Self]:
"""Return contacts for a given user, project and date."""
query = cls.query.join(TicketParticipant).filter(
cls.account == account,
Expand All @@ -285,7 +286,7 @@ def contacts_for_project_and_date(
@classmethod
def contacts_for_project(
cls, account: Account, project: Project, archived: bool = False
) -> Query[ContactExchange]:
) -> Query[Self]:
"""Return contacts for a given user and project."""
query = cls.query.join(TicketParticipant).filter(
cls.account == account,
Expand Down
17 changes: 9 additions & 8 deletions funnel/models/email_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hashlib
import unicodedata
from datetime import datetime
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast, overload

import base58
import idna
Expand Down Expand Up @@ -473,7 +473,7 @@ def get_filter(
email_hash: str | None = None,
) -> sa.ColumnElement[bool] | None:
"""
Get an filter condition for retriving an :class:`EmailAddress`.
Get an filter condition for retrieving an :class:`EmailAddress`.
Accepts an email address or a blake2b160 hash in either bytes or base58 form.
Internally converts all lookups to a bytes-based hash lookup. Returns an
Expand Down Expand Up @@ -528,14 +528,15 @@ def get(
Internally converts an email-based lookup into a hash-based lookup.
"""
return cls.query.filter(
cls.get_filter(email=email, blake2b160=blake2b160, email_hash=email_hash)
).one_or_none()
email_filter = cls.get_filter(
email=email, blake2b160=blake2b160, email_hash=email_hash
)
if email_filter is None:
return None
return cls.query.filter(email_filter).one_or_none()

@classmethod
def get_canonical(
cls, email: str, is_blocked: bool | None = None
) -> Query[EmailAddress]:
def get_canonical(cls, email: str, is_blocked: bool | None = None) -> Query[Self]:
"""
Get :class:`EmailAddress` instances matching the canonical representation.
Expand Down
4 changes: 2 additions & 2 deletions funnel/models/geoname.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Collection
from datetime import date
from decimal import Decimal
from typing import cast
from typing import Self, cast

from sqlalchemy.dialects.postgresql import ARRAY

Expand Down Expand Up @@ -569,7 +569,7 @@ def parse_locations(
return results

@classmethod
def autocomplete(cls, prefix: str, lang: str | None = None) -> Query[GeoName]:
def autocomplete(cls, prefix: str, lang: str | None = None) -> Query[Self]:
"""
Autocomplete a geoname record.
Expand Down
9 changes: 5 additions & 4 deletions funnel/models/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
Generic,
Optional,
Protocol,
Self,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -1160,12 +1161,12 @@ def get_for(cls, user: Account, eventid_b58: str) -> NotificationRecipient | Non
@classmethod
def web_notifications_for(
cls, user: Account, unread_only: bool = False
) -> Query[NotificationRecipient]:
) -> Query[Self]:
"""Return web notifications for a user, optionally returning unread-only."""
query = NotificationRecipient.query.join(Notification).filter(
query = cls.query.join(Notification).filter(
Notification.type.in_(notification_web_types),
NotificationRecipient.recipient == user,
NotificationRecipient.revoked_at.is_(None),
cls.recipient == user,
cls.revoked_at.is_(None),
)
if unread_only:
query = query.filter(NotificationRecipient.read_at.is_(None))
Expand Down
5 changes: 3 additions & 2 deletions funnel/models/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collections.abc import Sequence
from datetime import datetime
from typing import Self

from furl import furl
from pytz import BaseTzInfo, utc
Expand Down Expand Up @@ -770,7 +771,7 @@ def order_by_date(cls) -> sa.Case:
return clause

@classmethod
def all_unsorted(cls) -> Query[Project]:
def all_unsorted(cls) -> Query[Self]:
"""Return query of all published projects, without ordering criteria."""
return (
cls.query.join(Account, Project.account)
Expand All @@ -779,7 +780,7 @@ def all_unsorted(cls) -> Query[Project]:
)

@classmethod
def all(cls) -> Query[Project]: # noqa: A003
def all(cls) -> Query[Self]: # noqa: A003
"""Return all published projects, ordered by date."""
return cls.all_unsorted().order_by(cls.order_by_date())

Expand Down
3 changes: 2 additions & 1 deletion funnel/models/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collections.abc import Sequence
from datetime import datetime as datetime_type
from typing import Self

from baseframe import __
from baseframe.filters import preview
Expand Down Expand Up @@ -504,7 +505,7 @@ def roles_for(
return roles

@classmethod
def all_public(cls) -> Query[Proposal]:
def all_public(cls) -> Query[Self]:
return cls.query.join(Project).filter(Project.state.PUBLISHED, cls.state.PUBLIC)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions funnel/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections import OrderedDict, defaultdict
from datetime import datetime, timedelta
from typing import Any
from typing import Any, Self

from flask_babel import format_date, get_locale
from isoweek import Week
Expand Down Expand Up @@ -284,7 +284,7 @@ def make_unscheduled(self) -> None:
self.end_at = None

@classmethod
def all_public(cls) -> Query[Session]:
def all_public(cls) -> Query[Self]:
return cls.query.join(Project).filter(Project.state.PUBLISHED, cls.scheduled)


Expand Down
3 changes: 2 additions & 1 deletion funnel/models/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collections.abc import Sequence
from datetime import datetime
from typing import Self

from sqlalchemy.orm import Query as BaseQuery

Expand Down Expand Up @@ -365,7 +366,7 @@ def roles_for(
return roles

@classmethod
def all_published_public(cls) -> Query[Update]:
def all_published_public(cls) -> Query[Self]:
return cls.query.join(Project).filter(
Project.state.PUBLISHED, cls.state.PUBLISHED, cls.visibility_state.PUBLIC
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/models/account_User_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_user(db_session) -> None:


def test_user_pickername(user_twoflower, user_rincewind) -> None:
"""Test to verify pickername contains fullname and optional username."""
"""Test to verify `pickername` contains fullname and optional username."""
assert user_twoflower.pickername == "Twoflower"
assert user_rincewind.pickername == "Rincewind (@rincewind)"

Expand Down

0 comments on commit 6961920

Please sign in to comment.