Skip to content

Commit

Permalink
Fix bug where the thumbnail endpoint was not used for downloading thu…
Browse files Browse the repository at this point in the history
…mbnails (#9)

Closes #8
  • Loading branch information
reivilibre committed Nov 19, 2024
1 parent ed4bf23 commit 1d0a3fd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
19 changes: 16 additions & 3 deletions src/matrix_content_scanner/scanner/file_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ async def download_file(
ContentScannerRestError: The file was not found or could not be downloaded due
to an error on the remote homeserver's side.
"""
url = await self._build_https_url(media_path)
url = await self._build_https_url(
media_path, for_thumbnail=thumbnail_params is not None
)

# Attempt to retrieve the file at the generated URL.
try:
Expand All @@ -71,7 +73,11 @@ async def download_file(
# again with an r0 endpoint.
logger.info("File not found, trying legacy r0 path")

url = await self._build_https_url(media_path, endpoint_version="r0")
url = await self._build_https_url(
media_path,
endpoint_version="r0",
for_thumbnail=thumbnail_params is not None,
)

try:
file = await self._get_file_content(url, thumbnail_params)
Expand All @@ -89,6 +95,8 @@ async def _build_https_url(
self,
media_path: str,
endpoint_version: str = "v3",
*,
for_thumbnail: bool,
) -> str:
"""Turn a `server_name/media_id` path into an https:// one we can use to fetch
the media.
Expand All @@ -100,6 +108,9 @@ async def _build_https_url(
media_path: The media path to translate.
endpoint_version: The version of the download endpoint to use. As of Matrix
v1.1, this is either "v3" or "r0".
for_thumbnail: True if a server-side thumbnail is desired instead of the full
media. In that case, the URL for the `/thumbnail` endpoint is returned
instead of the `/download` endpoint.
Returns:
An https URL to use. If `base_homeserver_url` is set in the config, this
Expand Down Expand Up @@ -129,7 +140,9 @@ async def _build_https_url(
# didn't find a .well-known file.
base_url = "https://" + server_name

prefix = self.MEDIA_DOWNLOAD_PREFIX
prefix = (
self.MEDIA_THUMBNAIL_PREFIX if for_thumbnail else self.MEDIA_DOWNLOAD_PREFIX
)

# Build the full URL.
path_prefix = prefix % endpoint_version
Expand Down
8 changes: 5 additions & 3 deletions tests/scanner/test_file_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ async def test_download(self) -> None:
self.assertEqual(media.content_type, "image/png")

# Check that we tried downloading from the set base URL.
args = self.get_mock.call_args
self.assertTrue(args[0][0].startswith("http://my-site.com/"))
args = self.get_mock.call_args.args
self.assertTrue(args[0].startswith("http://my-site.com/"))

async def test_no_base_url(self) -> None:
"""Tests that configuring a base homeserver URL means files are downloaded from
Expand Down Expand Up @@ -146,7 +146,9 @@ async def test_thumbnail(self) -> None:
MEDIA_PATH, to_thumbnail_params({"height": "50"})
)

query: CIMultiDictProxy[str] = self.get_mock.call_args[1]["query"]
url: str = self.get_mock.call_args.args[0]
query: CIMultiDictProxy[str] = self.get_mock.call_args.kwargs["query"]
self.assertIn("/thumbnail/", url)
self.assertIn("height", query)
self.assertEqual(query.get("height"), "50", query.getall("height"))

Expand Down

0 comments on commit 1d0a3fd

Please sign in to comment.