diff --git a/CHANGELOG.md b/CHANGELOG.md index 05e46a2..3184c8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Syntax highlighting matching DevDocs. (#30) +- Support for setting a custom icon for produced ZIM files. (#32) ### Changed diff --git a/README.md b/README.md index 5365365..b716d4b 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,8 @@ docker run -v my_dir:/output ghcr.io/openzim/devdocs devdocs2zim --first=2 Value will be truncated to 4000 chars.Default: '{full_name} documentation by DevDocs' * `--tag TAG`: Add tag to the ZIM. Use --tag several times to add multiple. Formatting is supported. Default: ['devdocs', '{slug_without_version}'] +* `--logo-format FORMAT`: URL/path for the ZIM logo in PNG, JPG, or SVG format. + Formatting placeholders are supported. If unset, a DevDocs logo will be used. **Formatting Placeholders** diff --git a/src/devdocs2zim/constants.py b/src/devdocs2zim/constants.py index 5eb74f5..4ecd8b4 100644 --- a/src/devdocs2zim/constants.py +++ b/src/devdocs2zim/constants.py @@ -10,6 +10,7 @@ NAME = "devdocs2zim" VERSION = __version__ ROOT_DIR = pathlib.Path(__file__).parent +DEFAULT_LOGO_PATH = ROOT_DIR.joinpath("third_party", "devdocs", "devdocs_48.png") DEVDOCS_FRONTEND_URL = "https://devdocs.io" DEVDOCS_DOCUMENTS_URL = "https://documents.devdocs.io" diff --git a/src/devdocs2zim/entrypoint.py b/src/devdocs2zim/entrypoint.py index f3afa7e..4292b93 100644 --- a/src/devdocs2zim/entrypoint.py +++ b/src/devdocs2zim/entrypoint.py @@ -4,6 +4,7 @@ from devdocs2zim.client import DevdocsClient from devdocs2zim.constants import ( + DEFAULT_LOGO_PATH, DEVDOCS_DOCUMENTS_URL, DEVDOCS_FRONTEND_URL, NAME, @@ -24,6 +25,7 @@ def zim_defaults() -> ZimConfig: description_format="{full_name} docs by DevDocs", long_description_format=None, tags="devdocs;{slug_without_version}", + logo_format=str(DEFAULT_LOGO_PATH), ) diff --git a/src/devdocs2zim/generator.py b/src/devdocs2zim/generator.py index 76f568d..4d410ba 100644 --- a/src/devdocs2zim/generator.py +++ b/src/devdocs2zim/generator.py @@ -1,5 +1,6 @@ import argparse import datetime +import io import os import re from collections import defaultdict @@ -13,6 +14,17 @@ MAXIMUM_LONG_DESCRIPTION_METADATA_LENGTH, RECOMMENDED_MAX_TITLE_LENGTH, ) +from zimscraperlib.image.conversion import ( # pyright: ignore[reportMissingTypeStubs] + convert_image, + convert_svg2png, + format_for, +) +from zimscraperlib.image.transformation import ( # pyright: ignore[reportMissingTypeStubs] + resize_image, +) +from zimscraperlib.inputs import ( # pyright: ignore[reportMissingTypeStubs] + handle_user_provided_file, +) from zimscraperlib.zim import ( # pyright: ignore[reportMissingTypeStubs] Creator, StaticItem, @@ -21,7 +33,6 @@ IndexData, ) -# pyright: ignore[reportMissingTypeStubs] from devdocs2zim.client import ( DevdocsClient, DevdocsIndex, @@ -71,6 +82,8 @@ class ZimConfig(BaseModel): long_description_format: str | None # Semicolon delimited list of tags to apply to the ZIM. tags: str + # Format to use for the logo. + logo_format: str @staticmethod def add_flags(parser: argparse.ArgumentParser, defaults: "ZimConfig"): @@ -134,12 +147,21 @@ def add_flags(parser: argparse.ArgumentParser, defaults: "ZimConfig"): # argparse doesn't work so we expose the underlying semicolon delimited string. parser.add_argument( "--tags", - help="A semicolon (;) delimited list of tags to add to the ZIM." + help="A semicolon (;) delimited list of tags to add to the ZIM. " "Formatting is supported. " f"Default: {defaults.tags!r}", default=defaults.tags, ) + parser.add_argument( + "--logo-format", + help="URL/path for the ZIM logo in PNG, JPG, or SVG format. " + "Formatting placeholders are supported. " + "If unset, a DevDocs logo will be used.", + default=defaults.logo_format, + metavar="FORMAT", + ) + @staticmethod def of(namespace: argparse.Namespace) -> "ZimConfig": """Parses a namespace to create a new ZimConfig.""" @@ -195,6 +217,7 @@ def check_length(string: str, field_name: str, length: int) -> str: else None ), tags=fmt(self.tags), + logo_format=fmt(self.logo_format), ) @@ -339,7 +362,6 @@ def __init__( self.page_template = self.env.get_template("page.html") # type: ignore self.licenses_template = self.env.get_template(LICENSE_FILE) # type: ignore - self.logo_path = self.asset_path("devdocs_48.png") self.copyright_path = self.asset_path("COPYRIGHT") self.license_path = self.asset_path("LICENSE") @@ -456,6 +478,7 @@ def generate_zim( logger.info(f" Writing to: {zim_path}") + logo_bytes = self.fetch_logo_bytes(formatted_config.logo_format) creator = Creator(zim_path, "index") creator.config_metadata( Name=formatted_config.name_format, @@ -469,7 +492,7 @@ def generate_zim( Language=LANGUAGE_ISO_639_3, Tags=formatted_config.tags, Scraper=f"{NAME} v{VERSION}", - Illustration_48x48_at_1=self.logo_path.read_bytes(), + Illustration_48x48_at_1=logo_bytes, ) # Start creator early to detect problems early. @@ -491,6 +514,34 @@ def generate_zim( ) return zim_path + @staticmethod + def fetch_logo_bytes(user_logo_path: str) -> bytes: + """Fetch a user-supplied logo for the ZIM and format/resize it. + + Parameters: + user_logo_path: Path or URL to the logo. + """ + logger.info(f" Fetching logo from: {user_logo_path}") + full_logo_path = handle_user_provided_file(source=user_logo_path) + if full_logo_path is None: + # This appears to only happen if the path is blank. + raise Exception(f"Fetching logo {user_logo_path!r} failed.") + + converted_buf = io.BytesIO() + if format_for(full_logo_path, from_suffix=False) == "SVG": + # SVG conversion generates a PNG in the correct size + # so immediately return it. + convert_svg2png(full_logo_path, converted_buf, 48, 48) + return converted_buf.getvalue() + else: + # Convert to PNG + convert_image(full_logo_path, converted_buf, fmt="PNG") + + # resize to 48x48 + resized_buf = io.BytesIO() + resize_image(converted_buf, 48, 48, resized_buf, allow_upscaling=True) + return resized_buf.getvalue() + @staticmethod def page_titles(pages: list[DevdocsIndexEntry]) -> dict[str, str]: """Returns a map between page paths in the DB and their "best" title. diff --git a/tests/test_generator.py b/tests/test_generator.py index 5d976a5..454380e 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,9 +1,12 @@ import argparse +import io from pathlib import Path from tempfile import TemporaryDirectory from unittest import TestCase from unittest.mock import create_autospec +from PIL.Image import open as pilopen + from devdocs2zim.client import ( DevdocsClient, DevdocsIndex, @@ -32,6 +35,7 @@ def defaults(self) -> ZimConfig: description_format="default_description_format", long_description_format="default_long_description_format", tags="default_tag1;default_tag2", + logo_format="default_logo_format", ) def test_flag_parsing_defaults(self): @@ -66,6 +70,8 @@ def test_flag_parsing_overrides(self): "long-description-format", "--tags", "tag1;tag2", + "--logo-format", + "logo-format", ] ) ) @@ -80,6 +86,7 @@ def test_flag_parsing_overrides(self): description_format="description-format", long_description_format="long-description-format", tags="tag1;tag2", + logo_format="logo-format", ), got, ) @@ -101,6 +108,7 @@ def test_format_only_allowed(self): description_format="{replace_me}", long_description_format="{replace_me}", tags="{replace_me}", + logo_format="{replace_me}", ) got = to_format.format({"replace_me": "replaced"}) @@ -115,6 +123,7 @@ def test_format_only_allowed(self): description_format="replaced", long_description_format="replaced", tags="replaced", + logo_format="replaced", ), got, ) @@ -426,3 +435,39 @@ def test_page_titles_only_fragment(self): # First fragment wins if no page points to the top self.assertEqual({"mock": "Mock Sub1"}, got) + + def test_fetch_logo_bytes_jpeg(self): + jpg_path = str(Path(__file__).parent / "testdata" / "test.jpg") + + got = Generator.fetch_logo_bytes(jpg_path) + + self.assertIsNotNone(got) + with pilopen(io.BytesIO(got)) as image: + self.assertEqual((48, 48), image.size) + self.assertEqual("PNG", image.format) + + def test_fetch_logo_bytes_png(self): + png_path = str(Path(__file__).parent / "testdata" / "test.png") + + got = Generator.fetch_logo_bytes(png_path) + + self.assertIsNotNone(got) + with pilopen(io.BytesIO(got)) as image: + self.assertEqual((48, 48), image.size) + self.assertEqual("PNG", image.format) + + def test_fetch_logo_bytes_svg(self): + png_path = str(Path(__file__).parent / "testdata" / "test.svg") + + got = Generator.fetch_logo_bytes(png_path) + + self.assertIsNotNone(got) + with pilopen(io.BytesIO(got)) as image: + self.assertEqual((48, 48), image.size) + self.assertEqual("PNG", image.format) + + def test_fetch_logo_bytes_does_not_exist_fails(self): + self.assertRaises(OSError, Generator.fetch_logo_bytes, "does_not_exist") + + def test_fetch_logo_bytes_returns_none_fails(self): + self.assertRaises(Exception, Generator.fetch_logo_bytes, "") diff --git a/tests/testdata/test.jpg b/tests/testdata/test.jpg new file mode 100644 index 0000000..5c66b6d Binary files /dev/null and b/tests/testdata/test.jpg differ diff --git a/tests/testdata/test.png b/tests/testdata/test.png new file mode 100644 index 0000000..ebf894d Binary files /dev/null and b/tests/testdata/test.png differ diff --git a/tests/testdata/test.svg b/tests/testdata/test.svg new file mode 100644 index 0000000..fa24d8b --- /dev/null +++ b/tests/testdata/test.svg @@ -0,0 +1,7 @@ + + + + + testSVG + +