Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use placeholder revision in urls in cached responses #2966

Merged
merged 7 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions libs/libcommon/src/libcommon/cloudfront.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
from functools import partial
from typing import Optional

from botocore.signers import CloudFrontSigner
import botocore.signers
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.hashes import SHA1
from cryptography.hazmat.primitives.serialization import load_pem_private_key

from libcommon.config import CloudFrontConfig
from libcommon.url_signer import URLSigner
from libcommon.utils import get_expires


Expand All @@ -28,7 +27,7 @@ class InvalidPrivateKeyError(ValueError):
# but CloudFront mandates SHA1


class CloudFront(URLSigner):
class CloudFrontSigner:
"""
Signs CloudFront URLs using a private key.

Expand All @@ -38,7 +37,7 @@ class CloudFront(URLSigner):
"""

_expiration_seconds: int
_signer: CloudFrontSigner
_signer: botocore.signers.CloudFrontSigner

def __init__(self, key_pair_id: str, private_key: str, expiration_seconds: int) -> None:
"""
Expand All @@ -55,7 +54,9 @@ def __init__(self, key_pair_id: str, private_key: str, expiration_seconds: int)
raise InvalidPrivateKeyError("Expected an RSA private key")

self._expiration_seconds = expiration_seconds
self._signer = CloudFrontSigner(key_pair_id, partial(pk.sign, padding=padding, algorithm=algorithm))
self._signer = botocore.signers.CloudFrontSigner(
key_pair_id, partial(pk.sign, padding=padding, algorithm=algorithm)
)

def _sign_url(self, url: str, date_less_than: datetime.datetime) -> str:
"""
Expand Down Expand Up @@ -87,9 +88,9 @@ def sign_url(self, url: str) -> str:
return self._sign_url(url=url, date_less_than=date_less_than)


