From 07999461c5c5202f2ca0880ba94f6f26098f9f0f Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 3 Dec 2024 11:35:39 -0700 Subject: [PATCH 1/6] Initial auth media work --- .../scanner/file_downloader.py | 43 +++++++++++-------- src/matrix_content_scanner/scanner/scanner.py | 10 ++++- .../servlets/download.py | 7 +-- src/matrix_content_scanner/servlets/scan.py | 7 +-- .../servlets/thumbnail.py | 1 + tests/scanner/test_file_downloader.py | 10 ++--- tests/scanner/test_scanner.py | 1 + 7 files changed, 48 insertions(+), 31 deletions(-) diff --git a/src/matrix_content_scanner/scanner/file_downloader.py b/src/matrix_content_scanner/scanner/file_downloader.py index 389030d..935c2cc 100644 --- a/src/matrix_content_scanner/scanner/file_downloader.py +++ b/src/matrix_content_scanner/scanner/file_downloader.py @@ -33,6 +33,8 @@ class _PathNotFoundException(Exception): class FileDownloader: MEDIA_DOWNLOAD_PREFIX = "_matrix/media/%s/download" MEDIA_THUMBNAIL_PREFIX = "_matrix/media/%s/thumbnail" + MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX = "_matrix/client/v1/media/download" + MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX = "_matrix/client/v1/media/thumbnail" def __init__(self, mcs: "MatrixContentScanner"): self._base_url = mcs.config.download.base_homeserver_url @@ -44,6 +46,7 @@ async def download_file( self, media_path: str, thumbnail_params: Optional[MultiMapping[str]] = None, + auth_header: Optional[str] = None, ) -> MediaDescription: """Retrieve the file with the given `server_name/media_id` path, and stores it on disk. @@ -52,6 +55,8 @@ async def download_file( media_path: The path identifying the media to retrieve. thumbnail_params: If present, then we want to request and scan a thumbnail generated with the provided parameters instead of the full media. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the file (including its full content). @@ -60,27 +65,26 @@ async def download_file( ContentScannerRestError: The file was not found or could not be downloaded due to an error on the remote homeserver's side. """ - url = await self._build_https_url( - media_path, for_thumbnail=thumbnail_params is not None - ) + + prefix = self.MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX if auth_header is not None else self.MEDIA_DOWNLOAD_PREFIX + if thumbnail_params is not None: + prefix = self.MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX if auth_header is not None else self.MEDIA_THUMBNAIL_PREFIX + + url = await self._build_https_url(media_path, prefix) # Attempt to retrieve the file at the generated URL. try: - file = await self._get_file_content(url, thumbnail_params) + file = await self._get_file_content(url, thumbnail_params, auth_header) except _PathNotFoundException: # If the file could not be found, it might be because the homeserver hasn't # been upgraded to a version that supports Matrix v1.1 endpoints yet, so try # again with an r0 endpoint. logger.info("File not found, trying legacy r0 path") - url = await self._build_https_url( - media_path, - endpoint_version="r0", - for_thumbnail=thumbnail_params is not None, - ) + url = await self._build_https_url(media_path, prefix, endpoint_version="r0") try: - file = await self._get_file_content(url, thumbnail_params) + file = await self._get_file_content(url, thumbnail_params, auth_header) except _PathNotFoundException: # If that still failed, raise an error. raise ContentScannerRestError( @@ -94,9 +98,8 @@ async def download_file( async def _build_https_url( self, media_path: str, + prefix: str, endpoint_version: str = "v3", - *, - for_thumbnail: bool, ) -> str: """Turn a `server_name/media_id` path into an https:// one we can use to fetch the media. @@ -108,9 +111,6 @@ async def _build_https_url( media_path: The media path to translate. endpoint_version: The version of the download endpoint to use. As of Matrix v1.1, this is either "v3" or "r0". - for_thumbnail: True if a server-side thumbnail is desired instead of the full - media. In that case, the URL for the `/thumbnail` endpoint is returned - instead of the `/download` endpoint. Returns: An https URL to use. If `base_homeserver_url` is set in the config, this @@ -140,9 +140,6 @@ async def _build_https_url( # didn't find a .well-known file. base_url = "https://" + server_name - prefix = ( - self.MEDIA_THUMBNAIL_PREFIX if for_thumbnail else self.MEDIA_DOWNLOAD_PREFIX - ) # Build the full URL. path_prefix = prefix % endpoint_version @@ -159,12 +156,15 @@ async def _get_file_content( self, url: str, thumbnail_params: Optional[MultiMapping[str]], + auth_header: Optional[str] = None, ) -> MediaDescription: """Retrieve the content of the file at a given URL. Args: url: The URL to query. thumbnail_params: Query parameters used if the request is for a thumbnail. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the file (including its full content). @@ -178,7 +178,7 @@ async def _get_file_content( ContentScannerRestError: the server returned a non-200 status which cannot meant that the path wasn't understood. """ - code, body, headers = await self._get(url, query=thumbnail_params) + code, body, headers = await self._get(url, query=thumbnail_params, auth_header=auth_header) logger.info("Remote server responded with %d", code) @@ -307,12 +307,15 @@ async def _get( self, url: str, query: Optional[MultiMapping[str]] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, bytes, CIMultiDictProxy[str]]: """Sends a GET request to the provided URL. Args: url: The URL to send requests to. query: Optional parameters to use in the request's query string. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: The HTTP status code, body and headers the remote server responded with. @@ -324,6 +327,8 @@ async def _get( try: logger.info("Sending GET request to %s", url) async with aiohttp.ClientSession() as session: + if auth_header is not None: + self._headers.update("Authorization", auth_header) async with session.get( url, proxy=self._proxy_url, diff --git a/src/matrix_content_scanner/scanner/scanner.py b/src/matrix_content_scanner/scanner/scanner.py index 53e9806..4314bd1 100644 --- a/src/matrix_content_scanner/scanner/scanner.py +++ b/src/matrix_content_scanner/scanner/scanner.py @@ -100,6 +100,7 @@ async def scan_file( media_path: str, metadata: Optional[JsonDict] = None, thumbnail_params: Optional["MultiMapping[str]"] = None, + auth_header: Optional[str] = None ) -> MediaDescription: """Download and scan the given media. @@ -119,6 +120,8 @@ async def scan_file( the file isn't encrypted. thumbnail_params: If present, then we want to request and scan a thumbnail generated with the provided parameters instead of the full media. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the media. @@ -141,7 +144,7 @@ async def scan_file( # Try to download and scan the file. try: res = await self._scan_file( - cache_key, media_path, metadata, thumbnail_params + cache_key, media_path, metadata, thumbnail_params, auth_header ) # Set the future's result, and mark it as done. f.set_result(res) @@ -168,6 +171,7 @@ async def _scan_file( media_path: str, metadata: Optional[JsonDict] = None, thumbnail_params: Optional[MultiMapping[str]] = None, + auth_header: Optional[str] = None, ) -> MediaDescription: """Download and scan the given media. @@ -185,6 +189,8 @@ async def _scan_file( the file isn't encrypted. thumbnail_params: If present, then we want to request and scan a thumbnail generated with the provided parameters instead of the full media. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the media. @@ -218,6 +224,7 @@ async def _scan_file( media = await self._file_downloader.download_file( media_path=media_path, thumbnail_params=thumbnail_params, + auth_header=auth_header, ) # Compare the media's hash to ensure the server hasn't changed the file since @@ -251,6 +258,7 @@ async def _scan_file( media = await self._file_downloader.download_file( media_path=media_path, thumbnail_params=thumbnail_params, + auth_header=auth_header, ) # Download and scan the file. diff --git a/src/matrix_content_scanner/servlets/download.py b/src/matrix_content_scanner/servlets/download.py index 681cfb4..7f3fb54 100644 --- a/src/matrix_content_scanner/servlets/download.py +++ b/src/matrix_content_scanner/servlets/download.py @@ -26,8 +26,9 @@ async def _scan( self, media_path: str, metadata: Optional[JsonDict] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, _BytesResponse]: - media = await self._scanner.scan_file(media_path, metadata) + media = await self._scanner.scan_file(media_path, metadata, auth_header=auth_header) return 200, _BytesResponse( headers=media.response_headers, @@ -38,7 +39,7 @@ async def _scan( async def handle_plain(self, request: web.Request) -> Tuple[int, _BytesResponse]: """Handles GET requests to ../download/serverName/mediaId""" media_path = request.match_info["media_path"] - return await self._scan(media_path) + return await self._scan(media_path, auth_header=request.headers.get("Authorization")) @web_handler async def handle_encrypted( @@ -49,4 +50,4 @@ async def handle_encrypted( request, self._crypto_handler ) - return await self._scan(media_path, metadata) + return await self._scan(media_path, metadata, auth_header=request.headers.get("Authorization")) diff --git a/src/matrix_content_scanner/servlets/scan.py b/src/matrix_content_scanner/servlets/scan.py index c3f2060..0153458 100644 --- a/src/matrix_content_scanner/servlets/scan.py +++ b/src/matrix_content_scanner/servlets/scan.py @@ -23,9 +23,10 @@ async def _scan_and_format( self, media_path: str, metadata: Optional[JsonDict] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, JsonDict]: try: - await self._scanner.scan_file(media_path, metadata) + await self._scanner.scan_file(media_path, metadata, auth_header=auth_header) except FileDirtyError as e: res = {"clean": False, "info": e.info} else: @@ -37,7 +38,7 @@ async def _scan_and_format( async def handle_plain(self, request: web.Request) -> Tuple[int, JsonDict]: """Handles GET requests to ../scan/serverName/mediaId""" media_path = request.match_info["media_path"] - return await self._scan_and_format(media_path) + return await self._scan_and_format(media_path, auth_header=request.headers.get("Authorization")) @web_handler async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]: @@ -45,4 +46,4 @@ async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]: media_path, metadata = await get_media_metadata_from_request( request, self._crypto_handler ) - return await self._scan_and_format(media_path, metadata) + return await self._scan_and_format(media_path, metadata, auth_header=request.headers.get("Authorization")) diff --git a/src/matrix_content_scanner/servlets/thumbnail.py b/src/matrix_content_scanner/servlets/thumbnail.py index dfae44a..9a553f7 100644 --- a/src/matrix_content_scanner/servlets/thumbnail.py +++ b/src/matrix_content_scanner/servlets/thumbnail.py @@ -26,6 +26,7 @@ async def handle_thumbnail( media = await self._scanner.scan_file( media_path=media_path, thumbnail_params=request.query, + auth_header=request.headers.get("Authorization") ) return 200, _BytesResponse( diff --git a/tests/scanner/test_file_downloader.py b/tests/scanner/test_file_downloader.py index ce709d4..bcc5edc 100644 --- a/tests/scanner/test_file_downloader.py +++ b/tests/scanner/test_file_downloader.py @@ -37,7 +37,7 @@ def setUp(self) -> None: self.media_headers = get_base_media_headers() async def _get( - url: str, query: Optional[MultiDictProxy[str]] = None + url: str, query: Optional[MultiDictProxy[str]] = None, auth_header: Optional[str] = None, ) -> Tuple[int, bytes, CIMultiDictProxy[str]]: """Mock for the _get method on the file downloader that doesn't serve a .well-known client file. @@ -88,7 +88,7 @@ async def test_no_base_url(self) -> None: ) self.assertEqual( self.get_mock.mock_calls[1], - call("https://foo/_matrix/media/v3/download/" + MEDIA_PATH, query=None), + call("https://foo/_matrix/media/v3/download/" + MEDIA_PATH, query=None, auth_header=None), ) async def test_retry_on_404(self) -> None: @@ -128,13 +128,13 @@ async def _test_retry(self) -> None: self.assertEqual( self.get_mock.mock_calls[0], call( - "http://my-site.com/_matrix/media/v3/download/" + MEDIA_PATH, query=None + "http://my-site.com/_matrix/media/v3/download/" + MEDIA_PATH, query=None, auth_header=None, ), ) self.assertEqual( self.get_mock.mock_calls[1], call( - "http://my-site.com/_matrix/media/r0/download/" + MEDIA_PATH, query=None + "http://my-site.com/_matrix/media/r0/download/" + MEDIA_PATH, query=None, auth_header=None, ), ) @@ -203,7 +203,7 @@ def setUp(self) -> None: self.versions_status = 200 async def _get( - url: str, query: Optional[MultiDictProxy[str]] = None + url: str, query: Optional[MultiDictProxy[str]] = None, auth_header: Optional[str] = None, ) -> Tuple[int, bytes, CIMultiDictProxy[str]]: """Mock for the _get method on the file downloader that serves a .well-known client file. diff --git a/tests/scanner/test_scanner.py b/tests/scanner/test_scanner.py index 6ea61fd..f91950d 100644 --- a/tests/scanner/test_scanner.py +++ b/tests/scanner/test_scanner.py @@ -42,6 +42,7 @@ def setUp(self) -> None: async def download_file( media_path: str, thumbnail_params: Optional[Dict[str, List[str]]] = None, + auth_header: Optional[str] = None, ) -> MediaDescription: """Mock for the file downloader's `download_file` method.""" return self.downloader_res From 910bbfd1747a3f5f7bc9a1198096df891e815c0c Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 3 Dec 2024 15:18:17 -0700 Subject: [PATCH 2/6] Fixup auth media logic --- .../scanner/file_downloader.py | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/src/matrix_content_scanner/scanner/file_downloader.py b/src/matrix_content_scanner/scanner/file_downloader.py index 935c2cc..49c9230 100644 --- a/src/matrix_content_scanner/scanner/file_downloader.py +++ b/src/matrix_content_scanner/scanner/file_downloader.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: AGPL-3.0-only # Please see LICENSE in the repository root for full details. +import copy import json import logging import urllib.parse @@ -33,8 +34,8 @@ class _PathNotFoundException(Exception): class FileDownloader: MEDIA_DOWNLOAD_PREFIX = "_matrix/media/%s/download" MEDIA_THUMBNAIL_PREFIX = "_matrix/media/%s/thumbnail" - MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX = "_matrix/client/v1/media/download" - MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX = "_matrix/client/v1/media/thumbnail" + MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/download" + MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/thumbnail" def __init__(self, mcs: "MatrixContentScanner"): self._base_url = mcs.config.download.base_homeserver_url @@ -66,22 +67,41 @@ async def download_file( to an error on the remote homeserver's side. """ - prefix = self.MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX if auth_header is not None else self.MEDIA_DOWNLOAD_PREFIX + auth_media = True if auth_header is not None else False + + prefix = ( + self.MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX + if auth_media + else self.MEDIA_DOWNLOAD_PREFIX + ) if thumbnail_params is not None: - prefix = self.MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX if auth_header is not None else self.MEDIA_THUMBNAIL_PREFIX + prefix = ( + self.MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX + if auth_media + else self.MEDIA_THUMBNAIL_PREFIX + ) - url = await self._build_https_url(media_path, prefix) + url = await self._build_https_url( + media_path, prefix, "v1" if auth_media else "v3" + ) # Attempt to retrieve the file at the generated URL. try: file = await self._get_file_content(url, thumbnail_params, auth_header) except _PathNotFoundException: + if auth_media: + raise ContentScannerRestError( + http_status=HTTPStatus.NOT_FOUND, + reason=ErrCode.NOT_FOUND, + info="File not found", + ) + # If the file could not be found, it might be because the homeserver hasn't # been upgraded to a version that supports Matrix v1.1 endpoints yet, so try # again with an r0 endpoint. logger.info("File not found, trying legacy r0 path") - url = await self._build_https_url(media_path, prefix, endpoint_version="r0") + url = await self._build_https_url(media_path, prefix, "r0") try: file = await self._get_file_content(url, thumbnail_params, auth_header) @@ -99,7 +119,7 @@ async def _build_https_url( self, media_path: str, prefix: str, - endpoint_version: str = "v3", + endpoint_version: str, ) -> str: """Turn a `server_name/media_id` path into an https:// one we can use to fetch the media. @@ -110,7 +130,8 @@ async def _build_https_url( Args: media_path: The media path to translate. endpoint_version: The version of the download endpoint to use. As of Matrix - v1.1, this is either "v3" or "r0". + v1.11, this is "v1" for authenticated media. For unauthenticated media + this is either "v3" or "r0". Returns: An https URL to use. If `base_homeserver_url` is set in the config, this @@ -140,7 +161,6 @@ async def _build_https_url( # didn't find a .well-known file. base_url = "https://" + server_name - # Build the full URL. path_prefix = prefix % endpoint_version url = "%s/%s/%s/%s" % ( @@ -327,12 +347,19 @@ async def _get( try: logger.info("Sending GET request to %s", url) async with aiohttp.ClientSession() as session: + # TODO: Test we don't persist auth token + request_headers = copy.deepcopy(self._headers) if auth_header is not None: - self._headers.update("Authorization", auth_header) + auth_dict = {"Authorization": auth_header} + if request_headers is None: + request_headers = auth_dict + else: + request_headers.update(auth_dict) + async with session.get( url, proxy=self._proxy_url, - headers=self._headers, + headers=request_headers, params=query, ) as resp: return resp.status, await resp.read(), resp.headers From d3fdb3ba8400b01899f497fa2cf7862ac8f62510 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 3 Dec 2024 15:18:34 -0700 Subject: [PATCH 3/6] Run code format --- .../scanner/file_downloader.py | 4 +++- src/matrix_content_scanner/scanner/scanner.py | 2 +- .../servlets/download.py | 12 +++++++--- src/matrix_content_scanner/servlets/scan.py | 8 +++++-- .../servlets/thumbnail.py | 2 +- tests/scanner/test_file_downloader.py | 22 ++++++++++++++----- 6 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/matrix_content_scanner/scanner/file_downloader.py b/src/matrix_content_scanner/scanner/file_downloader.py index 49c9230..63abd3a 100644 --- a/src/matrix_content_scanner/scanner/file_downloader.py +++ b/src/matrix_content_scanner/scanner/file_downloader.py @@ -198,7 +198,9 @@ async def _get_file_content( ContentScannerRestError: the server returned a non-200 status which cannot meant that the path wasn't understood. """ - code, body, headers = await self._get(url, query=thumbnail_params, auth_header=auth_header) + code, body, headers = await self._get( + url, query=thumbnail_params, auth_header=auth_header + ) logger.info("Remote server responded with %d", code) diff --git a/src/matrix_content_scanner/scanner/scanner.py b/src/matrix_content_scanner/scanner/scanner.py index 4314bd1..3807aa4 100644 --- a/src/matrix_content_scanner/scanner/scanner.py +++ b/src/matrix_content_scanner/scanner/scanner.py @@ -100,7 +100,7 @@ async def scan_file( media_path: str, metadata: Optional[JsonDict] = None, thumbnail_params: Optional["MultiMapping[str]"] = None, - auth_header: Optional[str] = None + auth_header: Optional[str] = None, ) -> MediaDescription: """Download and scan the given media. diff --git a/src/matrix_content_scanner/servlets/download.py b/src/matrix_content_scanner/servlets/download.py index 7f3fb54..a80d68b 100644 --- a/src/matrix_content_scanner/servlets/download.py +++ b/src/matrix_content_scanner/servlets/download.py @@ -28,7 +28,9 @@ async def _scan( metadata: Optional[JsonDict] = None, auth_header: Optional[str] = None, ) -> Tuple[int, _BytesResponse]: - media = await self._scanner.scan_file(media_path, metadata, auth_header=auth_header) + media = await self._scanner.scan_file( + media_path, metadata, auth_header=auth_header + ) return 200, _BytesResponse( headers=media.response_headers, @@ -39,7 +41,9 @@ async def _scan( async def handle_plain(self, request: web.Request) -> Tuple[int, _BytesResponse]: """Handles GET requests to ../download/serverName/mediaId""" media_path = request.match_info["media_path"] - return await self._scan(media_path, auth_header=request.headers.get("Authorization")) + return await self._scan( + media_path, auth_header=request.headers.get("Authorization") + ) @web_handler async def handle_encrypted( @@ -50,4 +54,6 @@ async def handle_encrypted( request, self._crypto_handler ) - return await self._scan(media_path, metadata, auth_header=request.headers.get("Authorization")) + return await self._scan( + media_path, metadata, auth_header=request.headers.get("Authorization") + ) diff --git a/src/matrix_content_scanner/servlets/scan.py b/src/matrix_content_scanner/servlets/scan.py index 0153458..343c324 100644 --- a/src/matrix_content_scanner/servlets/scan.py +++ b/src/matrix_content_scanner/servlets/scan.py @@ -38,7 +38,9 @@ async def _scan_and_format( async def handle_plain(self, request: web.Request) -> Tuple[int, JsonDict]: """Handles GET requests to ../scan/serverName/mediaId""" media_path = request.match_info["media_path"] - return await self._scan_and_format(media_path, auth_header=request.headers.get("Authorization")) + return await self._scan_and_format( + media_path, auth_header=request.headers.get("Authorization") + ) @web_handler async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]: @@ -46,4 +48,6 @@ async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]: media_path, metadata = await get_media_metadata_from_request( request, self._crypto_handler ) - return await self._scan_and_format(media_path, metadata, auth_header=request.headers.get("Authorization")) + return await self._scan_and_format( + media_path, metadata, auth_header=request.headers.get("Authorization") + ) diff --git a/src/matrix_content_scanner/servlets/thumbnail.py b/src/matrix_content_scanner/servlets/thumbnail.py index 9a553f7..ddba2bf 100644 --- a/src/matrix_content_scanner/servlets/thumbnail.py +++ b/src/matrix_content_scanner/servlets/thumbnail.py @@ -26,7 +26,7 @@ async def handle_thumbnail( media = await self._scanner.scan_file( media_path=media_path, thumbnail_params=request.query, - auth_header=request.headers.get("Authorization") + auth_header=request.headers.get("Authorization"), ) return 200, _BytesResponse( diff --git a/tests/scanner/test_file_downloader.py b/tests/scanner/test_file_downloader.py index bcc5edc..3b6f895 100644 --- a/tests/scanner/test_file_downloader.py +++ b/tests/scanner/test_file_downloader.py @@ -37,7 +37,9 @@ def setUp(self) -> None: self.media_headers = get_base_media_headers() async def _get( - url: str, query: Optional[MultiDictProxy[str]] = None, auth_header: Optional[str] = None, + url: str, + query: Optional[MultiDictProxy[str]] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, bytes, CIMultiDictProxy[str]]: """Mock for the _get method on the file downloader that doesn't serve a .well-known client file. @@ -88,7 +90,11 @@ async def test_no_base_url(self) -> None: ) self.assertEqual( self.get_mock.mock_calls[1], - call("https://foo/_matrix/media/v3/download/" + MEDIA_PATH, query=None, auth_header=None), + call( + "https://foo/_matrix/media/v3/download/" + MEDIA_PATH, + query=None, + auth_header=None, + ), ) async def test_retry_on_404(self) -> None: @@ -128,13 +134,17 @@ async def _test_retry(self) -> None: self.assertEqual( self.get_mock.mock_calls[0], call( - "http://my-site.com/_matrix/media/v3/download/" + MEDIA_PATH, query=None, auth_header=None, + "http://my-site.com/_matrix/media/v3/download/" + MEDIA_PATH, + query=None, + auth_header=None, ), ) self.assertEqual( self.get_mock.mock_calls[1], call( - "http://my-site.com/_matrix/media/r0/download/" + MEDIA_PATH, query=None, auth_header=None, + "http://my-site.com/_matrix/media/r0/download/" + MEDIA_PATH, + query=None, + auth_header=None, ), ) @@ -203,7 +213,9 @@ def setUp(self) -> None: self.versions_status = 200 async def _get( - url: str, query: Optional[MultiDictProxy[str]] = None, auth_header: Optional[str] = None, + url: str, + query: Optional[MultiDictProxy[str]] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, bytes, CIMultiDictProxy[str]]: """Mock for the _get method on the file downloader that serves a .well-known client file. From 7402baedea8dd17cc84323ec276a1991879b1c36 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 3 Dec 2024 18:45:25 -0700 Subject: [PATCH 4/6] Add tests for downloading auth media --- tests/scanner/test_file_downloader.py | 64 +++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/scanner/test_file_downloader.py b/tests/scanner/test_file_downloader.py index 3b6f895..765cadf 100644 --- a/tests/scanner/test_file_downloader.py +++ b/tests/scanner/test_file_downloader.py @@ -55,6 +55,14 @@ async def _get( or "/_matrix/media/r0/thumbnail/" + MEDIA_PATH in url ): return self.media_status, self.media_body, self.media_headers + if ( + url.endswith(("/_matrix/client/v1/media/download/" + MEDIA_PATH,)) + or "/_matrix/client/v1/media/thumbnail/" + MEDIA_PATH in url + ): + if auth_header is not None: + return self.media_status, self.media_body, self.media_headers + else: + return 404, b"Not found", CIMultiDictProxy(CIMultiDict()) elif url.endswith("/.well-known/matrix/client"): return 404, b"Not found", CIMultiDictProxy(CIMultiDict()) @@ -74,6 +82,19 @@ async def test_download(self) -> None: args = self.get_mock.call_args.args self.assertTrue(args[0].startswith("http://my-site.com/")) + async def test_download_auth_media(self) -> None: + """Tests that downloading a file works using authenticated media.""" + media = await self.downloader.download_file( + MEDIA_PATH, auth_header="Bearer access_token" + ) + self.assertEqual(media.content, SMALL_PNG) + self.assertEqual(media.content_type, "image/png") + + # Check that we tried downloading from the set base URL. + args = self.get_mock.call_args.args + self.assertTrue(args[0].startswith("http://my-site.com/")) + self.assertIn("/_matrix/client/v1/media/download/" + MEDIA_PATH, args[0]) + async def test_no_base_url(self) -> None: """Tests that configuring a base homeserver URL means files are downloaded from that homeserver (rather than the one the files were uploaded to) and .well-known @@ -148,6 +169,34 @@ async def _test_retry(self) -> None: ), ) + async def test_no_retry(self) -> None: + """Tests that in a set specific case a failure to download a file from a v1 + authenticated media download path means we don't retry the request. + """ + self.media_status = 400 + self.media_body = b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}' + self._set_headers({"content-type": ["application/json"]}) + + # Check that we eventually fail at downloading the file. + with self.assertRaises(ContentScannerRestError) as cm: + await self.downloader.download_file( + MEDIA_PATH, auth_header="Bearer access_token" + ) + + self.assertEqual(cm.exception.http_status, 404) + self.assertEqual(cm.exception.info, "File not found") + + # Check that we sent out only one request. + self.assertEqual(self.get_mock.call_count, 1) + self.assertEqual( + self.get_mock.mock_calls[0], + call( + "http://my-site.com/_matrix/client/v1/media/download/" + MEDIA_PATH, + query=None, + auth_header="Bearer access_token", + ), + ) + async def test_thumbnail(self) -> None: """Tests that we can download a thumbnail and that the parameters to generate the thumbnail are correctly passed on to the homeserver. @@ -162,6 +211,21 @@ async def test_thumbnail(self) -> None: self.assertIn("height", query) self.assertEqual(query.get("height"), "50", query.getall("height")) + async def test_thumbnail_auth_media(self) -> None: + """Tests that we can download a thumbnail and that the parameters to generate the + thumbnail are correctly passed on to the homeserver using authenticated media. + """ + await self.downloader.download_file( + MEDIA_PATH, to_thumbnail_params({"height": "50"}), "Bearer access_token" + ) + + url: str = self.get_mock.call_args.args[0] + query: CIMultiDictProxy[str] = self.get_mock.call_args.kwargs["query"] + self.assertIn("/thumbnail/", url) + self.assertIn("/_matrix/client/v1/media/thumbnail/" + MEDIA_PATH, url) + self.assertIn("height", query) + self.assertEqual(query.get("height"), "50", query.getall("height")) + async def test_multiple_content_type(self) -> None: """Tests that we raise an error if the homeserver responds with too many Content-Type headers. From c2e0d0469c87c78f43c3bc4a18873dae385a82d7 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 5 Dec 2024 13:39:58 -0700 Subject: [PATCH 5/6] Wrangle the request_headers more cleanly --- .../scanner/file_downloader.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/matrix_content_scanner/scanner/file_downloader.py b/src/matrix_content_scanner/scanner/file_downloader.py index 63abd3a..255a562 100644 --- a/src/matrix_content_scanner/scanner/file_downloader.py +++ b/src/matrix_content_scanner/scanner/file_downloader.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: AGPL-3.0-only # Please see LICENSE in the repository root for full details. -import copy import json import logging import urllib.parse @@ -349,14 +348,12 @@ async def _get( try: logger.info("Sending GET request to %s", url) async with aiohttp.ClientSession() as session: - # TODO: Test we don't persist auth token - request_headers = copy.deepcopy(self._headers) - if auth_header is not None: - auth_dict = {"Authorization": auth_header} - if request_headers is None: - request_headers = auth_dict - else: - request_headers.update(auth_dict) + if auth_header is None: + request_headers = self._headers + else: + request_headers = {"Authorization": auth_header} + if self._headers is not None: + request_headers = {**request_headers, **self._headers} async with session.get( url, From 4382430b02b89071aa4e8c71cc9e3a8b205ab498 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 5 Dec 2024 13:44:43 -0700 Subject: [PATCH 6/6] Wrangle headers more --- .../scanner/file_downloader.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/matrix_content_scanner/scanner/file_downloader.py b/src/matrix_content_scanner/scanner/file_downloader.py index 255a562..b1df7cb 100644 --- a/src/matrix_content_scanner/scanner/file_downloader.py +++ b/src/matrix_content_scanner/scanner/file_downloader.py @@ -40,7 +40,11 @@ def __init__(self, mcs: "MatrixContentScanner"): self._base_url = mcs.config.download.base_homeserver_url self._well_known_cache: Dict[str, Optional[str]] = {} self._proxy_url = mcs.config.download.proxy - self._headers = mcs.config.download.additional_headers + self._headers = ( + mcs.config.download.additional_headers + if mcs.config.download.additional_headers is not None + else {} + ) async def download_file( self, @@ -348,12 +352,10 @@ async def _get( try: logger.info("Sending GET request to %s", url) async with aiohttp.ClientSession() as session: - if auth_header is None: - request_headers = self._headers + if auth_header is not None: + request_headers = {"Authorization": auth_header, **self._headers} else: - request_headers = {"Authorization": auth_header} - if self._headers is not None: - request_headers = {**request_headers, **self._headers} + request_headers = self._headers async with session.get( url,