From 48a6132a7d93eb5f7f547a0a8017e4bf6a782781 Mon Sep 17 00:00:00 2001 From: zerolab Date: Thu, 15 Feb 2024 18:57:51 +0000 Subject: [PATCH] Make `get_client_root_url` overridable both in both mixins --- src/wagtail_headless_preview/models.py | 77 +++++++++++++++++--------- tests/test_frontend.py | 10 ++-- 2 files changed, 57 insertions(+), 30 deletions(-) diff --git a/src/wagtail_headless_preview/models.py b/src/wagtail_headless_preview/models.py index 415c1b4..f2316a6 100644 --- a/src/wagtail_headless_preview/models.py +++ b/src/wagtail_headless_preview/models.py @@ -1,16 +1,24 @@ import datetime import json +from typing import TYPE_CHECKING, Optional + from django.contrib.contenttypes.models import ContentType from django.core.signing import TimestampSigner from django.db import models from django.shortcuts import redirect, render from django.utils.http import urlencode +from wagtail.models import Site from wagtail_headless_preview.settings import headless_preview_settings from wagtail_headless_preview.signals import preview_update +if TYPE_CHECKING: + from django.http import HttpRequest + from wagtail.models import Page, Site + + class PagePreview(models.Model): token = models.CharField(max_length=255, unique=True) content_type = models.ForeignKey( @@ -19,7 +27,7 @@ class PagePreview(models.Model): content_json = models.TextField() created_at = models.DateField(auto_now_add=True) - def as_page(self): + def as_page(self) -> "Page": content = json.loads(self.content_json) page_model = ContentType.objects.get_for_id( content["content_type"] @@ -34,7 +42,7 @@ def garbage_collect(cls): cls.objects.filter(created_at__lt=yesterday).delete() -def get_client_root_url_from_site(site): +def get_client_root_url_from_site(site) -> str: try: root_url = headless_preview_settings.CLIENT_URLS[site.hostname] except (AttributeError, KeyError): @@ -48,17 +56,35 @@ def get_client_root_url_from_site(site): return root_url -class HeadlessPreviewMixin: +class HeadlessBase: + def _get_site_from_request(self, request: "HttpRequest") -> "Site": + """ + Copy of Page.get_site() which passes the request object to Page.get_url_parts() + """ + url_parts = self.get_url_parts(request=request) + if url_parts is None: + # page is not routable + return + + site_id, root_url, page_path = url_parts + + return Site.objects.get(id=site_id) + + def get_client_root_url(self, request: "HttpRequest") -> str: + return get_client_root_url_from_site(self._get_site_from_request(request)) + + +class HeadlessPreviewMixin(HeadlessBase): @classmethod - def get_preview_signer(cls): + def get_preview_signer(cls) -> TimestampSigner: return TimestampSigner(salt="headlesspreview.token") @classmethod - def get_content_type_str(cls): + def get_content_type_str(cls) -> str: return cls._meta.app_label + "." + cls.__name__.lower() @classmethod - def get_page_from_preview_token(cls, token): + def get_page_from_preview_token(cls, token: str) -> Optional["Page"]: content_type = ContentType.objects.get_for_model(cls) # Check token is valid @@ -71,7 +97,16 @@ def get_page_from_preview_token(cls, token): except PagePreview.DoesNotExist: return - def create_page_preview(self): + def update_page_preview(self, token: str) -> PagePreview: + return PagePreview.objects.update_or_create( + token=token, + defaults={ + "content_type": self.content_type, + "content_json": self.to_json(), + }, + ) + + def create_page_preview(self) -> PagePreview: if self.pk is None: identifier = ( f"parent_id={self.get_parent().pk};page_type={self._meta.label}" @@ -86,26 +121,14 @@ def create_page_preview(self): return preview - def update_page_preview(self, token): - return PagePreview.objects.update_or_create( - token=token, - defaults={ - "content_type": self.content_type, - "content_json": self.to_json(), - }, - ) - - def get_client_root_url(self): - return get_client_root_url_from_site(self.get_site()) - - def get_preview_url(self, token): + def get_preview_url(self, request: "HttpRequest", token: str) -> str: return ( - self.get_client_root_url() + self.get_client_root_url(request) + "?" + urlencode({"content_type": self.get_content_type_str(), "token": token}) ) - def serve_preview(self, request, preview_mode): + def serve_preview(self, request: "HttpRequest", preview_mode): PagePreview.garbage_collect() page_preview = self.create_page_preview() page_preview.save() @@ -113,7 +136,7 @@ def serve_preview(self, request, preview_mode): # Send the preview_update signal. Other apps can implement their own handling preview_update.send(sender=HeadlessPreviewMixin, token=page_preview.token) - preview_url = self.get_preview_url(page_preview.token) + preview_url = self.get_preview_url(request, page_preview.token) if headless_preview_settings.REDIRECT_ON_PREVIEW: return redirect(preview_url) @@ -126,8 +149,8 @@ def serve_preview(self, request, preview_mode): return response -class HeadlessServeMixin: - def serve(self, request): +class HeadlessServeMixin(HeadlessBase): + def serve(self, request: "HttpRequest"): """ Mixin overriding the default serve method with a redirect. The URL of the requested page is kept the same, only the host is @@ -136,10 +159,12 @@ def serve(self, request): However, you can enforce a single host using the HEADLESS_SERVE_BASE_URL setting. """ + if headless_preview_settings.SERVE_BASE_URL: base_url = headless_preview_settings.SERVE_BASE_URL else: - base_url = get_client_root_url_from_site(self.get_site()) + base_url = self.get_client_root_url(request) + site_id, site_root, relative_page_url = self.get_url_parts(request) return redirect(f"{base_url.rstrip('/')}{relative_page_url}") diff --git a/tests/test_frontend.py b/tests/test_frontend.py index e7b174f..ff44dac 100644 --- a/tests/test_frontend.py +++ b/tests/test_frontend.py @@ -1,5 +1,5 @@ from django.contrib.auth.models import User -from django.test import TestCase, override_settings +from django.test import RequestFactory, TestCase, override_settings from django.urls import reverse from django.utils.http import urlencode from wagtail.models import Page @@ -30,6 +30,8 @@ def setUpTestData(cls): cls.page.title = "Simple page with draft edit" cls.page.save_revision() + cls.request = RequestFactory().get("/") + def setUp(self): self.client.login( username=self.admin_user.username, password="password" # noqa: S106 @@ -79,7 +81,7 @@ def test_redirect_on_preview(self): self.assertRedirects( response, - self.page.get_preview_url(preview_token), + self.page.get_preview_url(self.request, preview_token), fetch_redirect_response=False, ) @@ -88,7 +90,7 @@ def test_redirect_on_preview(self): ) def test_get_client_root_url_with_default_trailing_slash_enforcement(self): self.assertEqual( - self.page.get_client_root_url(), + self.page.get_client_root_url(self.request), "https://headless.site/", ) @@ -100,7 +102,7 @@ def test_get_client_root_url_with_default_trailing_slash_enforcement(self): ) def test_get_client_root_url_without_trailing_slash_enforcement(self): self.assertEqual( - self.page.get_client_root_url(), + self.page.get_client_root_url(self.request), "https://headless.site", )