Skip to content

Commit

Permalink
Use placeholder revision in urls in cached responses (#2966)
Browse files Browse the repository at this point in the history
* use placeholder revision in urls in cached responses

* fix tests

* again

* again

* and again

* Apply suggestions from code review

Co-authored-by: Sylvain Lesage <[email protected]>

---------

Co-authored-by: Sylvain Lesage <[email protected]>
  • Loading branch information
lhoestq and severo authored Jul 15, 2024
1 parent 2ff0da5 commit 57008ce
Show file tree
Hide file tree
Showing 19 changed files with 185 additions and 106 deletions.
15 changes: 8 additions & 7 deletions libs/libcommon/src/libcommon/cloudfront.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
from functools import partial
from typing import Optional

from botocore.signers import CloudFrontSigner
import botocore.signers
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.hashes import SHA1
from cryptography.hazmat.primitives.serialization import load_pem_private_key

from libcommon.config import CloudFrontConfig
from libcommon.url_signer import URLSigner
from libcommon.utils import get_expires


Expand All @@ -28,7 +27,7 @@ class InvalidPrivateKeyError(ValueError):
# but CloudFront mandates SHA1


class CloudFront(URLSigner):
class CloudFrontSigner:
"""
Signs CloudFront URLs using a private key.
Expand All @@ -38,7 +37,7 @@ class CloudFront(URLSigner):
"""

_expiration_seconds: int
_signer: CloudFrontSigner
_signer: botocore.signers.CloudFrontSigner

def __init__(self, key_pair_id: str, private_key: str, expiration_seconds: int) -> None:
"""
Expand All @@ -55,7 +54,9 @@ def __init__(self, key_pair_id: str, private_key: str, expiration_seconds: int)
raise InvalidPrivateKeyError("Expected an RSA private key")

self._expiration_seconds = expiration_seconds
self._signer = CloudFrontSigner(key_pair_id, partial(pk.sign, padding=padding, algorithm=algorithm))
self._signer = botocore.signers.CloudFrontSigner(
key_pair_id, partial(pk.sign, padding=padding, algorithm=algorithm)
)

def _sign_url(self, url: str, date_less_than: datetime.datetime) -> str:
"""
Expand Down Expand Up @@ -87,9 +88,9 @@ def sign_url(self, url: str) -> str:
return self._sign_url(url=url, date_less_than=date_less_than)


def get_cloudfront_signer(cloudfront_config: CloudFrontConfig) -> Optional[CloudFront]:
def get_cloudfront_signer(cloudfront_config: CloudFrontConfig) -> Optional[CloudFrontSigner]:
return (
CloudFront(
CloudFrontSigner(
key_pair_id=cloudfront_config.key_pair_id,
private_key=cloudfront_config.private_key,
expiration_seconds=cloudfront_config.expiration_seconds,
Expand Down
28 changes: 14 additions & 14 deletions libs/libcommon/src/libcommon/storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from libcommon.config import S3Config, StorageProtocol
from libcommon.constants import DATASET_SEPARATOR
from libcommon.url_signer import URLSigner
from libcommon.url_preparator import URLPreparator


class StorageClientInitializeError(Exception):
Expand All @@ -27,15 +27,15 @@ class StorageClient:
base_url (`str`): The base url for the publicly distributed assets
overwrite (`bool`, *optional*, defaults to `False`): Whether to overwrite existing files
s3_config (`S3Config`, *optional*): The S3 configuration to connect to the storage client. Only needed if the protocol is "s3"
url_signer (`URLSigner`, *optional*): The url signer to use for signing urls
url_preparator (`URLPreparator`, *optional*): The urlpreparator to use for signing urls and replacing revision in url
"""

_fs: Union[LocalFileSystem, S3FileSystem]
protocol: StorageProtocol
storage_root: str
base_url: str
overwrite: bool
url_signer: Optional[URLSigner] = None
url_preparator: Optional[URLPreparator] = None

def __init__(
self,
Expand All @@ -44,14 +44,14 @@ def __init__(
base_url: str,
overwrite: bool = False,
s3_config: Optional[S3Config] = None,
url_signer: Optional[URLSigner] = None,
url_preparator: Optional[URLPreparator] = None,
) -> None:
logging.info(f"trying to initialize storage client with {protocol=} {storage_root=} {base_url=} {overwrite=}")
self.storage_root = storage_root
self.protocol = protocol
self.base_url = base_url
self.overwrite = overwrite
self.url_signer = url_signer
self.url_preparator = url_preparator
if protocol == "s3":
if not s3_config:
raise StorageClientInitializeError("s3 config is required")
Expand Down Expand Up @@ -82,18 +82,18 @@ def get_full_path(self, path: str) -> str:
def exists(self, path: str) -> bool:
return bool(self._fs.exists(self.get_full_path(path)))

def get_url(self, path: str) -> str:
return self.sign_url_if_available(self.get_unsigned_url(path))
def get_url(self, path: str, revision: str) -> str:
return self.prepare_url(self.get_unprepared_url(path), revision=revision)

def get_unsigned_url(self, path: str) -> str:
def get_unprepared_url(self, path: str) -> str:
url = f"{self.base_url}/{path}"
logging.debug(f"unsigned url: {url}")
logging.debug(f"unprepared url: {url}")
return url

def sign_url_if_available(self, url: str) -> str:
if self.url_signer:
url = self.url_signer.sign_url(url=url)
logging.debug(f"signed url: {url}")
def prepare_url(self, url: str, revision: str) -> str:
if self.url_preparator:
url = self.url_preparator.prepare_url(url=url, revision=revision)
logging.debug(f"prepared url: {url}")
return url

def delete_dataset_directory(self, dataset: str) -> int:
Expand Down Expand Up @@ -151,4 +151,4 @@ def generate_object_key(
return f"{parse.quote(dataset)}/{DATASET_SEPARATOR}/{revision}/{DATASET_SEPARATOR}/{parse.quote(config)}/{parse.quote(split)}/{str(row_idx)}/{parse.quote(column)}/{filename}"

def __str__(self) -> str:
return f"StorageClient(protocol={self.protocol}, storage_root={self.storage_root}, base_url={self.base_url}, overwrite={self.overwrite}, url_signer={self.url_signer})"
return f"StorageClient(protocol={self.protocol}, storage_root={self.storage_root}, base_url={self.base_url}, overwrite={self.overwrite}, url_preparator={self.url_preparator})"
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The HuggingFace Authors.

from abc import ABC, abstractmethod
from abc import ABC
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Union

from datasets import Audio, Features, Image
from datasets.features.features import FeatureType, Sequence

from libcommon.cloudfront import CloudFrontSigner
from libcommon.dtos import FeatureItem
from libcommon.viewer_utils.asset import replace_dataset_git_revision_placeholder


class InvalidFirstRowsError(ValueError):
Expand Down Expand Up @@ -74,15 +76,26 @@ def classify(feature: FeatureType, visit_path: VisitPath) -> None:
return asset_url_paths


class URLSigner(ABC):
@abstractmethod
def sign_url(self, url: str) -> str:
pass
class URLPreparator(ABC):
def __init__(self, url_signer: Optional[CloudFrontSigner]) -> None:
self.url_signer = url_signer

def prepare_url(self, url: str, revision: str) -> str:
# Set the right revision in the URL e.g.
# Before: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/{dataset_git_revision}/--/default/test/0/image/image.jpg
# After: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/5fe59d7e52732b86d11ee0e9c4a8cdb0e8ba7a6e/--/default/test/0/image/image.jpg
url = replace_dataset_git_revision_placeholder(url, revision)
# Sign the URL since the assets require authentication to be accessed
# Before: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/5fe59d7e52732b86d11ee0e9c4a8cdb0e8ba7a6e/--/default/test/0/image/image.jpg
# After: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/5fe59d7e52732b86d11ee0e9c4a8cdb0e8ba7a6e/--/default/test/0/image/image.jpg?Expires=1...4&Signature=E...A__&Key-Pair-Id=K...3
if self.url_signer:
url = self.url_signer.sign_url(url)
return url

def __str__(self) -> str:
return self.__class__.__name__
return f"{self.__class__.__name__}(url_signer={self.url_signer})"

def _sign_asset_url_path_in_place(self, cell: Any, asset_url_path: AssetUrlPath) -> Any:
def _prepare_asset_url_path_in_place(self, cell: Any, asset_url_path: AssetUrlPath, revision: str) -> Any:
if not cell:
return cell
elif len(asset_url_path.path) == 0:
Expand All @@ -91,21 +104,25 @@ def _sign_asset_url_path_in_place(self, cell: Any, asset_url_path: AssetUrlPath)
src = cell.get("src")
if not isinstance(src, str):
raise InvalidFirstRowsError('Expected cell["src"] to be a string')
cell["src"] = self.sign_url(url=src)
# ^ sign the url in place
cell["src"] = self.prepare_url(src, revision=revision)
# ^ prepare the url in place
else:
key = asset_url_path.path[0]
if key == 0:
# it's a list, we have to sign each element
# it's a list, we have to prepare each element
if not isinstance(cell, list):
raise InvalidFirstRowsError("Expected the cell to be a list")
for cell_item in cell:
self._sign_asset_url_path_in_place(cell=cell_item, asset_url_path=asset_url_path.enter())
self._prepare_asset_url_path_in_place(
cell=cell_item, asset_url_path=asset_url_path.enter(), revision=revision
)
else:
# it's a dict, we have to sign the value of the key
# it's a dict, we have to prepare the value of the key
if not isinstance(cell, dict):
raise InvalidFirstRowsError("Expected the cell to be a dict")
self._sign_asset_url_path_in_place(cell=cell[key], asset_url_path=asset_url_path.enter())
self._prepare_asset_url_path_in_place(
cell=cell[key], asset_url_path=asset_url_path.enter(), revision=revision
)

def _get_asset_url_paths_from_first_rows(self, first_rows: Mapping[str, Any]) -> list[AssetUrlPath]:
# parse the features to find the paths to assets URLs
Expand All @@ -116,11 +133,11 @@ def _get_asset_url_paths_from_first_rows(self, first_rows: Mapping[str, Any]) ->
features = Features.from_dict(features_dict)
return get_asset_url_paths(features)

def sign_urls_in_first_rows_in_place(self, first_rows: Mapping[str, Any]) -> None:
def prepare_urls_in_first_rows_in_place(self, first_rows: Mapping[str, Any], revision: str) -> None:
asset_url_paths = self._get_asset_url_paths_from_first_rows(first_rows=first_rows)
if not asset_url_paths:
return
# sign the URLs
# prepare the URLs (set revision + sign)
row_items = first_rows.get("rows")
if not isinstance(row_items, list):
raise InvalidFirstRowsError('Expected response["rows"] to be a list')
Expand All @@ -135,6 +152,6 @@ def sign_urls_in_first_rows_in_place(self, first_rows: Mapping[str, Any]) -> Non
raise InvalidFirstRowsError('Expected response["rows"][i]["row"] to be a dict')
for asset_url_path in asset_url_paths:
if isinstance(asset_url_path.path[0], str) and asset_url_path.path[0] in truncated_cells:
# the cell has been truncated, nothing to sign in it
# the cell has been truncated, nothing to prepare in it
continue
self._sign_asset_url_path_in_place(cell=row, asset_url_path=asset_url_path)
self._prepare_asset_url_path_in_place(cell=row, asset_url_path=asset_url_path, revision=revision)
48 changes: 36 additions & 12 deletions libs/libcommon/src/libcommon/viewer_utils/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@

from io import BytesIO
from tempfile import NamedTemporaryFile
from typing import Optional, TypedDict
from typing import TYPE_CHECKING, Optional, TypedDict

from PIL import Image, ImageOps
from pydub import AudioSegment # type:ignore

from libcommon.storage_client import StorageClient
if TYPE_CHECKING:
from libcommon.storage_client import StorageClient


SUPPORTED_AUDIO_EXTENSION_TO_MEDIA_TYPE = {".wav": "audio/wav", ".mp3": "audio/mpeg", ".opus": "audio/opus"}
SUPPORTED_AUDIO_EXTENSIONS = SUPPORTED_AUDIO_EXTENSION_TO_MEDIA_TYPE.keys()
DATASET_GIT_REVISION_PLACEHOLDER = "{dataset_git_revision}"


class ImageSource(TypedDict):
Expand All @@ -35,25 +38,34 @@ def create_image_file(
filename: str,
image: Image.Image,
format: str,
storage_client: StorageClient,
storage_client: "StorageClient",
) -> ImageSource:
# We use a placeholder revision in the JSON stored in the database,
# while the path of the file stored on the disk/s3 contains the revision.
# The placeholder will be replaced later by the
# dataset_git_revision of cache responses when the data will be accessed.
# This is useful to allow moving files to a newer revision without having
# to modify the cached rows content.
object_key = storage_client.generate_object_key(
dataset=dataset,
revision=revision,
revision=DATASET_GIT_REVISION_PLACEHOLDER,
config=config,
split=split,
row_idx=row_idx,
column=column,
filename=filename,
)
if storage_client.overwrite or not storage_client.exists(object_key):
path = replace_dataset_git_revision_placeholder(object_key, revision=revision)
if storage_client.overwrite or not storage_client.exists(path):
image = ImageOps.exif_transpose(image) # type: ignore[assignment]
buffer = BytesIO()
image.save(fp=buffer, format=format)
buffer.seek(0)
with storage_client._fs.open(storage_client.get_full_path(object_key), "wb") as f:
with storage_client._fs.open(storage_client.get_full_path(path), "wb") as f:
f.write(buffer.read())
return ImageSource(src=storage_client.get_url(object_key), height=image.height, width=image.width)
return ImageSource(
src=storage_client.get_url(object_key, revision=revision), height=image.height, width=image.width
)


def create_audio_file(
Expand All @@ -66,11 +78,15 @@ def create_audio_file(
audio_file_bytes: bytes,
audio_file_extension: Optional[str],
filename: str,
storage_client: StorageClient,
storage_client: "StorageClient",
) -> list[AudioSource]:
# We use a placeholder revision that will be replaced later by the
# dataset_git_revision of cache responses when the data will be accessed.
# This is useful to allow moving files to a newer revision without having
# to modify the cached rows content.
object_key = storage_client.generate_object_key(
dataset=dataset,
revision=revision,
revision=DATASET_GIT_REVISION_PLACEHOLDER,
config=config,
split=split,
row_idx=row_idx,
Expand All @@ -85,8 +101,9 @@ def create_audio_file(
)
media_type = SUPPORTED_AUDIO_EXTENSION_TO_MEDIA_TYPE[suffix]

if storage_client.overwrite or not storage_client.exists(object_key):
audio_path = storage_client.get_full_path(object_key)
path = replace_dataset_git_revision_placeholder(object_key, revision=revision)
if storage_client.overwrite or not storage_client.exists(path):
audio_path = storage_client.get_full_path(path)
if audio_file_extension == suffix:
with storage_client._fs.open(audio_path, "wb") as f:
f.write(audio_file_bytes)
Expand All @@ -100,4 +117,11 @@ def create_audio_file(
buffer.seek(0)
with storage_client._fs.open(audio_path, "wb") as f:
f.write(buffer.read())
return [AudioSource(src=storage_client.get_url(object_key), type=media_type)]
return [AudioSource(src=storage_client.get_url(object_key, revision=revision), type=media_type)]


def replace_dataset_git_revision_placeholder(url_or_object_key: str, revision: str) -> str:
# Set the right revision in the URL e.g.
# Before: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/{dataset_git_revision}/--/default/test/0/image/image.jpg
# After: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/c844916c2920d2d01e8a15f8dc1caf6f017a293c/--/default/test/0/image/image.jpg
return url_or_object_key.replace(DATASET_GIT_REVISION_PLACEHOLDER, revision)
12 changes: 12 additions & 0 deletions libs/libcommon/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from libcommon.simple_cache import _clean_cache_database
from libcommon.storage import StrPath, init_parquet_metadata_dir
from libcommon.storage_client import StorageClient
from libcommon.url_preparator import URLPreparator

from .constants import ASSETS_BASE_URL

Expand Down Expand Up @@ -81,3 +82,14 @@ def storage_client(tmp_path_factory: TempPathFactory) -> StorageClient:
return StorageClient(
protocol="file", storage_root=str(tmp_path_factory.getbasetemp()), base_url=ASSETS_BASE_URL, overwrite=True
)


@fixture(scope="session")
def storage_client_with_url_preparator(tmp_path_factory: TempPathFactory) -> StorageClient:
return StorageClient(
protocol="file",
storage_root=str(tmp_path_factory.getbasetemp()),
base_url=ASSETS_BASE_URL,
overwrite=True,
url_preparator=URLPreparator(url_signer=None),
)
2 changes: 1 addition & 1 deletion libs/libcommon/tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from datasets.features.features import FeatureType

from libcommon.url_signer import AssetUrlPath
from libcommon.url_preparator import AssetUrlPath

from ..constants import (
ASSETS_BASE_URL,
Expand Down
Loading

0 comments on commit 57008ce

Please sign in to comment.