Skip to content

Commit

Permalink
update url preparator
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Oct 28, 2024
1 parent 4b9b936 commit 12c58c0
Show file tree
Hide file tree
Showing 14 changed files with 44 additions and 26 deletions.
2 changes: 1 addition & 1 deletion libs/libapi/src/libapi/rows_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def transform_rows(
offset=offset,
row_idx_column=row_idx_column,
)
if "Audio(" in str(features) or "Image(" in str(features) or "Video(" in str(features):
if "Audio(" in str(features) or "Image(" in str(features) or "Video(" in str(features):
# Use multithreading to parallelize image/audio files uploads.
# Also multithreading is ok to convert audio data
# (we use pydub which might spawn one ffmpeg process per conversion, which releases the GIL)
Expand Down
6 changes: 4 additions & 2 deletions libs/libapi/tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@


@pytest.fixture
def storage_client(tmp_path: Path) -> StorageClient:
def storage_client(tmp_path: Path, hf_endpoint: str) -> StorageClient:
return StorageClient(
protocol="file",
storage_root=str(tmp_path / CACHED_ASSETS_FOLDER),
base_url="http://localhost/cached-assets",
url_preparator=URLPreparator(url_signer=None),
url_preparator=URLPreparator(
url_signer=None, hf_endpoint=hf_endpoint, assets_base_url="http://localhost/cached-assets"
),
)


Expand Down
1 change: 0 additions & 1 deletion libs/libcommon/src/libcommon/storage_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2023 The HuggingFace Authors.
import logging
import re
from typing import Optional, Union
from urllib import parse

Expand Down
6 changes: 3 additions & 3 deletions libs/libcommon/src/libcommon/url_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def classify(feature: FeatureType, visit_path: VisitPath) -> None:


class URLPreparator(ABC):
def __init__(self, url_signer: Optional[CloudFrontSigner], hf_endpoint: str) -> None:
def __init__(self, url_signer: Optional[CloudFrontSigner], hf_endpoint: str, assets_base_url: str) -> None:
self.url_signer = url_signer
self.hf_endpoint = hf_endpoint
self.datasets_server_assets_endpoint = hf_endpoint.replace("://", "://" + DATASETS_SERVER_ASSETS_SUBDOMAIN_NAME + ".")
self.assets_base_url = assets_base_url

def prepare_url(self, url: str, revision: str) -> str:
# Set the right revision in the URL e.g.
Expand All @@ -91,7 +91,7 @@ def prepare_url(self, url: str, revision: str) -> str:
# Sign the URL since the assets require authentication to be accessed
# Before: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/5fe59d7e52732b86d11ee0e9c4a8cdb0e8ba7a6e/--/default/test/0/image/image.jpg
# After: https://datasets-server.huggingface.co/assets/vidore/syntheticDocQA_artificial_intelligence_test/--/5fe59d7e52732b86d11ee0e9c4a8cdb0e8ba7a6e/--/default/test/0/image/image.jpg?Expires=1...4&Signature=E...A__&Key-Pair-Id=K...3
if self.url_signer and url.startswith(self.datasets_server_assets_endpoint):
if self.url_signer and url.startswith(self.assets_base_url):
url = self.url_signer.sign_url(url)
# Convert HF URL to HF HTTP URL e.g.
# Before: hf://datasets/username/dataset_name@5fe59d7e52732b86d11ee0e9c4a8cdb0e8ba7a6e/video.mp4
Expand Down
7 changes: 1 addition & 6 deletions libs/libcommon/src/libcommon/viewer_utils/asset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2022 The HuggingFace Authors.

