Skip to content

Commit

Permalink
Add support for authenticated media (#69)
Browse files Browse the repository at this point in the history
This PR adds support for authenticated media by passing the
Authorization header information through from the client to the
homeserver.
Once MAS supports scoped access tokens this code should be changed over
to use that.

Huge shoutout to @S7evinK for doing the bulk of the implementation on
this.
  • Loading branch information
devonh authored Dec 5, 2024
1 parent 2056293 commit 694914f
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 34 deletions.
77 changes: 55 additions & 22 deletions src/matrix_content_scanner/scanner/file_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,24 @@ class _PathNotFoundException(Exception):
class FileDownloader:
MEDIA_DOWNLOAD_PREFIX = "_matrix/media/%s/download"
MEDIA_THUMBNAIL_PREFIX = "_matrix/media/%s/thumbnail"
MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/download"
MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/thumbnail"

def __init__(self, mcs: "MatrixContentScanner"):
self._base_url = mcs.config.download.base_homeserver_url
self._well_known_cache: Dict[str, Optional[str]] = {}
self._proxy_url = mcs.config.download.proxy
self._headers = mcs.config.download.additional_headers
self._headers = (
mcs.config.download.additional_headers
if mcs.config.download.additional_headers is not None
else {}
)

async def download_file(
self,
media_path: str,
thumbnail_params: Optional[MultiMapping[str]] = None,
auth_header: Optional[str] = None,
) -> MediaDescription:
"""Retrieve the file with the given `server_name/media_id` path, and stores it on
disk.
Expand All @@ -52,6 +59,8 @@ async def download_file(
media_path: The path identifying the media to retrieve.
thumbnail_params: If present, then we want to request and scan a thumbnail
generated with the provided parameters instead of the full media.
auth_header: If present, we forward the given Authorization header, this is
required for authenticated media endpoints.
Returns:
A description of the file (including its full content).
Expand All @@ -60,27 +69,45 @@ async def download_file(
ContentScannerRestError: The file was not found or could not be downloaded due
to an error on the remote homeserver's side.
"""

auth_media = True if auth_header is not None else False

prefix = (
self.MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX
if auth_media
else self.MEDIA_DOWNLOAD_PREFIX
)
if thumbnail_params is not None:
prefix = (
self.MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX
if auth_media
else self.MEDIA_THUMBNAIL_PREFIX
)

url = await self._build_https_url(
media_path, for_thumbnail=thumbnail_params is not None
media_path, prefix, "v1" if auth_media else "v3"
)

# Attempt to retrieve the file at the generated URL.
try:
file = await self._get_file_content(url, thumbnail_params)
file = await self._get_file_content(url, thumbnail_params, auth_header)
except _PathNotFoundException:
if auth_media:
raise ContentScannerRestError(
http_status=HTTPStatus.NOT_FOUND,
reason=ErrCode.NOT_FOUND,
info="File not found",
)

# If the file could not be found, it might be because the homeserver hasn't
# been upgraded to a version that supports Matrix v1.1 endpoints yet, so try
# again with an r0 endpoint.
logger.info("File not found, trying legacy r0 path")

url = await self._build_https_url(
media_path,
endpoint_version="r0",
for_thumbnail=thumbnail_params is not None,
)
url = await self._build_https_url(media_path, prefix, "r0")

try:
file = await self._get_file_content(url, thumbnail_params)
file = await self._get_file_content(url, thumbnail_params, auth_header)
except _PathNotFoundException:
# If that still failed, raise an error.
raise ContentScannerRestError(
Expand All @@ -94,9 +121,8 @@ async def download_file(
async def _build_https_url(
self,
media_path: str,
endpoint_version: str = "v3",
*,
for_thumbnail: bool,
prefix: str,
endpoint_version: str,
) -> str:
"""Turn a `server_name/media_id` path into an https:// one we can use to fetch
the media.
Expand All @@ -107,10 +133,8 @@ async def _build_https_url(
Args:
media_path: The media path to translate.
endpoint_version: The version of the download endpoint to use. As of Matrix
v1.1, this is either "v3" or "r0".
for_thumbnail: True if a server-side thumbnail is desired instead of the full
media. In that case, the URL for the `/thumbnail` endpoint is returned
instead of the `/download` endpoint.
v1.11, this is "v1" for authenticated media. For unauthenticated media
this is either "v3" or "r0".
Returns:
An https URL to use. If `base_homeserver_url` is set in the config, this
Expand Down Expand Up @@ -140,10 +164,6 @@ async def _build_https_url(
# didn't find a .well-known file.
base_url = "https://" + server_name

prefix = (
self.MEDIA_THUMBNAIL_PREFIX if for_thumbnail else self.MEDIA_DOWNLOAD_PREFIX
)

# Build the full URL.
path_prefix = prefix % endpoint_version
url = "%s/%s/%s/%s" % (
Expand All @@ -159,12 +179,15 @@ async def _get_file_content(
self,
url: str,
thumbnail_params: Optional[MultiMapping[str]],
auth_header: Optional[str] = None,
) -> MediaDescription:
"""Retrieve the content of the file at a given URL.
Args:
url: The URL to query.
thumbnail_params: Query parameters used if the request is for a thumbnail.
auth_header: If present, we forward the given Authorization header, this is
required for authenticated media endpoints.
Returns:
A description of the file (including its full content).
Expand All @@ -178,7 +201,9 @@ async def _get_file_content(
ContentScannerRestError: the server returned a non-200 status which cannot
meant that the path wasn't understood.
"""
code, body, headers = await self._get(url, query=thumbnail_params)
code, body, headers = await self._get(
url, query=thumbnail_params, auth_header=auth_header
)

logger.info("Remote server responded with %d", code)

Expand Down Expand Up @@ -307,12 +332,15 @@ async def _get(
self,
url: str,
query: Optional[MultiMapping[str]] = None,
auth_header: Optional[str] = None,
) -> Tuple[int, bytes, CIMultiDictProxy[str]]:
"""Sends a GET request to the provided URL.
Args:
url: The URL to send requests to.
query: Optional parameters to use in the request's query string.
auth_header: If present, we forward the given Authorization header, this is
required for authenticated media endpoints.
Returns:
The HTTP status code, body and headers the remote server responded with.
Expand All @@ -324,10 +352,15 @@ async def _get(
try:
logger.info("Sending GET request to %s", url)
async with aiohttp.ClientSession() as session:
if auth_header is not None:
request_headers = {"Authorization": auth_header, **self._headers}
else:
request_headers = self._headers

async with session.get(
url,
proxy=self._proxy_url,
headers=self._headers,
headers=request_headers,
params=query,
) as resp:
return resp.status, await resp.read(), resp.headers
Expand Down
10 changes: 9 additions & 1 deletion src/matrix_content_scanner/scanner/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def scan_file(
media_path: str,
metadata: Optional[JsonDict] = None,
thumbnail_params: Optional["MultiMapping[str]"] = None,
auth_header: Optional[str] = None,
) -> MediaDescription:
"""Download and scan the given media.
Expand All @@ -119,6 +120,8 @@ async def scan_file(
the file isn't encrypted.
thumbnail_params: If present, then we want to request and scan a thumbnail
generated with the provided parameters instead of the full media.
auth_header: If present, we forward the given Authorization header, this is
required for authenticated media endpoints.
Returns:
A description of the media.
Expand All @@ -141,7 +144,7 @@ async def scan_file(
# Try to download and scan the file.
try:
res = await self._scan_file(
cache_key, media_path, metadata, thumbnail_params
cache_key, media_path, metadata, thumbnail_params, auth_header
)
# Set the future's result, and mark it as done.
f.set_result(res)
Expand All @@ -168,6 +171,7 @@ async def _scan_file(
media_path: str,
metadata: Optional[JsonDict] = None,
thumbnail_params: Optional[MultiMapping[str]] = None,
auth_header: Optional[str] = None,
) -> MediaDescription:
"""Download and scan the given media.
Expand All @@ -185,6 +189,8 @@ async def _scan_file(
the file isn't encrypted.
thumbnail_params: If present, then we want to request and scan a thumbnail
generated with the provided parameters instead of the full media.
auth_header: If present, we forward the given Authorization header, this is
required for authenticated media endpoints.
Returns:
A description of the media.
Expand Down Expand Up @@ -218,6 +224,7 @@ async def _scan_file(
media = await self._file_downloader.download_file(
media_path=media_path,
thumbnail_params=thumbnail_params,
auth_header=auth_header,
)

# Compare the media's hash to ensure the server hasn't changed the file since
Expand Down Expand Up @@ -251,6 +258,7 @@ async def _scan_file(
media = await self._file_downloader.download_file(
media_path=media_path,
thumbnail_params=thumbnail_params,
auth_header=auth_header,
)

# Download and scan the file.
Expand Down
13 changes: 10 additions & 3 deletions src/matrix_content_scanner/servlets/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ async def _scan(
self,
media_path: str,
metadata: Optional[JsonDict] = None,
auth_header: Optional[str] = None,
) -> Tuple[int, _BytesResponse]:
media = await self._scanner.scan_file(media_path, metadata)
media = await self._scanner.scan_file(
media_path, metadata, auth_header=auth_header
)

return 200, _BytesResponse(
headers=media.response_headers,
Expand All @@ -38,7 +41,9 @@ async def _scan(
async def handle_plain(self, request: web.Request) -> Tuple[int, _BytesResponse]:
"""Handles GET requests to ../download/serverName/mediaId"""
media_path = request.match_info["media_path"]
return await self._scan(media_path)
return await self._scan(
media_path, auth_header=request.headers.get("Authorization")
)

@web_handler
async def handle_encrypted(
Expand All @@ -49,4 +54,6 @@ async def handle_encrypted(
request, self._crypto_handler
)

return await self._scan(media_path, metadata)
return await self._scan(
media_path, metadata, auth_header=request.headers.get("Authorization")
)
11 changes: 8 additions & 3 deletions src/matrix_content_scanner/servlets/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ async def _scan_and_format(
self,
media_path: str,
metadata: Optional[JsonDict] = None,
auth_header: Optional[str] = None,
) -> Tuple[int, JsonDict]:
try:
await self._scanner.scan_file(media_path, metadata)
await self._scanner.scan_file(media_path, metadata, auth_header=auth_header)
except FileDirtyError as e:
res = {"clean": False, "info": e.info}
else:
Expand All @@ -37,12 +38,16 @@ async def _scan_and_format(
async def handle_plain(self, request: web.Request) -> Tuple[int, JsonDict]:
"""Handles GET requests to ../scan/serverName/mediaId"""
media_path = request.match_info["media_path"]
return await self._scan_and_format(media_path)
return await self._scan_and_format(
media_path, auth_header=request.headers.get("Authorization")
)

@web_handler
async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]:
"""Handles GET requests to ../scan_encrypted"""
media_path, metadata = await get_media_metadata_from_request(
request, self._crypto_handler
)
return await self._scan_and_format(media_path, metadata)
return await self._scan_and_format(
media_path, metadata, auth_header=request.headers.get("Authorization")
)
1 change: 1 addition & 0 deletions src/matrix_content_scanner/servlets/thumbnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def handle_thumbnail(
media = await self._scanner.scan_file(
media_path=media_path,
thumbnail_params=request.query,
auth_header=request.headers.get("Authorization"),
)

return 200, _BytesResponse(
Expand Down
Loading

0 comments on commit 694914f

Please sign in to comment.