From 694914f92bf100a37f7e1baee527cfb3cd01303f Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 5 Dec 2024 20:59:27 +0000 Subject: [PATCH] Add support for authenticated media (#69) This PR adds support for authenticated media by passing the Authorization header information through from the client to the homeserver. Once MAS supports scoped access tokens this code should be changed over to use that. Huge shoutout to @S7evinK for doing the bulk of the implementation on this. --- .../scanner/file_downloader.py | 77 ++++++++++++----- src/matrix_content_scanner/scanner/scanner.py | 10 ++- .../servlets/download.py | 13 ++- src/matrix_content_scanner/servlets/scan.py | 11 ++- .../servlets/thumbnail.py | 1 + tests/scanner/test_file_downloader.py | 86 +++++++++++++++++-- tests/scanner/test_scanner.py | 1 + 7 files changed, 165 insertions(+), 34 deletions(-) diff --git a/src/matrix_content_scanner/scanner/file_downloader.py b/src/matrix_content_scanner/scanner/file_downloader.py index bcdab50..492e87f 100644 --- a/src/matrix_content_scanner/scanner/file_downloader.py +++ b/src/matrix_content_scanner/scanner/file_downloader.py @@ -33,17 +33,24 @@ class _PathNotFoundException(Exception): class FileDownloader: MEDIA_DOWNLOAD_PREFIX = "_matrix/media/%s/download" MEDIA_THUMBNAIL_PREFIX = "_matrix/media/%s/thumbnail" + MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/download" + MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/thumbnail" def __init__(self, mcs: "MatrixContentScanner"): self._base_url = mcs.config.download.base_homeserver_url self._well_known_cache: Dict[str, Optional[str]] = {} self._proxy_url = mcs.config.download.proxy - self._headers = mcs.config.download.additional_headers + self._headers = ( + mcs.config.download.additional_headers + if mcs.config.download.additional_headers is not None + else {} + ) async def download_file( self, media_path: str, thumbnail_params: Optional[MultiMapping[str]] = None, + auth_header: Optional[str] = None, ) -> MediaDescription: """Retrieve the file with the given `server_name/media_id` path, and stores it on disk. @@ -52,6 +59,8 @@ async def download_file( media_path: The path identifying the media to retrieve. thumbnail_params: If present, then we want to request and scan a thumbnail generated with the provided parameters instead of the full media. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the file (including its full content). @@ -60,27 +69,45 @@ async def download_file( ContentScannerRestError: The file was not found or could not be downloaded due to an error on the remote homeserver's side. """ + + auth_media = True if auth_header is not None else False + + prefix = ( + self.MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX + if auth_media + else self.MEDIA_DOWNLOAD_PREFIX + ) + if thumbnail_params is not None: + prefix = ( + self.MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX + if auth_media + else self.MEDIA_THUMBNAIL_PREFIX + ) + url = await self._build_https_url( - media_path, for_thumbnail=thumbnail_params is not None + media_path, prefix, "v1" if auth_media else "v3" ) # Attempt to retrieve the file at the generated URL. try: - file = await self._get_file_content(url, thumbnail_params) + file = await self._get_file_content(url, thumbnail_params, auth_header) except _PathNotFoundException: + if auth_media: + raise ContentScannerRestError( + http_status=HTTPStatus.NOT_FOUND, + reason=ErrCode.NOT_FOUND, + info="File not found", + ) + # If the file could not be found, it might be because the homeserver hasn't # been upgraded to a version that supports Matrix v1.1 endpoints yet, so try # again with an r0 endpoint. logger.info("File not found, trying legacy r0 path") - url = await self._build_https_url( - media_path, - endpoint_version="r0", - for_thumbnail=thumbnail_params is not None, - ) + url = await self._build_https_url(media_path, prefix, "r0") try: - file = await self._get_file_content(url, thumbnail_params) + file = await self._get_file_content(url, thumbnail_params, auth_header) except _PathNotFoundException: # If that still failed, raise an error. raise ContentScannerRestError( @@ -94,9 +121,8 @@ async def download_file( async def _build_https_url( self, media_path: str, - endpoint_version: str = "v3", - *, - for_thumbnail: bool, + prefix: str, + endpoint_version: str, ) -> str: """Turn a `server_name/media_id` path into an https:// one we can use to fetch the media. @@ -107,10 +133,8 @@ async def _build_https_url( Args: media_path: The media path to translate. endpoint_version: The version of the download endpoint to use. As of Matrix - v1.1, this is either "v3" or "r0". - for_thumbnail: True if a server-side thumbnail is desired instead of the full - media. In that case, the URL for the `/thumbnail` endpoint is returned - instead of the `/download` endpoint. + v1.11, this is "v1" for authenticated media. For unauthenticated media + this is either "v3" or "r0". Returns: An https URL to use. If `base_homeserver_url` is set in the config, this @@ -140,10 +164,6 @@ async def _build_https_url( # didn't find a .well-known file. base_url = "https://" + server_name - prefix = ( - self.MEDIA_THUMBNAIL_PREFIX if for_thumbnail else self.MEDIA_DOWNLOAD_PREFIX - ) - # Build the full URL. path_prefix = prefix % endpoint_version url = "%s/%s/%s/%s" % ( @@ -159,12 +179,15 @@ async def _get_file_content( self, url: str, thumbnail_params: Optional[MultiMapping[str]], + auth_header: Optional[str] = None, ) -> MediaDescription: """Retrieve the content of the file at a given URL. Args: url: The URL to query. thumbnail_params: Query parameters used if the request is for a thumbnail. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the file (including its full content). @@ -178,7 +201,9 @@ async def _get_file_content( ContentScannerRestError: the server returned a non-200 status which cannot meant that the path wasn't understood. """ - code, body, headers = await self._get(url, query=thumbnail_params) + code, body, headers = await self._get( + url, query=thumbnail_params, auth_header=auth_header + ) logger.info("Remote server responded with %d", code) @@ -307,12 +332,15 @@ async def _get( self, url: str, query: Optional[MultiMapping[str]] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, bytes, CIMultiDictProxy[str]]: """Sends a GET request to the provided URL. Args: url: The URL to send requests to. query: Optional parameters to use in the request's query string. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: The HTTP status code, body and headers the remote server responded with. @@ -324,10 +352,15 @@ async def _get( try: logger.info("Sending GET request to %s", url) async with aiohttp.ClientSession() as session: + if auth_header is not None: + request_headers = {"Authorization": auth_header, **self._headers} + else: + request_headers = self._headers + async with session.get( url, proxy=self._proxy_url, - headers=self._headers, + headers=request_headers, params=query, ) as resp: return resp.status, await resp.read(), resp.headers diff --git a/src/matrix_content_scanner/scanner/scanner.py b/src/matrix_content_scanner/scanner/scanner.py index 9e27958..1d49b2d 100644 --- a/src/matrix_content_scanner/scanner/scanner.py +++ b/src/matrix_content_scanner/scanner/scanner.py @@ -100,6 +100,7 @@ async def scan_file( media_path: str, metadata: Optional[JsonDict] = None, thumbnail_params: Optional["MultiMapping[str]"] = None, + auth_header: Optional[str] = None, ) -> MediaDescription: """Download and scan the given media. @@ -119,6 +120,8 @@ async def scan_file( the file isn't encrypted. thumbnail_params: If present, then we want to request and scan a thumbnail generated with the provided parameters instead of the full media. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the media. @@ -141,7 +144,7 @@ async def scan_file( # Try to download and scan the file. try: res = await self._scan_file( - cache_key, media_path, metadata, thumbnail_params + cache_key, media_path, metadata, thumbnail_params, auth_header ) # Set the future's result, and mark it as done. f.set_result(res) @@ -168,6 +171,7 @@ async def _scan_file( media_path: str, metadata: Optional[JsonDict] = None, thumbnail_params: Optional[MultiMapping[str]] = None, + auth_header: Optional[str] = None, ) -> MediaDescription: """Download and scan the given media. @@ -185,6 +189,8 @@ async def _scan_file( the file isn't encrypted. thumbnail_params: If present, then we want to request and scan a thumbnail generated with the provided parameters instead of the full media. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the media. @@ -218,6 +224,7 @@ async def _scan_file( media = await self._file_downloader.download_file( media_path=media_path, thumbnail_params=thumbnail_params, + auth_header=auth_header, ) # Compare the media's hash to ensure the server hasn't changed the file since @@ -251,6 +258,7 @@ async def _scan_file( media = await self._file_downloader.download_file( media_path=media_path, thumbnail_params=thumbnail_params, + auth_header=auth_header, ) # Download and scan the file. diff --git a/src/matrix_content_scanner/servlets/download.py b/src/matrix_content_scanner/servlets/download.py index 6f1626c..d18f502 100644 --- a/src/matrix_content_scanner/servlets/download.py +++ b/src/matrix_content_scanner/servlets/download.py @@ -26,8 +26,11 @@ async def _scan( self, media_path: str, metadata: Optional[JsonDict] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, _BytesResponse]: - media = await self._scanner.scan_file(media_path, metadata) + media = await self._scanner.scan_file( + media_path, metadata, auth_header=auth_header + ) return 200, _BytesResponse( headers=media.response_headers, @@ -38,7 +41,9 @@ async def _scan( async def handle_plain(self, request: web.Request) -> Tuple[int, _BytesResponse]: """Handles GET requests to ../download/serverName/mediaId""" media_path = request.match_info["media_path"] - return await self._scan(media_path) + return await self._scan( + media_path, auth_header=request.headers.get("Authorization") + ) @web_handler async def handle_encrypted( @@ -49,4 +54,6 @@ async def handle_encrypted( request, self._crypto_handler ) - return await self._scan(media_path, metadata) + return await self._scan( + media_path, metadata, auth_header=request.headers.get("Authorization") + ) diff --git a/src/matrix_content_scanner/servlets/scan.py b/src/matrix_content_scanner/servlets/scan.py index dc13cdb..461581e 100644 --- a/src/matrix_content_scanner/servlets/scan.py +++ b/src/matrix_content_scanner/servlets/scan.py @@ -23,9 +23,10 @@ async def _scan_and_format( self, media_path: str, metadata: Optional[JsonDict] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, JsonDict]: try: - await self._scanner.scan_file(media_path, metadata) + await self._scanner.scan_file(media_path, metadata, auth_header=auth_header) except FileDirtyError as e: res = {"clean": False, "info": e.info} else: @@ -37,7 +38,9 @@ async def _scan_and_format( async def handle_plain(self, request: web.Request) -> Tuple[int, JsonDict]: """Handles GET requests to ../scan/serverName/mediaId""" media_path = request.match_info["media_path"] - return await self._scan_and_format(media_path) + return await self._scan_and_format( + media_path, auth_header=request.headers.get("Authorization") + ) @web_handler async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]: @@ -45,4 +48,6 @@ async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]: media_path, metadata = await get_media_metadata_from_request( request, self._crypto_handler ) - return await self._scan_and_format(media_path, metadata) + return await self._scan_and_format( + media_path, metadata, auth_header=request.headers.get("Authorization") + ) diff --git a/src/matrix_content_scanner/servlets/thumbnail.py b/src/matrix_content_scanner/servlets/thumbnail.py index 9b0dd72..08bd959 100644 --- a/src/matrix_content_scanner/servlets/thumbnail.py +++ b/src/matrix_content_scanner/servlets/thumbnail.py @@ -26,6 +26,7 @@ async def handle_thumbnail( media = await self._scanner.scan_file( media_path=media_path, thumbnail_params=request.query, + auth_header=request.headers.get("Authorization"), ) return 200, _BytesResponse( diff --git a/tests/scanner/test_file_downloader.py b/tests/scanner/test_file_downloader.py index 9d10d61..46b8828 100644 --- a/tests/scanner/test_file_downloader.py +++ b/tests/scanner/test_file_downloader.py @@ -37,7 +37,9 @@ def setUp(self) -> None: self.media_headers = get_base_media_headers() async def _get( - url: str, query: Optional[MultiDictProxy[str]] = None + url: str, + query: Optional[MultiDictProxy[str]] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, bytes, CIMultiDictProxy[str]]: """Mock for the _get method on the file downloader that doesn't serve a .well-known client file. @@ -53,6 +55,14 @@ async def _get( or "/_matrix/media/r0/thumbnail/" + MEDIA_PATH in url ): return self.media_status, self.media_body, self.media_headers + if ( + url.endswith(("/_matrix/client/v1/media/download/" + MEDIA_PATH,)) + or "/_matrix/client/v1/media/thumbnail/" + MEDIA_PATH in url + ): + if auth_header is not None: + return self.media_status, self.media_body, self.media_headers + else: + return 404, b"Not found", CIMultiDictProxy(CIMultiDict()) elif url.endswith("/.well-known/matrix/client"): return 404, b"Not found", CIMultiDictProxy(CIMultiDict()) @@ -72,6 +82,19 @@ async def test_download(self) -> None: args = self.get_mock.call_args.args self.assertTrue(args[0].startswith("http://my-site.com/")) + async def test_download_auth_media(self) -> None: + """Tests that downloading a file works using authenticated media.""" + media = await self.downloader.download_file( + MEDIA_PATH, auth_header="Bearer access_token" + ) + self.assertEqual(media.content, SMALL_PNG) + self.assertEqual(media.content_type, "image/png") + + # Check that we tried downloading from the set base URL. + args = self.get_mock.call_args.args + self.assertTrue(args[0].startswith("http://my-site.com/")) + self.assertIn("/_matrix/client/v1/media/download/" + MEDIA_PATH, args[0]) + async def test_no_base_url(self) -> None: """Tests that configuring a base homeserver URL means files are downloaded from that homeserver (rather than the one the files were uploaded to) and .well-known @@ -88,7 +111,11 @@ async def test_no_base_url(self) -> None: ) self.assertEqual( self.get_mock.mock_calls[1], - call("https://foo/_matrix/media/v3/download/" + MEDIA_PATH, query=None), + call( + "https://foo/_matrix/media/v3/download/" + MEDIA_PATH, + query=None, + auth_header=None, + ), ) async def test_retry_on_404(self) -> None: @@ -128,13 +155,45 @@ async def _test_retry(self) -> None: self.assertEqual( self.get_mock.mock_calls[0], call( - "http://my-site.com/_matrix/media/v3/download/" + MEDIA_PATH, query=None + "http://my-site.com/_matrix/media/v3/download/" + MEDIA_PATH, + query=None, + auth_header=None, ), ) self.assertEqual( self.get_mock.mock_calls[1], call( - "http://my-site.com/_matrix/media/r0/download/" + MEDIA_PATH, query=None + "http://my-site.com/_matrix/media/r0/download/" + MEDIA_PATH, + query=None, + auth_header=None, + ), + ) + + async def test_no_retry(self) -> None: + """Tests that in a set specific case a failure to download a file from a v1 + authenticated media download path means we don't retry the request. + """ + self.media_status = 400 + self.media_body = b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}' + self._set_headers({"content-type": ["application/json"]}) + + # Check that we eventually fail at downloading the file. + with self.assertRaises(ContentScannerRestError) as cm: + await self.downloader.download_file( + MEDIA_PATH, auth_header="Bearer access_token" + ) + + self.assertEqual(cm.exception.http_status, 404) + self.assertEqual(cm.exception.info, "File not found") + + # Check that we sent out only one request. + self.assertEqual(self.get_mock.call_count, 1) + self.assertEqual( + self.get_mock.mock_calls[0], + call( + "http://my-site.com/_matrix/client/v1/media/download/" + MEDIA_PATH, + query=None, + auth_header="Bearer access_token", ), ) @@ -152,6 +211,21 @@ async def test_thumbnail(self) -> None: self.assertIn("height", query) self.assertEqual(query.get("height"), "50", query.getall("height")) + async def test_thumbnail_auth_media(self) -> None: + """Tests that we can download a thumbnail and that the parameters to generate the + thumbnail are correctly passed on to the homeserver using authenticated media. + """ + await self.downloader.download_file( + MEDIA_PATH, to_thumbnail_params({"height": "50"}), "Bearer access_token" + ) + + url: str = self.get_mock.call_args.args[0] + query: CIMultiDictProxy[str] = self.get_mock.call_args.kwargs["query"] + self.assertIn("/thumbnail/", url) + self.assertIn("/_matrix/client/v1/media/thumbnail/" + MEDIA_PATH, url) + self.assertIn("height", query) + self.assertEqual(query.get("height"), "50", query.getall("height")) + async def test_multiple_content_type(self) -> None: """Tests that we raise an error if the homeserver responds with too many Content-Type headers. @@ -203,7 +277,9 @@ def setUp(self) -> None: self.versions_status = 200 async def _get( - url: str, query: Optional[MultiDictProxy[str]] = None + url: str, + query: Optional[MultiDictProxy[str]] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, bytes, CIMultiDictProxy[str]]: """Mock for the _get method on the file downloader that serves a .well-known client file. diff --git a/tests/scanner/test_scanner.py b/tests/scanner/test_scanner.py index 9c7d66b..9ce57a9 100644 --- a/tests/scanner/test_scanner.py +++ b/tests/scanner/test_scanner.py @@ -42,6 +42,7 @@ def setUp(self) -> None: async def download_file( media_path: str, thumbnail_params: Optional[Dict[str, List[str]]] = None, + auth_header: Optional[str] = None, ) -> MediaDescription: """Mock for the file downloader's `download_file` method.""" return self.downloader_res