def get_cloudfront_signer(cloudfront_config: CloudFrontConfig) -> Optional[CloudFront]:
def get_cloudfront_signer(cloudfront_config: CloudFrontConfig) -> Optional[CloudFrontSigner]:
return (
CloudFront(
CloudFrontSigner(
key_pair_id=cloudfront_config.key_pair_id,
private_key=cloudfront_config.private_key,
expiration_seconds=cloudfront_config.expiration_seconds,
Expand Down
28 changes: 14 additions & 14 deletions libs/libcommon/src/libcommon/storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from libcommon.config import S3Config, StorageProtocol
from libcommon.constants import DATASET_SEPARATOR
from libcommon.url_signer import URLSigner
from libcommon.url_preparator import URLPreparator


class StorageClientInitializeError(Exception):
Expand All @@ -27,15 +27,15 @@ class StorageClient:
base_url (`str`): The base url for the publicly distributed assets
overwrite (`bool`, *optional*, defaults to `False`): Whether to overwrite existing files
s3_config (`S3Config`, *optional*): The S3 configuration to connect to the storage client. Only needed if the protocol is "s3"
url_signer (`URLSigner`, *optional*): The url signer to use for signing urls
url_preparator (`URLPreparator`, *optional*): The urlpreparator to use for signing urls and replacing revision in url
"""

_fs: Union[LocalFileSystem, S3FileSystem]
protocol: StorageProtocol
storage_root: str
base_url: str
overwrite: bool
url_signer: Optional[URLSigner] = None
url_preparator: Optional[URLPreparator] = None

def __init__(
self,
Expand All @@ -44,14 +44,14 @@ def __init__(
base_url: str,
overwrite: bool = False,
s3_config: Optional[S3Config] = None,
url_signer: Optional[URLSigner] = None,
url_preparator: Optional[URLPreparator] = None,
) -> None:
logging.info(f"trying to initialize storage client with {protocol=} {storage_root=} {base_url=} {overwrite=}")
self.storage_root = storage_root
self.protocol = protocol
self.base_url = base_url
self.overwrite = overwrite
self.url_signer = url_signer
self.url_preparator = url_preparator
if protocol == "s3":
if not s3_config:
raise StorageClientInitializeError("s3 config is required")
Expand Down Expand Up @@ -82,18 +82,18 @@ def get_full_path(self, path: str) -> str:
def exists(self, path: str) -> bool:
return bool(self._fs.exists(self.get_full_path(path)))

def get_url(self, path: str) -> str:
return self.sign_url_if_available(self.get_unsigned_url(path))
def get_url(self, path: str, revision: str) -> str:
return self.prepare_url(self.get_unprepared_url(path), revision=revision)

def get_unsigned_url(self, path: str) -> str:
def get_unprepared_url(self, path: str) -> str:
url = f"{self.base_url}/{path}"
logging.debug(f"unsigned url: {url}")
logging.debug(f"unprepared url: {url}")
return url

def sign_url_if_available(self, url: str) -> str:
if self.url_signer:
url = self.url_signer.sign_url(url=url)
logging.debug(f"signed url: {url}")
def prepare_url(self, url: str, revision: str) -> str:
if self.url_preparator:
url = self.url_preparator.prepare_url(url=url, revision=revision)
logging.debug(f"prepared url: {url}")
return url

def delete_dataset_directory(self, dataset: str) -> int:
Expand Down Expand Up @@ -151,4 +151,4 @@ def generate_object_key(
return f"{parse.quote(dataset)}/{DATASET_SEPARATOR}/{revision}/{DATASET_SEPARATOR}/{parse.quote(config)}/{parse.quote(split)}/{str(row_idx)}/{parse.quote(column)}/{filename}"

def __str__(self) -> str:
return f"StorageClient(protocol={self.protocol}, storage_root={self.storage_root}, base_url={self.base_url}, overwrite={self.overwrite}, url_signer={self.url_signer})"
return f"StorageClient(protocol={self.protocol}, storage_root={self.storage_root}, base_url={self.base_url}, overwrite={self.overwrite}, url_preparator={self.url_preparator})"
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The HuggingFace Authors.

from abc import ABC, abstractmethod
from abc import ABC
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Union

from datasets import Audio, Features, Image
from datasets.features.features import FeatureType, Sequence

from libcommon.cloudfront import CloudFrontSigner
from libcommon.dtos import FeatureItem
from libcommon.viewer_utils.asset import replace_dataset_git_revision_placeholder


class InvalidFirstRowsError(ValueError):
Expand Down Expand Up @@ -74,15 +76,26 @@ def classify(feature: FeatureType, visit_path: VisitPath) -> None:
return asset_url_paths


class URLSigner(ABC):
@abstractmethod
def sign_url(self, url: str) -> str:
pass
class URLPreparator(ABC):
def __init__(self, url_signer: Optional[CloudFrontSigner]) -> None:
self.url_signer = url_signer

def prepare_url(self, url: str, revision: str) -> str:
# Set the right revision in the URL e.g.
# Before: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/{dataset_git_revision}/--/default/test/0/image/image.jpg
# After: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/5fe59d7e52732b86d11ee0e9c4a8cdb0e8ba7a6e/--/default/test/0/image/image.jpg
url = replace_dataset_git_revision_placeholder(url, revision)
# Sign the URL since the assets require authentication to be accessed
# Before: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/5fe59d7e52732b86d11ee0e9c4a8cdb0e8ba7a6e/--/default/test/0/image/image.jpg
# After: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/5fe59d7e52732b86d11ee0e9c4a8cdb0e8ba7a6e/--/default/test/0/image/image.jpg?Expires=1...4&Signature=E...A__&Key-Pair-Id=K...3
lhoestq marked this conversation as resolved.
Show resolved Hide resolved
if self.url_signer:
url = self.url_signer.sign_url(url)
return url

def __str__(self) -> str:
return self.__class__.__name__
return f"{self.__class__.__name__}(url_signer={self.url_signer})"

def _sign_asset_url_path_in_place(self, cell: Any, asset_url_path: AssetUrlPath) -> Any:
def _prepare_asset_url_path_in_place(self, cell: Any, asset_url_path: AssetUrlPath, revision: str) -> Any:
if not cell:
return cell
elif len(asset_url_path.path) == 0:
Expand All @@ -91,21 +104,25 @@ def _sign_asset_url_path_in_place(self, cell: Any, asset_url_path: AssetUrlPath)
src = cell.get("src")
if not isinstance(src, str):
raise InvalidFirstRowsError('Expected cell["src"] to be a string')
cell["src"] = self.sign_url(url=src)
# ^ sign the url in place
cell["src"] = self.prepare_url(src, revision=revision)
# ^ prepare the url in place
else:
key = asset_url_path.path[0]
if key == 0:
# it's a list, we have to sign each element
# it's a list, we have to prepare each element
if not isinstance(cell, list):
raise InvalidFirstRowsError("Expected the cell to be a list")
for cell_item in cell:
self._sign_asset_url_path_in_place(cell=cell_item, asset_url_path=asset_url_path.enter())
self._prepare_asset_url_path_in_place(
cell=cell_item, asset_url_path=asset_url_path.enter(), revision=revision
)
else:
# it's a dict, we have to sign the value of the key
# it's a dict, we have to prepare the value of the key
if not isinstance(cell, dict):
raise InvalidFirstRowsError("Expected the cell to be a dict")
self._sign_asset_url_path_in_place(cell=cell[key], asset_url_path=asset_url_path.enter())
self._prepare_asset_url_path_in_place(
cell=cell[key], asset_url_path=asset_url_path.enter(), revision=revision
)

def _get_asset_url_paths_from_first_rows(self, first_rows: Mapping[str, Any]) -> list[AssetUrlPath]:
# parse the features to find the paths to assets URLs
Expand All @@ -116,11 +133,11 @@ def _get_asset_url_paths_from_first_rows(self, first_rows: Mapping[str, Any]) ->
features = Features.from_dict(features_dict)
return get_asset_url_paths(features)

def sign_urls_in_first_rows_in_place(self, first_rows: Mapping[str, Any]) -> None:
def prepare_urls_in_first_rows_in_place(self, first_rows: Mapping[str, Any], revision: str) -> None:
asset_url_paths = self._get_asset_url_paths_from_first_rows(first_rows=first_rows)
if not asset_url_paths:
return
# sign the URLs
# prepare the URLs (set revision + sign)
row_items = first_rows.get("rows")
if not isinstance(row_items, list):
raise InvalidFirstRowsError('Expected response["rows"] to be a list')
Expand All @@ -135,6 +152,6 @@ def sign_urls_in_first_rows_in_place(self, first_rows: Mapping[str, Any]) -> Non
raise InvalidFirstRowsError('Expected response["rows"][i]["row"] to be a dict')
for asset_url_path in asset_url_paths:
if isinstance(asset_url_path.path[0], str) and asset_url_path.path[0] in truncated_cells:
# the cell has been truncated, nothing to sign in it
# the cell has been truncated, nothing to prepare in it
continue
self._sign_asset_url_path_in_place(cell=row, asset_url_path=asset_url_path)
self._prepare_asset_url_path_in_place(cell=row, asset_url_path=asset_url_path, revision=revision)
46 changes: 34 additions & 12 deletions libs/libcommon/src/libcommon/viewer_utils/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@

from io import BytesIO
from tempfile import NamedTemporaryFile
from typing import Optional, TypedDict
from typing import TYPE_CHECKING, Optional, TypedDict

from PIL import Image, ImageOps
from pydub import AudioSegment # type:ignore

from libcommon.storage_client import StorageClient
if TYPE_CHECKING:
from libcommon.storage_client import StorageClient


SUPPORTED_AUDIO_EXTENSION_TO_MEDIA_TYPE = {".wav": "audio/wav", ".mp3": "audio/mpeg", ".opus": "audio/opus"}
SUPPORTED_AUDIO_EXTENSIONS = SUPPORTED_AUDIO_EXTENSION_TO_MEDIA_TYPE.keys()
DATASET_GIT_REVISION_PLACEHOLDER = "{dataset_git_revision}"


class ImageSource(TypedDict):
Expand All @@ -35,25 +38,32 @@ def create_image_file(
filename: str,
image: Image.Image,
format: str,
storage_client: StorageClient,
storage_client: "StorageClient",
) -> ImageSource:
# We use a placeholder revision that will be replaced later by the
lhoestq marked this conversation as resolved.
Show resolved Hide resolved
# dataset_git_revision of cache responses when the data will be accessed.
# This is useful to allow moving files to a newer revision without having
# to modify the cached rows content.
object_key = storage_client.generate_object_key(
dataset=dataset,
revision=revision,
revision=DATASET_GIT_REVISION_PLACEHOLDER,
Comment on lines 49 to +51
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main change is here

config=config,
split=split,
row_idx=row_idx,
column=column,
filename=filename,
)
if storage_client.overwrite or not storage_client.exists(object_key):
path = replace_dataset_git_revision_placeholder(object_key, revision=revision)
if storage_client.overwrite or not storage_client.exists(path):
image = ImageOps.exif_transpose(image) # type: ignore[assignment]
buffer = BytesIO()
image.save(fp=buffer, format=format)
buffer.seek(0)
with storage_client._fs.open(storage_client.get_full_path(object_key), "wb") as f:
with storage_client._fs.open(storage_client.get_full_path(path), "wb") as f:
f.write(buffer.read())
return ImageSource(src=storage_client.get_url(object_key), height=image.height, width=image.width)
return ImageSource(
src=storage_client.get_url(object_key, revision=revision), height=image.height, width=image.width
)
Comment on lines +66 to +68
Copy link
Member Author

@lhoestq lhoestq Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few lines later, get_url applies the right revision if there is an URL preparator associated to the storage client.

In practice:

  • /rows, /search and /filter have an URL preparator that inserts the revision and signs the urls from the parquet/duckdb data
  • split-first-rows worker has no URL preparator: URLs are stored in the cached as unprepared (i.e. with the revision placeholder and unsigned)
  • /first-rows has an URL preparator that inserts the revision and signs the urls that come form the cache



def create_audio_file(
Expand All @@ -66,11 +76,15 @@ def create_audio_file(
audio_file_bytes: bytes,
audio_file_extension: Optional[str],
filename: str,
storage_client: StorageClient,
storage_client: "StorageClient",
) -> list[AudioSource]:
# We use a placeholder revision that will be replaced later by the
# dataset_git_revision of cache responses when the data will be accessed.
# This is useful to allow moving files to a newer revision without having
# to modify the cached rows content.
object_key = storage_client.generate_object_key(
dataset=dataset,
revision=revision,
revision=DATASET_GIT_REVISION_PLACEHOLDER,
config=config,
split=split,
row_idx=row_idx,
Expand All @@ -85,8 +99,9 @@ def create_audio_file(
)
media_type = SUPPORTED_AUDIO_EXTENSION_TO_MEDIA_TYPE[suffix]

if storage_client.overwrite or not storage_client.exists(object_key):
audio_path = storage_client.get_full_path(object_key)
path = replace_dataset_git_revision_placeholder(object_key, revision=revision)
if storage_client.overwrite or not storage_client.exists(path):
audio_path = storage_client.get_full_path(path)
if audio_file_extension == suffix:
with storage_client._fs.open(audio_path, "wb") as f:
f.write(audio_file_bytes)
Expand All @@ -100,4 +115,11 @@ def create_audio_file(
buffer.seek(0)
with storage_client._fs.open(audio_path, "wb") as f:
f.write(buffer.read())
return [AudioSource(src=storage_client.get_url(object_key), type=media_type)]
return [AudioSource(src=storage_client.get_url(object_key, revision=revision), type=media_type)]


def replace_dataset_git_revision_placeholder(url_or_object_key: str, revision: str) -> str:
# Set the right revision in the URL e.g.
# Before: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/{dataset_git_revision}/--/default/test/0/image/image.jpg
# After: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/c844916c2920d2d01e8a15f8dc1caf6f017a293c/--/default/test/0/image/image.jpg
return url_or_object_key.replace(DATASET_GIT_REVISION_PLACEHOLDER, revision)
12 changes: 12 additions & 0 deletions libs/libcommon/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from libcommon.simple_cache import _clean_cache_database
from libcommon.storage import StrPath, init_parquet_metadata_dir
from libcommon.storage_client import StorageClient
from libcommon.url_preparator import URLPreparator

from .constants import ASSETS_BASE_URL

Expand Down Expand Up @@ -81,3 +82,14 @@ def storage_client(tmp_path_factory: TempPathFactory) -> StorageClient:
return StorageClient(
protocol="file", storage_root=str(tmp_path_factory.getbasetemp()), base_url=ASSETS_BASE_URL, overwrite=True
)


@fixture(scope="session")
def storage_client_with_url_preparator(tmp_path_factory: TempPathFactory) -> StorageClient:
return StorageClient(
protocol="file",
storage_root=str(tmp_path_factory.getbasetemp()),
base_url=ASSETS_BASE_URL,
overwrite=True,
url_preparator=URLPreparator(url_signer=None),
)
2 changes: 1 addition & 1 deletion libs/libcommon/tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from datasets.features.features import FeatureType

from libcommon.url_signer import AssetUrlPath
from libcommon.url_preparator import AssetUrlPath

from ..constants import (
ASSETS_BASE_URL,
Expand Down
Loading
Loading