diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7849540..459e80d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.2.1' + rev: 'v0.2.2' hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/src/wagtail_headless_preview/models.py b/src/wagtail_headless_preview/models.py index 6a71435..857c7bb 100644 --- a/src/wagtail_headless_preview/models.py +++ b/src/wagtail_headless_preview/models.py @@ -1,16 +1,24 @@ import datetime import json +from typing import TYPE_CHECKING, Optional + from django.contrib.contenttypes.models import ContentType from django.core.signing import TimestampSigner from django.db import models from django.shortcuts import redirect, render from django.utils.http import urlencode +from wagtail.models import Site from wagtail_headless_preview.settings import headless_preview_settings from wagtail_headless_preview.signals import preview_update +if TYPE_CHECKING: + from django.http import HttpRequest + from wagtail.models import Page, Site + + class PagePreview(models.Model): token = models.CharField(max_length=255, unique=True) content_type = models.ForeignKey( @@ -19,10 +27,10 @@ class PagePreview(models.Model): content_json = models.TextField() created_at = models.DateField(auto_now_add=True) - def __str__(self): + def __str__(self) -> str: return f"PagePreview: {self.token}, {self.created_at}" - def as_page(self): + def as_page(self) -> "Page": content = json.loads(self.content_json) page_model = ContentType.objects.get_for_id( content["content_type"] @@ -37,7 +45,7 @@ def garbage_collect(cls): cls.objects.filter(created_at__lt=yesterday).delete() -def get_client_root_url_from_site(site): +def get_client_root_url_from_site(site) -> str: try: root_url = headless_preview_settings.CLIENT_URLS[site.hostname] except (AttributeError, KeyError): @@ -51,17 +59,26 @@ def get_client_root_url_from_site(site): return root_url -class HeadlessPreviewMixin: +class HeadlessBase: + def get_client_root_url(self, request: "HttpRequest") -> str: + """ + Finds the client root URL based on the Site (as found from the request). + This can be overridden in the concrete Page subclasses to provide a different logic + """ + return get_client_root_url_from_site(Site.find_for_request(request)) + + +class HeadlessPreviewMixin(HeadlessBase): @classmethod - def get_preview_signer(cls): + def get_preview_signer(cls) -> TimestampSigner: return TimestampSigner(salt="headlesspreview.token") @classmethod - def get_content_type_str(cls): + def get_content_type_str(cls) -> str: return cls._meta.app_label + "." + cls.__name__.lower() @classmethod - def get_page_from_preview_token(cls, token): + def get_page_from_preview_token(cls, token: str) -> Optional["Page"]: content_type = ContentType.objects.get_for_model(cls) # Check token is valid @@ -74,7 +91,16 @@ def get_page_from_preview_token(cls, token): except PagePreview.DoesNotExist: return - def create_page_preview(self): + def update_page_preview(self, token: str) -> PagePreview: + return PagePreview.objects.update_or_create( + token=token, + defaults={ + "content_type": self.content_type, + "content_json": self.to_json(), + }, + ) + + def create_page_preview(self) -> PagePreview: if self.pk is None: identifier = ( f"parent_id={self.get_parent().pk};page_type={self._meta.label}" @@ -89,26 +115,14 @@ def create_page_preview(self): return preview - def update_page_preview(self, token): - return PagePreview.objects.update_or_create( - token=token, - defaults={ - "content_type": self.content_type, - "content_json": self.to_json(), - }, - ) - - def get_client_root_url(self): - return get_client_root_url_from_site(self.get_site()) - - def get_preview_url(self, token): + def get_preview_url(self, request: "HttpRequest", token: str) -> str: return ( - self.get_client_root_url() + self.get_client_root_url(request) + "?" + urlencode({"content_type": self.get_content_type_str(), "token": token}) ) - def serve_preview(self, request, preview_mode): + def serve_preview(self, request: "HttpRequest", preview_mode): PagePreview.garbage_collect() page_preview = self.create_page_preview() page_preview.save() @@ -116,7 +130,7 @@ def serve_preview(self, request, preview_mode): # Send the preview_update signal. Other apps can implement their own handling preview_update.send(sender=HeadlessPreviewMixin, token=page_preview.token) - preview_url = self.get_preview_url(page_preview.token) + preview_url = self.get_preview_url(request, page_preview.token) if headless_preview_settings.REDIRECT_ON_PREVIEW: return redirect(preview_url) @@ -129,8 +143,8 @@ def serve_preview(self, request, preview_mode): return response -class HeadlessServeMixin: - def serve(self, request): +class HeadlessServeMixin(HeadlessBase): + def serve(self, request: "HttpRequest"): """ Mixin overriding the default serve method with a redirect. The URL of the requested page is kept the same, only the host is @@ -139,10 +153,12 @@ def serve(self, request): However, you can enforce a single host using the HEADLESS_SERVE_BASE_URL setting. """ + if headless_preview_settings.SERVE_BASE_URL: base_url = headless_preview_settings.SERVE_BASE_URL else: - base_url = get_client_root_url_from_site(self.get_site()) + base_url = self.get_client_root_url(request) + site_id, site_root, relative_page_url = self.get_url_parts(request) return redirect(f"{base_url.rstrip('/')}{relative_page_url}") diff --git a/tests/test_frontend.py b/tests/test_frontend.py index 97e47be..64f325c 100644 --- a/tests/test_frontend.py +++ b/tests/test_frontend.py @@ -1,5 +1,5 @@ from django.contrib.auth.models import User -from django.test import TestCase, override_settings +from django.test import RequestFactory, TestCase, override_settings from django.urls import reverse from django.utils.http import urlencode from wagtail.models import Page @@ -9,7 +9,7 @@ from wagtail_headless_preview.settings import headless_preview_settings -class TestFrontendViews(TestCase): +class TestMixins(TestCase): fixtures = ["test.json"] @classmethod @@ -30,6 +30,13 @@ def setUpTestData(cls): cls.page.title = "Simple page with draft edit" cls.page.save_revision() + cls.headless_page = HeadlessPage( + title="Headless page original", slug="headless-page" + ) + cls.homepage.add_child(instance=cls.headless_page) + + cls.request = RequestFactory().get("/") + def setUp(self): self.client.login(username=self.admin_user.username, password="password") @@ -77,7 +84,7 @@ def test_redirect_on_preview(self): self.assertRedirects( response, - self.page.get_preview_url(preview_token), + self.page.get_preview_url(self.request, preview_token), fetch_redirect_response=False, ) @@ -86,7 +93,7 @@ def test_redirect_on_preview(self): ) def test_get_client_root_url_with_default_trailing_slash_enforcement(self): self.assertEqual( - self.page.get_client_root_url(), + self.page.get_client_root_url(self.request), "https://headless.site/", ) @@ -98,10 +105,20 @@ def test_get_client_root_url_with_default_trailing_slash_enforcement(self): ) def test_get_client_root_url_without_trailing_slash_enforcement(self): self.assertEqual( - self.page.get_client_root_url(), + self.page.get_client_root_url(self.request), "https://headless.site", ) + @override_settings( + WAGTAIL_HEADLESS_PREVIEW={"SERVE_BASE_URL": "https://headless.site"}, + TEST_OVERRIDE_CLIENT_ROOT_URL=True, + ) + def test_get_client_root_url_override_in_implementing_page(self): + self.assertEqual( + self.headless_page.get_client_root_url(self.request), + "https://wagtail.org", + ) + def test_create_page_preview_race(self): self.page.create_page_preview() @@ -111,36 +128,22 @@ def test_create_page_preview_race(self): # This shouldn't hit a unique constraint error self.page.create_page_preview() - -class TestHeadlessRedirectMixin(TestCase): - fixtures = ["test.json"] - - @classmethod - def setUpTestData(cls): - cls.admin_user = User.objects.create_superuser( - username="admin", - email="admin@example.com", - password="password", - ) - - cls.homepage = Page.objects.get(url_path="/home/").specific - cls.page = HeadlessPage(title="Simple page original", slug="simple-page") - cls.homepage.add_child(instance=cls.page) - - def test_serve(self): + def test_headless_serve_mixin_redirects_in_serve(self): client_url = headless_preview_settings.CLIENT_URLS["default"].rstrip("/") - response = self.client.get(self.page.url) + response = self.client.get(self.headless_page.url) self.assertRedirects( - response, f"{client_url}/{self.page.slug}/", fetch_redirect_response=False + response, + f"{client_url}/{self.headless_page.slug}/", + fetch_redirect_response=False, ) @override_settings( WAGTAIL_HEADLESS_PREVIEW={"SERVE_BASE_URL": "https://headless.site"} ) - def test_serve_with_headless_serve_base_url(self): - response = self.client.get(self.page.url) + def test_headless_serve_mixin_serve_with_headless_serve_base_url(self): + response = self.client.get(self.headless_page.url) self.assertRedirects( response, - f"https://headless.site/{self.page.slug}/", + f"https://headless.site/{self.headless_page.slug}/", fetch_redirect_response=False, ) diff --git a/tests/testapp/models.py b/tests/testapp/models.py index d49f4b9..fa5fc9f 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -1,11 +1,22 @@ +from typing import TYPE_CHECKING + +from django.conf import settings from wagtail.models import Page from wagtail_headless_preview.models import HeadlessMixin, HeadlessPreviewMixin +if TYPE_CHECKING: + from django.http import HttpRequest + + class SimplePage(HeadlessPreviewMixin, Page): pass class HeadlessPage(HeadlessMixin, Page): - pass + def get_client_root_url(self, request: "HttpRequest") -> str: + if getattr(settings, "TEST_OVERRIDE_CLIENT_ROOT_URL", False): + return "https://wagtail.org" + + return super().get_client_root_url(request)