diff --git a/src/matrix_content_scanner/scanner/file_downloader.py b/src/matrix_content_scanner/scanner/file_downloader.py index b1757f8..389030d 100644 --- a/src/matrix_content_scanner/scanner/file_downloader.py +++ b/src/matrix_content_scanner/scanner/file_downloader.py @@ -60,7 +60,9 @@ async def download_file( ContentScannerRestError: The file was not found or could not be downloaded due to an error on the remote homeserver's side. """ - url = await self._build_https_url(media_path) + url = await self._build_https_url( + media_path, for_thumbnail=thumbnail_params is not None + ) # Attempt to retrieve the file at the generated URL. try: @@ -71,7 +73,11 @@ async def download_file( # again with an r0 endpoint. logger.info("File not found, trying legacy r0 path") - url = await self._build_https_url(media_path, endpoint_version="r0") + url = await self._build_https_url( + media_path, + endpoint_version="r0", + for_thumbnail=thumbnail_params is not None, + ) try: file = await self._get_file_content(url, thumbnail_params) @@ -89,6 +95,8 @@ async def _build_https_url( self, media_path: str, endpoint_version: str = "v3", + *, + for_thumbnail: bool, ) -> str: """Turn a `server_name/media_id` path into an https:// one we can use to fetch the media. @@ -100,6 +108,9 @@ async def _build_https_url( media_path: The media path to translate. endpoint_version: The version of the download endpoint to use. As of Matrix v1.1, this is either "v3" or "r0". + for_thumbnail: True if a server-side thumbnail is desired instead of the full + media. In that case, the URL for the `/thumbnail` endpoint is returned + instead of the `/download` endpoint. Returns: An https URL to use. If `base_homeserver_url` is set in the config, this @@ -129,7 +140,9 @@ async def _build_https_url( # didn't find a .well-known file. base_url = "https://" + server_name - prefix = self.MEDIA_DOWNLOAD_PREFIX + prefix = ( + self.MEDIA_THUMBNAIL_PREFIX if for_thumbnail else self.MEDIA_DOWNLOAD_PREFIX + ) # Build the full URL. path_prefix = prefix % endpoint_version diff --git a/tests/scanner/test_file_downloader.py b/tests/scanner/test_file_downloader.py index 784f813..ce709d4 100644 --- a/tests/scanner/test_file_downloader.py +++ b/tests/scanner/test_file_downloader.py @@ -69,8 +69,8 @@ async def test_download(self) -> None: self.assertEqual(media.content_type, "image/png") # Check that we tried downloading from the set base URL. - args = self.get_mock.call_args - self.assertTrue(args[0][0].startswith("http://my-site.com/")) + args = self.get_mock.call_args.args + self.assertTrue(args[0].startswith("http://my-site.com/")) async def test_no_base_url(self) -> None: """Tests that configuring a base homeserver URL means files are downloaded from @@ -146,7 +146,9 @@ async def test_thumbnail(self) -> None: MEDIA_PATH, to_thumbnail_params({"height": "50"}) ) - query: CIMultiDictProxy[str] = self.get_mock.call_args[1]["query"] + url: str = self.get_mock.call_args.args[0] + query: CIMultiDictProxy[str] = self.get_mock.call_args.kwargs["query"] + self.assertIn("/thumbnail/", url) self.assertIn("height", query) self.assertEqual(query.get("height"), "50", query.getall("height"))