import re
from io import BytesIO
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Optional, TypedDict
Expand Down Expand Up @@ -149,11 +148,7 @@ def create_video_file(
# dataset_git_revision of cache responses when the data will be accessed.
# This is useful to allow moving files to a newer revision without having
# to modify the cached rows content.
if (
"path" in encoded_video
and isinstance(encoded_video["path"], str)
and "://" in encoded_video["path"]
):
if "path" in encoded_video and isinstance(encoded_video["path"], str) and "://" in encoded_video["path"]:
object_path = encoded_video["path"].replace(revision, DATASET_GIT_REVISION_PLACEHOLDER)
else:
object_path = storage_client.generate_object_path(
Expand Down
1 change: 0 additions & 1 deletion libs/libcommon/src/libcommon/viewer_utils/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ def get_cell_value(
featureName: str,
fieldType: Any,
storage_client: StorageClient,
hf_endpoint: str,
json_path: Optional[list[Union[str, int]]] = None,
) -> Any:
# always allow None values in the cells
Expand Down
4 changes: 2 additions & 2 deletions libs/libcommon/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from libcommon.storage_client import StorageClient
from libcommon.url_preparator import URLPreparator

from .constants import ASSETS_BASE_URL
from .constants import ASSETS_BASE_URL, CI_HUB_ENDPOINT

# Import fixture modules as plugins
pytest_plugins = ["tests.fixtures.datasets", "tests.fixtures.fsspec"]
Expand Down Expand Up @@ -91,5 +91,5 @@ def storage_client_with_url_preparator(tmp_path_factory: TempPathFactory) -> Sto
storage_root=str(tmp_path_factory.getbasetemp()),
base_url=ASSETS_BASE_URL,
overwrite=True,
url_preparator=URLPreparator(url_signer=None),
url_preparator=URLPreparator(url_signer=None, hf_endpoint=CI_HUB_ENDPOINT, assets_base_url=ASSETS_BASE_URL),
)
6 changes: 5 additions & 1 deletion libs/libcommon/tests/test_integration_s3_cloudfront.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from libcommon.storage_client import StorageClient
from libcommon.url_preparator import URLPreparator

from .constants import CI_HUB_ENDPOINT

BUCKET = "hf-datasets-server-statics-test"
CLOUDFRONT_KEY_PAIR_ID = "K3814DK2QUJ71H"

Expand All @@ -34,7 +36,9 @@ def test_real_cloudfront(monkeypatch: pytest.MonkeyPatch) -> None:
storage_root=f"{BUCKET}/assets",
)
url_signer = get_cloudfront_signer(cloudfront_config=cloudfront_config)
url_preparator = URLPreparator(url_signer=url_signer)
url_preparator = URLPreparator(
url_signer=url_signer, hf_endpoint=CI_HUB_ENDPOINT, assets_base_url=assets_config.base_url
)
if not s3_config.access_key_id or not s3_config.secret_access_key or not url_signer:
pytest.skip("the S3 and/or CloudFront credentials are not set in environment variables, so we skip the test")

Expand Down
16 changes: 12 additions & 4 deletions libs/libcommon/tests/test_url_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from libcommon.viewer_utils.rows import create_first_rows_response

from .constants import (
ASSETS_BASE_URL,
CI_HUB_ENDPOINT,
DATASETS_NAMES,
DEFAULT_COLUMNS_MAX_NUMBER,
DEFAULT_CONFIG,
Expand Down Expand Up @@ -63,7 +65,7 @@ def sign_url(self, url: str) -> str:
def test__prepare_asset_url_path_in_place(datasets_fixtures: Mapping[str, DatasetFixture], dataset_name: str) -> None:
dataset_fixture = datasets_fixtures[dataset_name]
url_signer = FakeUrlSigner()
url_preparator = URLPreparator(url_signer=url_signer)
url_preparator = URLPreparator(url_signer=url_signer, hf_endpoint=CI_HUB_ENDPOINT, assets_base_url=ASSETS_BASE_URL)
for asset_url_path in dataset_fixture.expected_asset_url_paths:
cell_asset_url_path = asset_url_path.enter()
# ^ remove the column name, as we will sign the cell, not the row
Expand Down Expand Up @@ -101,7 +103,9 @@ def get_fake_rows_content(rows_max_number: int) -> RowsContent: # noqa: ARG001
)

url_signer = FakeUrlSigner()
url_preparator = URLPreparator(url_signer=url_signer)
url_preparator = URLPreparator(
url_signer=url_signer, hf_endpoint=CI_HUB_ENDPOINT, assets_base_url=storage_client.base_url
)
asset_url_paths = url_preparator._get_asset_url_paths_from_first_rows(first_rows=first_rows)

assert asset_url_paths == dataset_fixture.expected_asset_url_paths
Expand Down Expand Up @@ -130,7 +134,9 @@ def test_prepare_urls_in_first_rows_in_place(
)

url_signer = FakeUrlSigner()
url_preparator = URLPreparator(url_signer=url_signer)
url_preparator = URLPreparator(
url_signer=url_signer, hf_endpoint=CI_HUB_ENDPOINT, assets_base_url=storage_client.base_url
)
url_preparator.prepare_urls_in_first_rows_in_place(first_rows=first_rows, revision=DEFAULT_REVISION)

