diff --git a/src/matrix_content_scanner/scanner/file_downloader.py b/src/matrix_content_scanner/scanner/file_downloader.py index bcdab50..492e87f 100644 --- a/src/matrix_content_scanner/scanner/file_downloader.py +++ b/src/matrix_content_scanner/scanner/file_downloader.py @@ -33,17 +33,24 @@ class _PathNotFoundException(Exception): class FileDownloader: MEDIA_DOWNLOAD_PREFIX = "_matrix/media/%s/download" MEDIA_THUMBNAIL_PREFIX = "_matrix/media/%s/thumbnail" + MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/download" + MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/thumbnail" def __init__(self, mcs: "MatrixContentScanner"): self._base_url = mcs.config.download.base_homeserver_url self._well_known_cache: Dict[str, Optional[str]] = {} self._proxy_url = mcs.config.download.proxy - self._headers = mcs.config.download.additional_headers + self._headers = ( + mcs.config.download.additional_headers + if mcs.config.download.additional_headers is not None + else {} + ) async def download_file( self, media_path: str, thumbnail_params: Optional[MultiMapping[str]] = None, + auth_header: Optional[str] = None, ) -> MediaDescription: """Retrieve the file with the given `server_name/media_id` path, and stores it on disk. @@ -52,6 +59,8 @@ async def download_file( media_path: The path identifying the media to retrieve. thumbnail_params: If present, then we want to request and scan a thumbnail generated with the provided parameters instead of the full media. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the file (including its full content). @@ -60,27 +69,45 @@ async def download_file( ContentScannerRestError: The file was not found or could not be downloaded due to an error on the remote homeserver's side. """ + + auth_media = True if auth_header is not None else False + + prefix = ( + self.MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX + if auth_media + else self.MEDIA_DOWNLOAD_PREFIX + ) + if thumbnail_params is not None: + prefix = ( + self.MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX + if auth_media + else self.MEDIA_THUMBNAIL_PREFIX + ) + url = await self._build_https_url( - media_path, for_thumbnail=thumbnail_params is not None + media_path, prefix, "v1" if auth_media else "v3" ) # Attempt to retrieve the file at the generated URL. try: - file = await self._get_file_content(url, thumbnail_params) + file = await self._get_file_content(url, thumbnail_params, auth_header) except _PathNotFoundException: + if auth_media: + raise ContentScannerRestError( + http_status=HTTPStatus.NOT_FOUND, + reason=ErrCode.NOT_FOUND, + info="File not found", + ) + # If the file could not be found, it might be because the homeserver hasn't # been upgraded to a version that supports Matrix v1.1 endpoints yet, so try # again with an r0 endpoint. logger.info("File not found, trying legacy r0 path") - url = await self._build_https_url( - media_path, - endpoint_version="r0", - for_thumbnail=thumbnail_params is not None, - ) + url = await self._build_https_url(media_path, prefix, "r0") try: - file = await self._get_file_content(url, thumbnail_params) + file = await self._get_file_content(url, thumbnail_params, auth_header) except _PathNotFoundException: # If that still failed, raise an error. raise ContentScannerRestError( @@ -94,9 +121,8 @@ async def download_file( async def _build_https_url( self, media_path: str, - endpoint_version: str = "v3", - *, - for_thumbnail: bool, + prefix: str, + endpoint_version: str, ) -> str: """Turn a `server_name/media_id` path into an https:// one we can use to fetch the media. @@ -107,10 +133,8 @@ async def _build_https_url( Args: media_path: The media path to translate. endpoint_version: The version of the download endpoint to use. As of Matrix - v1.1, this is either "v3" or "r0". - for_thumbnail: True if a server-side thumbnail is desired instead of the full - media. In that case, the URL for the `/thumbnail` endpoint is returned - instead of the `/download` endpoint. + v1.11, this is "v1" for authenticated media. For unauthenticated media + this is either "v3" or "r0". Returns: An https URL to use. If `base_homeserver_url` is set in the config, this @@ -140,10 +164,6 @@ async def _build_https_url( # didn't find a .well-known file. base_url = "https://" + server_name - prefix = ( - self.MEDIA_THUMBNAIL_PREFIX if for_thumbnail else self.MEDIA_DOWNLOAD_PREFIX - ) - # Build the full URL. path_prefix = prefix % endpoint_version url = "%s/%s/%s/%s" % ( @@ -159,12 +179,15 @@ async def _get_file_content( self, url: str, thumbnail_params: Optional[MultiMapping[str]], + auth_header: Optional[str] = None, ) -> MediaDescription: """Retrieve the content of the file at a given URL. Args: url: The URL to query. thumbnail_params: Query parameters used if the request is for a thumbnail. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the file (including its full content). @@ -178,7 +201,9 @@ async def _get_file_content( ContentScannerRestError: the server returned a non-200 status which cannot meant that the path wasn't understood. """ - code, body, headers = await self._get(url, query=thumbnail_params) + code, body, headers = await self._get( + url, query=thumbnail_params, auth_header=auth_header + ) logger.info("Remote server responded with %d", code) @@ -307,12 +332,15 @@ async def _get( self, url: str, query: Optional[MultiMapping[str]] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, bytes, CIMultiDictProxy[str]]: """Sends a GET request to the provided URL. Args: url: The URL to send requests to. query: Optional parameters to use in the request's query string. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: The HTTP status code, body and headers the remote server responded with. @@ -324,10 +352,15 @@ async def _get( try: logger.info("Sending GET request to %s", url) async with aiohttp.ClientSession() as session: + if auth_header is not None: + request_headers = {"Authorization": auth_header, **self._headers} + else: + request_headers = self._headers + async with session.get( url, proxy=self._proxy_url, - headers=self._headers, + headers=request_headers, params=query, ) as resp: return resp.status, await resp.read(), resp.headers diff --git a/src/matrix_content_scanner/scanner/scanner.py b/src/matrix_content_scanner/scanner/scanner.py index 9e27958..1d49b2d 100644 --- a/src/matrix_content_scanner/scanner/scanner.py +++ b/src/matrix_content_scanner/scanner/scanner.py @@ -100,6 +100,7 @@ async def scan_file( media_path: str, metadata: Optional[JsonDict] = None, thumbnail_params: Optional["MultiMapping[str]"] = None, + auth_header: Optional[str] = None, ) -> MediaDescription: """Download and scan the given media. @@ -119,6 +120,8 @@ async def scan_file( the file isn't encrypted. thumbnail_params: If present, then we want to request and scan a thumbnail generated with the provided parameters instead of the full media. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the media. @@ -141,7 +144,7 @@ async def scan_file( # Try to download and scan the file. try: res = await self._scan_file( - cache_key, media_path, metadata, thumbnail_params + cache_key, media_path, metadata, thumbnail_params, auth_header ) # Set the future's result, and mark it as done. f.set_result(res) @@ -168,6 +171,7 @@ async def _scan_file( media_path: str, metadata: Optional[JsonDict] = None, thumbnail_params: Optional[MultiMapping[str]] = None, + auth_header: Optional[str] = None, ) -> MediaDescription: """Download and scan the given media. @@ -185,6 +189,8 @@ async def _scan_file( the file isn't encrypted. thumbnail_params: If present, then we want to request and scan a thumbnail generated with the provided parameters instead of the full media. + auth_header: If present, we forward the given Authorization header, this is + required for authenticated media endpoints. Returns: A description of the media. @@ -218,6 +224,7 @@ async def _scan_file( media = await self._file_downloader.download_file( media_path=media_path, thumbnail_params=thumbnail_params, + auth_header=auth_header, ) # Compare the media's hash to ensure the server hasn't changed the file since @@ -251,6 +258,7 @@ async def _scan_file( media = await self._file_downloader.download_file( media_path=media_path, thumbnail_params=thumbnail_params, + auth_header=auth_header, ) # Download and scan the file. diff --git a/src/matrix_content_scanner/servlets/download.py b/src/matrix_content_scanner/servlets/download.py index 6f1626c..d18f502 100644 --- a/src/matrix_content_scanner/servlets/download.py +++ b/src/matrix_content_scanner/servlets/download.py @@ -26,8 +26,11 @@ async def _scan( self, media_path: str, metadata: Optional[JsonDict] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, _BytesResponse]: - media = await self._scanner.scan_file(media_path, metadata) + media = await self._scanner.scan_file( + media_path, metadata, auth_header=auth_header + ) return 200, _BytesResponse( headers=media.response_headers, @@ -38,7 +41,9 @@ async def _scan( async def handle_plain(self, request: web.Request) -> Tuple[int, _BytesResponse]: """Handles GET requests to ../download/serverName/mediaId""" media_path = request.match_info["media_path"] - return await self._scan(media_path) + return await self._scan( + media_path, auth_header=request.headers.get("Authorization") + ) @web_handler async def handle_encrypted( @@ -49,4 +54,6 @@ async def handle_encrypted( request, self._crypto_handler ) - return await self._scan(media_path, metadata) + return await self._scan( + media_path, metadata, auth_header=request.headers.get("Authorization") + ) diff --git a/src/matrix_content_scanner/servlets/scan.py b/src/matrix_content_scanner/servlets/scan.py index dc13cdb..461581e 100644 --- a/src/matrix_content_scanner/servlets/scan.py +++ b/src/matrix_content_scanner/servlets/scan.py @@ -23,9 +23,10 @@ async def _scan_and_format( self, media_path: str, metadata: Optional[JsonDict] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, JsonDict]: try: - await self._scanner.scan_file(media_path, metadata) + await self._scanner.scan_file(media_path, metadata, auth_header=auth_header) except FileDirtyError as e: res = {"clean": False, "info": e.info} else: @@ -37,7 +38,9 @@ async def _scan_and_format( async def handle_plain(self, request: web.Request) -> Tuple[int, JsonDict]: """Handles GET requests to ../scan/serverName/mediaId""" media_path = request.match_info["media_path"] - return await self._scan_and_format(media_path) + return await self._scan_and_format( + media_path, auth_header=request.headers.get("Authorization") + ) @web_handler async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]: @@ -45,4 +48,6 @@ async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]: media_path, metadata = await get_media_metadata_from_request( request, self._crypto_handler ) - return await self._scan_and_format(media_path, metadata) + return await self._scan_and_format( + media_path, metadata, auth_header=request.headers.get("Authorization") + ) diff --git a/src/matrix_content_scanner/servlets/thumbnail.py b/src/matrix_content_scanner/servlets/thumbnail.py index 9b0dd72..08bd959 100644 --- a/src/matrix_content_scanner/servlets/thumbnail.py +++ b/src/matrix_content_scanner/servlets/thumbnail.py @@ -26,6 +26,7 @@ async def handle_thumbnail( media = await self._scanner.scan_file( media_path=media_path, thumbnail_params=request.query, + auth_header=request.headers.get("Authorization"), ) return 200, _BytesResponse( diff --git a/tests/scanner/test_file_downloader.py b/tests/scanner/test_file_downloader.py index 9d10d61..46b8828 100644 --- a/tests/scanner/test_file_downloader.py +++ b/tests/scanner/test_file_downloader.py @@ -37,7 +37,9 @@ def setUp(self) -> None: self.media_headers = get_base_media_headers() async def _get( - url: str, query: Optional[MultiDictProxy[str]] = None + url: str, + query: Optional[MultiDictProxy[str]] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, bytes, CIMultiDictProxy[str]]: """Mock for the _get method on the file downloader that doesn't serve a .well-known client file. @@ -53,6 +55,14 @@ async def _get( or "/_matrix/media/r0/thumbnail/" + MEDIA_PATH in url ): return self.media_status, self.media_body, self.media_headers + if ( + url.endswith(("/_matrix/client/v1/media/download/" + MEDIA_PATH,)) + or "/_matrix/client/v1/media/thumbnail/" + MEDIA_PATH in url + ): + if auth_header is not None: + return self.media_status, self.media_body, self.media_headers + else: + return 404, b"Not found", CIMultiDictProxy(CIMultiDict()) elif url.endswith("/.well-known/matrix/client"): return 404, b"Not found", CIMultiDictProxy(CIMultiDict()) @@ -72,6 +82,19 @@ async def test_download(self) -> None: args = self.get_mock.call_args.args self.assertTrue(args[0].startswith("http://my-site.com/")) + async def test_download_auth_media(self) -> None: + """Tests that downloading a file works using authenticated media.""" + media = await self.downloader.download_file( + MEDIA_PATH, auth_header="Bearer access_token" + ) + self.assertEqual(media.content, SMALL_PNG) + self.assertEqual(media.content_type, "image/png") + + # Check that we tried downloading from the set base URL. + args = self.get_mock.call_args.args + self.assertTrue(args[0].startswith("http://my-site.com/")) + self.assertIn("/_matrix/client/v1/media/download/" + MEDIA_PATH, args[0]) + async def test_no_base_url(self) -> None: """Tests that configuring a base homeserver URL means files are downloaded from that homeserver (rather than the one the files were uploaded to) and .well-known @@ -88,7 +111,11 @@ async def test_no_base_url(self) -> None: ) self.assertEqual( self.get_mock.mock_calls[1], - call("https://foo/_matrix/media/v3/download/" + MEDIA_PATH, query=None), + call( + "https://foo/_matrix/media/v3/download/" + MEDIA_PATH, + query=None, + auth_header=None, + ), ) async def test_retry_on_404(self) -> None: @@ -128,13 +155,45 @@ async def _test_retry(self) -> None: self.assertEqual( self.get_mock.mock_calls[0], call( - "http://my-site.com/_matrix/media/v3/download/" + MEDIA_PATH, query=None + "http://my-site.com/_matrix/media/v3/download/" + MEDIA_PATH, + query=None, + auth_header=None, ), ) self.assertEqual( self.get_mock.mock_calls[1], call( - "http://my-site.com/_matrix/media/r0/download/" + MEDIA_PATH, query=None + "http://my-site.com/_matrix/media/r0/download/" + MEDIA_PATH, + query=None, + auth_header=None, + ), + ) + + async def test_no_retry(self) -> None: + """Tests that in a set specific case a failure to download a file from a v1 + authenticated media download path means we don't retry the request. + """ + self.media_status = 400 + self.media_body = b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}' + self._set_headers({"content-type": ["application/json"]}) + + # Check that we eventually fail at downloading the file. + with self.assertRaises(ContentScannerRestError) as cm: + await self.downloader.download_file( + MEDIA_PATH, auth_header="Bearer access_token" + ) + + self.assertEqual(cm.exception.http_status, 404) + self.assertEqual(cm.exception.info, "File not found") + + # Check that we sent out only one request. + self.assertEqual(self.get_mock.call_count, 1) + self.assertEqual( + self.get_mock.mock_calls[0], + call( + "http://my-site.com/_matrix/client/v1/media/download/" + MEDIA_PATH, + query=None, + auth_header="Bearer access_token", ), ) @@ -152,6 +211,21 @@ async def test_thumbnail(self) -> None: self.assertIn("height", query) self.assertEqual(query.get("height"), "50", query.getall("height")) + async def test_thumbnail_auth_media(self) -> None: + """Tests that we can download a thumbnail and that the parameters to generate the + thumbnail are correctly passed on to the homeserver using authenticated media. + """ + await self.downloader.download_file( + MEDIA_PATH, to_thumbnail_params({"height": "50"}), "Bearer access_token" + ) + + url: str = self.get_mock.call_args.args[0] + query: CIMultiDictProxy[str] = self.get_mock.call_args.kwargs["query"] + self.assertIn("/thumbnail/", url) + self.assertIn("/_matrix/client/v1/media/thumbnail/" + MEDIA_PATH, url) + self.assertIn("height", query) + self.assertEqual(query.get("height"), "50", query.getall("height")) + async def test_multiple_content_type(self) -> None: """Tests that we raise an error if the homeserver responds with too many Content-Type headers. @@ -203,7 +277,9 @@ def setUp(self) -> None: self.versions_status = 200 async def _get( - url: str, query: Optional[MultiDictProxy[str]] = None + url: str, + query: Optional[MultiDictProxy[str]] = None, + auth_header: Optional[str] = None, ) -> Tuple[int, bytes, CIMultiDictProxy[str]]: """Mock for the _get method on the file downloader that serves a .well-known client file. diff --git a/tests/scanner/test_scanner.py b/tests/scanner/test_scanner.py index 9c7d66b..9ce57a9 100644 --- a/tests/scanner/test_scanner.py +++ b/tests/scanner/test_scanner.py @@ -42,6 +42,7 @@ def setUp(self) -> None: async def download_file( media_path: str, thumbnail_params: Optional[Dict[str, List[str]]] = None, + auth_header: Optional[str] = None, ) -> MediaDescription: """Mock for the file downloader's `download_file` method.""" return self.downloader_res