Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor views to use generic arg support in ModelView #1941

Merged
merged 7 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions funnel/forms/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PASSWORD_MIN_LENGTH,
Account,
Anchor,
User,
check_password_strength,
getuser,
)
Expand Down Expand Up @@ -158,7 +159,7 @@ class PasswordForm(forms.Form):
"""Form to validate a user's password, for password-gated sudo actions."""

__expects__ = ('edit_user',)
edit_user: Account
edit_user: User

password = forms.PasswordField(
__("Password"),
Expand All @@ -181,7 +182,7 @@ class PasswordPolicyForm(forms.Form):

__expects__ = ('edit_user',)
__returns__ = ('password_strength', 'is_weak', 'warning', 'suggestions')
edit_user: Account
edit_user: User
password_strength: int | None = None
is_weak: bool | None = None
warning: str | None = None
Expand Down Expand Up @@ -252,7 +253,7 @@ class PasswordCreateForm(forms.Form):

__returns__ = ('password_strength',)
__expects__ = ('edit_user',)
edit_user: Account
edit_user: User
password_strength: int | None = None

password = forms.PasswordField(
Expand Down Expand Up @@ -334,7 +335,7 @@ class PasswordChangeForm(forms.Form):

__returns__ = ('password_strength',)
__expects__ = ('edit_user',)
edit_user: Account
edit_user: User
password_strength: int | None = None

old_password = forms.PasswordField(
Expand Down Expand Up @@ -473,7 +474,7 @@ class UsernameAvailableForm(forms.Form):
"""Form to check for whether a username is available to use."""

__expects__ = ('edit_user',)
edit_user: Account
edit_user: User

username = forms.StringField(
__("Username"),
Expand Down Expand Up @@ -519,7 +520,7 @@ class NewEmailAddressForm(
"""Form to add a new email address to an account."""

__expects__ = ('edit_user',)
edit_user: Account
edit_user: User

email = forms.EmailField(
__("Email address"),
Expand Down Expand Up @@ -566,7 +567,7 @@ class NewPhoneForm(EnableNotificationsDescriptionProtoMixin, forms.RecaptchaForm
"""Form to add a new mobile number (SMS-capable) to an account."""

__expects__ = ('edit_user',)
edit_user: Account
edit_user: User

phone = forms.TelField(
__("Phone number"),
Expand Down
6 changes: 4 additions & 2 deletions funnel/forms/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AuthClient,
AuthClientCredential,
AuthClientPermissions,
User,
valid_name,
)
from .helpers import strip_filters
Expand All @@ -29,7 +30,8 @@ class AuthClientForm(forms.Form):
"""Register a new OAuth client application."""

__returns__ = ('account',)
account: Account | None = None
edit_user: User
account: Account

title = forms.StringField(
__("Application title"),
Expand Down Expand Up @@ -127,7 +129,7 @@ def _urls_match(self, url1: str, url2: str) -> bool:
def validate_redirect_uri(self, field: forms.Field) -> None:
"""Validate redirect URI points to the website for confidential clients."""
if self.confidential.data and not self._urls_match(
self.website.data, field.data
self.website.data or '', field.data
):
raise forms.validators.ValidationError(
_("The scheme, domain and port must match that of the website URL")
Expand Down
4 changes: 2 additions & 2 deletions funnel/forms/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from baseframe import _, __, forms

from ..models import Account, Team
from ..models import Account, Team, User

__all__ = ['OrganizationForm', 'TeamForm']

Expand All @@ -19,7 +19,7 @@ class OrganizationForm(forms.Form):
"""Form for an organization's name and title."""

__expects__: Iterable[str] = ('edit_user',)
edit_user: Account
edit_user: User
edit_obj: Account | None

title = forms.StringField(
Expand Down
4 changes: 2 additions & 2 deletions funnel/models/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ class AuthClient(ScopeMixin, UuidMixin, BaseMixin, Model):
__scope_null_allowed__ = True
#: Account that owns this client
account_id: Mapped[int] = sa.orm.mapped_column(
sa.ForeignKey('account.id'), nullable=True
sa.ForeignKey('account.id'), nullable=False
)
account: Mapped[Account | None] = with_roles(
account: Mapped[Account] = with_roles(
relationship(
Account,
foreign_keys=[account_id],
Expand Down
12 changes: 10 additions & 2 deletions funnel/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ def decorator(attr: T) -> T:
# None or '' not allowed
raise ValueError(f"Could not determine name for {attr!r}")
if use_name in cls.__dict__:
raise AttributeError(f"{cls.__name__} already has attribute {use_name}")
raise AttributeError(
f"{cls.__name__} already has attribute {use_name}",
name=use_name,
obj=cls,
)
setattr(cls, use_name, attr)
return attr

Expand Down Expand Up @@ -291,7 +295,11 @@ def decorator(temp_cls: TempType) -> ReopenedType:
):
# Refuse to overwrite existing attributes
if hasattr(cls, attr):
raise AttributeError(f"{cls.__name__} already has attribute {attr}")
raise AttributeError(
f"{cls.__name__} already has attribute {attr}",
name=attr,
obj=cls,
)
# All good? Copy the attribute over...
setattr(cls, attr, value)
# ...And remove it from the temporary class
Expand Down
4 changes: 2 additions & 2 deletions funnel/models/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def __getattr__(self, name: str) -> bool | str | None:
Label.name == name, Label.project == self._obj.project
).one_or_none()
if label is None:
raise AttributeError
raise AttributeError(f"No label {name} in {self._obj.project}")

if not label.has_options:
return label in self._obj.labels
Expand All @@ -357,7 +357,7 @@ def __setattr__(self, name: str, value: bool) -> None:
Label._archived.is_(False),
).one_or_none()
if label is None:
raise AttributeError
raise AttributeError(f"No label {name} in {self._obj.project}")

if not label.has_options:
if value is True:
Expand Down
23 changes: 19 additions & 4 deletions funnel/models/membership_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collections.abc import Callable, Iterable
from datetime import datetime as datetime_type
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar
from uuid import UUID

Expand Down Expand Up @@ -643,15 +644,19 @@ class AmendMembership(Generic[MembershipType]):
to any attribute listed as a data column.
"""

membership: MembershipType
_actor: Account
_new: dict[str, Any]

def __init__(self, membership: MembershipType, actor: Account) -> None:
"""Create an amendment placeholder."""
if membership.revoked_at is not None:
raise MembershipRevokedError(
"This membership record has already been revoked"
)
object.__setattr__(self, 'membership', membership)
object.__setattr__(self, '_new', {})
object.__setattr__(self, '_actor', actor)
object.__setattr__(self, '_new', {})

def __getattr__(self, attr: str) -> Any:
"""Get an attribute from the underlying record."""
Expand All @@ -662,7 +667,13 @@ def __getattr__(self, attr: str) -> Any:
def __setattr__(self, attr: str, value: Any) -> None:
"""Set an amended value."""
if attr not in self.membership.__data_columns__:
raise AttributeError(f"{attr} cannot be set")
raise AttributeError(
f"{attr} cannot be set",
name=attr,
obj=SimpleNamespace(
**{_: None for _ in self.membership.__data_columns__}
),
)
self._new[attr] = value

def __enter__(self) -> AmendMembership:
Expand Down Expand Up @@ -697,10 +708,14 @@ def _confirm_enumerated_mixins(_mapper: Any, cls: type[Account]) -> None:
if attr_relationship is None:
raise AttributeError(
f'{cls.__name__} does not have a relationship named'
f' {attr_name!r} targeting a subclass of {expected_class.__name__}'
f' {attr_name!r} targeting a subclass of {expected_class.__name__}',
name=attr_name,
obj=cls,
)
if not issubclass(attr_relationship.property.mapper.class_, expected_class):
raise AttributeError(
f'{cls.__name__}.{attr_name} should be a relationship to a'
f' subclass of {expected_class.__name__}'
f' subclass of {expected_class.__name__}',
name=attr_name,
obj=cls,
)
13 changes: 6 additions & 7 deletions funnel/proxies/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

from flask import has_request_context, request
from flask.globals import request_ctx
from werkzeug.local import LocalProxy
from werkzeug.utils import cached_property

Expand Down Expand Up @@ -121,25 +122,23 @@ def hx_prompt(self) -> str | None:

def _get_current_object(self) -> RequestWants:
"""Type hint for the LocalProxy wrapper method."""
return self


def _get_request_wants() -> RequestWants:
"""Get request_wants from the request."""
# Flask 2.0 deprecated use of _request_ctx_stack.top and recommends using `g`.
# However, `g` is not suitable for us as we must cache results for a request only.
# Therefore we stick it in the request object itself.
if has_request_context():
# pylint: disable=protected-access
wants = getattr(request, '_request_wants', None)
wants = getattr(request_ctx, 'request_wants', None)
if wants is None:
wants = RequestWants()
request._request_wants = wants # type: ignore[attr-defined]
request_ctx.request_wants = wants # type: ignore[attr-defined]
return wants
# Return an empty handler
return RequestWants()


request_wants: RequestWants = LocalProxy(_get_request_wants) # type: ignore[assignment]
request_wants: RequestWants = cast(RequestWants, LocalProxy(_get_request_wants))


def response_varies(response: ResponseType) -> ResponseType:
Expand Down
5 changes: 4 additions & 1 deletion funnel/transports/sms/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from enum import Enum
from re import Pattern
from string import Formatter
from types import SimpleNamespace
from typing import Any, ClassVar, cast

from flask import Flask
Expand Down Expand Up @@ -249,7 +250,9 @@ def __getattr__(self, attr: str) -> Any:
try:
return self._format_kwargs[attr]
except KeyError as exc:
raise AttributeError(attr) from exc
raise AttributeError(
attr, name=attr, obj=SimpleNamespace(**self._format_kwargs)
) from exc

def __getitem__(self, key: str) -> Any:
"""Get a format variable via dictionary access, defaulting to ''."""
Expand Down
10 changes: 1 addition & 9 deletions funnel/views/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,20 +290,15 @@ def login_session_service(obj: LoginSession) -> str | None:
return None


@route('/account')
@route('/account', init_app=app)
class AccountView(ClassView):
"""Account management views."""

__decorators__ = [requires_login]

obj: Account
current_section = 'account' # needed for showing active tab
SavedProjectForm = SavedProjectForm

def loader(self, **kwargs) -> Account:
"""Return current user."""
return current_auth.user

@route('', endpoint='account')
@render_with('account.html.jinja2')
def account(self) -> ReturnRenderWith:
Expand Down Expand Up @@ -882,9 +877,6 @@ def delete(self):
)


AccountView.init_app(app)


# --- Compatibility routes -------------------------------------------------------------


Expand Down
Loading