assert url_signer.counter == dataset_fixture.expected_num_asset_urls
Expand Down Expand Up @@ -185,7 +191,9 @@ def test_prepare_urls_in_first_rows_in_place_with_truncated_cells(
)

url_signer = FakeUrlSigner()
url_preparator = URLPreparator(url_signer=url_signer)
url_preparator = URLPreparator(
url_signer=url_signer, hf_endpoint=CI_HUB_ENDPOINT, assets_base_url=storage_client.base_url
)
url_preparator.prepare_urls_in_first_rows_in_place(first_rows=first_rows, revision=DEFAULT_REVISION)

if expected == "complete":
Expand Down
5 changes: 4 additions & 1 deletion libs/libcommon/tests/viewer_utils/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ..constants import (
ASSETS_BASE_URL,
CI_HUB_ENDPOINT,
DATASETS_NAMES,
DEFAULT_COLUMN_NAME,
DEFAULT_CONFIG,
Expand Down Expand Up @@ -147,7 +148,9 @@ def test_ogg_audio_with_s3(
secret_access_key="fake_secret_access_key",
region_name="us-east-1",
),
url_preparator=URLPreparator(url_signer=None),
url_preparator=URLPreparator(
url_signer=None, hf_endpoint=CI_HUB_ENDPOINT, assets_base_url=ASSETS_BASE_URL
),
)

# patch aiobotocore.endpoint.convert_to_response_dict because of known issue in aiotbotocore
Expand Down
4 changes: 3 additions & 1 deletion services/api/src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def create_app_with_config(app_config: AppConfig, endpoint_config: EndpointConfi
)

url_signer = get_cloudfront_signer(cloudfront_config=app_config.cloudfront)
url_preparator = URLPreparator(url_signer=url_signer)
url_preparator = URLPreparator(
url_signer=url_signer, hf_endpoint=app_config.common.hf_endpoint, assets_base_url=app_config.assets.base_url
)
assets_storage_client = StorageClient(
protocol=app_config.assets.storage_protocol,
storage_root=app_config.assets.storage_root,
Expand Down
4 changes: 3 additions & 1 deletion services/rows/src/rows/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def create_app_with_config(app_config: AppConfig) -> Starlette:
queue_resource = QueueMongoResource(database=app_config.queue.mongo_database, host=app_config.queue.mongo_url)

url_signer = get_cloudfront_signer(cloudfront_config=app_config.cloudfront)
url_preparator = URLPreparator(url_signer=url_signer)
url_preparator = URLPreparator(
url_signer=url_signer, hf_endpoint=app_config.common.hf_endpoint, assets_base_url=app_config.assets.base_url
)
cached_assets_storage_client = StorageClient(
protocol=app_config.cached_assets.storage_protocol,
storage_root=app_config.cached_assets.storage_root,
Expand Down
4 changes: 3 additions & 1 deletion services/search/src/search/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def create_app_with_config(app_config: AppConfig) -> Starlette:
cache_resource = CacheMongoResource(database=app_config.cache.mongo_database, host=app_config.cache.mongo_url)
queue_resource = QueueMongoResource(database=app_config.queue.mongo_database, host=app_config.queue.mongo_url)
url_signer = get_cloudfront_signer(cloudfront_config=app_config.cloudfront)
url_preparator = URLPreparator(url_signer=url_signer)
url_preparator = URLPreparator(
url_signer=url_signer, hf_endpoint=app_config.common.hf_endpoint, assets_base_url=app_config.assets.base_url
)
cached_assets_storage_client = StorageClient(
protocol=app_config.cached_assets.storage_protocol,
storage_root=app_config.cached_assets.storage_root,
Expand Down
4 changes: 3 additions & 1 deletion services/webhook/src/webhook/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def create_app_with_config(app_config: AppConfig) -> Starlette:
)

url_signer = get_cloudfront_signer(cloudfront_config=app_config.cloudfront)
url_preparator = URLPreparator(url_signer=url_signer)
url_preparator = URLPreparator(
url_signer=url_signer, hf_endpoint=app_config.common.hf_endpoint, assets_base_url=app_config.assets.base_url
)
assets_storage_client = StorageClient(
protocol=app_config.assets.storage_protocol,
storage_root=app_config.assets.storage_root,
Expand Down

0 comments on commit 12c58c0

Please sign in to comment.