Skip to content

Commit

Permalink
Support private datasets (#2224)
Browse files Browse the repository at this point in the history
* simplify test

* first delete, then create, when moving

indeed: when a repo is moved, we surely want to remove the old entry,
while the new one could be forbidden (private repo, for example)

* refactor to simplify the logic

the idea is to always use "operation.py", so that access control will be
easier once we allow private datasets.

* first pass

* check if dataset is blocked in operations.py

* fix import

* add tests and fix bugs

* fix types

* simplify the e2e tests and fix the admin token

* let unexpected exceptions bubble up

* fix cookie + bugs in tests

* refactor: remove DatasetOrchestrator wrapper, and simplify code

* add logs and simplify logic

* add e2e tests on private datasets

* simplify code and fix e2e tests

* fix admin + move comments in e2e tests

* add e2e tests: private+gated, disabled viewer, disabled discussions, blocklist

* fix types

* add missing environment variables for e2e tests

* change logic of backfill job: update all current datasets

* remove assets/cached-assets when deleting a dataset

* fix tests

* fix quality

* Update libs/libcommon/tests/test_operations.py

Co-authored-by: Andrea Francis Soria Jimenez <[email protected]>

* Update libs/libcommon/src/libcommon/operations.py

Co-authored-by: Andrea Francis Soria Jimenez <[email protected]>

---------

Co-authored-by: Andrea Francis Soria Jimenez <[email protected]>
  • Loading branch information
severo and AndreaFrancis authored Jan 30, 2024
1 parent c48c011 commit 9d157f3
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 19 deletions.
6 changes: 0 additions & 6 deletions e2e/tests/test_11_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ def test_pro_user_private(csv_path: str) -> None:
poll_parquet_until_ready_and_assert(
dataset=dataset,
headers={"Authorization": f"Bearer {PRO_USER_TOKEN}"},
expected_status_code=501,
expected_error_code="NotSupportedPrivateRepositoryError",
)


Expand Down Expand Up @@ -187,8 +185,6 @@ def test_enterprise_org_private(csv_path: str) -> None:
poll_parquet_until_ready_and_assert(
dataset=dataset,
headers={"Authorization": f"Bearer {ENTERPRISE_USER_TOKEN}"},
expected_status_code=501,
expected_error_code="NotSupportedPrivateRepositoryError",
)


Expand Down Expand Up @@ -217,8 +213,6 @@ def test_pro_user_private_gated(csv_path: str) -> None:
poll_parquet_until_ready_and_assert(
dataset=dataset,
headers={"Authorization": f"Bearer {PRO_USER_TOKEN}"},
expected_status_code=501,
expected_error_code="NotSupportedPrivateRepositoryError",
)


Expand Down
100 changes: 97 additions & 3 deletions libs/libcommon/src/libcommon/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
# Copyright 2022 The HuggingFace Authors.

import logging
from typing import Optional
from dataclasses import dataclass
from typing import Optional, Union

from huggingface_hub.hf_api import DatasetInfo, HfApi
from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError
from huggingface_hub.utils import (
HfHubHTTPError,
RepositoryNotFoundError,
get_session,
hf_raise_for_status,
validate_hf_hub_args,
)

from libcommon.dtos import Priority
from libcommon.exceptions import (
Expand All @@ -20,6 +27,70 @@
from libcommon.utils import raise_if_blocked


@dataclass
class EntityInfo:
"""
Contains (very partial) information about an entity on the Hub.
<Tip>
Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made.
</Tip>
**Attributes**:
is_pro (`bool`, *optional*):
Is the entity a pro user.
is_enterprise (`bool`, *optional*):
Is the entity an enterprise organization.
"""

is_pro: Optional[bool]
is_enterprise: Optional[bool]

def __init__(self, **kwargs) -> None: # type: ignore
self.is_pro = kwargs.pop("isPro", None)
self.is_enterprise = kwargs.pop("isEnterprise", None)


class CustomHfApi(HfApi): # type: ignore
@validate_hf_hub_args # type: ignore
def whoisthis(
self,
name: str,
*,
timeout: Optional[float] = None,
token: Optional[Union[bool, str]] = None,
) -> EntityInfo:
"""
Get information on an entity on huggingface.co.
You have to pass an acceptable token.
Args:
name (`str`):
Name of a user or an organization.
timeout (`float`, *optional*):
Whether to set a timeout for the request to the Hub.
token (`bool` or `str`, *optional*):
A valid authentication token (see https://huggingface.co/settings/token).
If `None` or `True` and machine is logged in (through `huggingface-cli login`
or [`~huggingface_hub.login`]), token will be retrieved from the cache.
If `False`, token is not sent in the request header.
Returns:
[`hf_api.EntityInfo`]: The entity information.
"""
headers = self._build_hf_headers(token=token)
path = f"{self.endpoint}/api/whoisthis"
params = {"name": name}

r = get_session().get(path, headers=headers, timeout=timeout, params=params)
hf_raise_for_status(r)
data = r.json()
return EntityInfo(**data)


def get_dataset_info(
dataset: str,
hf_endpoint: str,
Expand All @@ -32,6 +103,20 @@ def get_dataset_info(
)


def get_entity_info(
author: str,
hf_endpoint: str,
hf_token: Optional[str] = None,
hf_timeout_seconds: Optional[float] = None,
) -> EntityInfo:
# let's the exceptions bubble up if any
return CustomHfApi(endpoint=hf_endpoint).whoisthis( # type: ignore
name=author,
token=hf_token,
timeout=hf_timeout_seconds,
)


def get_latest_dataset_revision_if_supported_or_raise(
dataset: str,
hf_endpoint: str,
Expand Down Expand Up @@ -60,7 +145,16 @@ def get_latest_dataset_revision_if_supported_or_raise(
# ^ in most cases, get_dataset_info should already have raised. Anyway, we double-check here.
raise NotSupportedDisabledRepositoryError(f"Not supported: dataset repository {dataset} is disabled.")
if dataset_info.private:
raise NotSupportedPrivateRepositoryError(f"Not supported: dataset repository {dataset} is private.")
author = dataset_info.author
if not author:
raise ValueError(f"Cannot get the author of dataset {dataset}.")
entity_info = get_entity_info(
author=author, hf_endpoint=hf_endpoint, hf_token=hf_token, hf_timeout_seconds=hf_timeout_seconds
)
if (not entity_info.is_pro) and (not entity_info.is_enterprise):
raise NotSupportedPrivateRepositoryError(
f"Not supported: dataset repository {dataset} is private. Private datasets are only supported for pro users and enterprise organizations."
)
if dataset_info.cardData and not dataset_info.cardData.get("viewer", True):
raise NotSupportedDisabledViewerError(f"Not supported: dataset viewer is disabled in {dataset} configuration.")
if blocked_datasets:
Expand Down
60 changes: 54 additions & 6 deletions libs/libcommon/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from contextlib import contextmanager
from http import HTTPStatus
from pathlib import Path
from typing import Optional
from unittest.mock import patch

import pytest
Expand All @@ -19,7 +20,12 @@
NotSupportedPrivateRepositoryError,
NotSupportedRepositoryNotFoundError,
)
from libcommon.operations import delete_dataset, get_latest_dataset_revision_if_supported_or_raise, update_dataset
from libcommon.operations import (
CustomHfApi,
delete_dataset,
get_latest_dataset_revision_if_supported_or_raise,
update_dataset,
)
from libcommon.queue import Queue
from libcommon.resources import CacheMongoResource, QueueMongoResource
from libcommon.simple_cache import has_some_cache, upsert_response
Expand Down Expand Up @@ -58,6 +64,21 @@ def test_get_revision_timeout() -> None:
)


@pytest.mark.parametrize(
"name,expected_pro,expected_enterprise",
[
(NORMAL_USER, False, None),
(PRO_USER, True, None),
(NORMAL_ORG, None, False),
(ENTERPRISE_ORG, None, True),
],
)
def test_whoisthis(name: str, expected_pro: Optional[bool], expected_enterprise: Optional[bool]) -> None:
entity_info = CustomHfApi(endpoint=CI_HUB_ENDPOINT).whoisthis(name=name, token=CI_APP_TOKEN)
assert entity_info.is_pro == expected_pro
assert entity_info.is_enterprise == expected_enterprise


@contextmanager
def tmp_dataset(namespace: str, token: str, private: bool) -> Iterator[str]:
# create a test dataset in hub-ci, then delete it
Expand All @@ -79,27 +100,37 @@ def tmp_dataset(namespace: str, token: str, private: bool) -> Iterator[str]:
[
(NORMAL_USER_TOKEN, NORMAL_USER),
(NORMAL_USER_TOKEN, NORMAL_ORG),
(PRO_USER_TOKEN, PRO_USER),
(ENTERPRISE_USER_TOKEN, ENTERPRISE_USER),
(ENTERPRISE_USER_TOKEN, ENTERPRISE_ORG),
],
)
def test_get_revision_private(token: str, namespace: str) -> None:
def test_get_revision_private_raises(token: str, namespace: str) -> None:
with tmp_dataset(namespace=namespace, token=token, private=True) as dataset:
with pytest.raises(NotSupportedPrivateRepositoryError):
get_latest_dataset_revision_if_supported_or_raise(
dataset=dataset, hf_endpoint=CI_HUB_ENDPOINT, hf_token=CI_APP_TOKEN
)


@pytest.mark.parametrize(
"token,namespace",
[
(PRO_USER_TOKEN, PRO_USER),
(ENTERPRISE_USER_TOKEN, ENTERPRISE_ORG),
],
)
def test_get_revision_private(token: str, namespace: str) -> None:
with tmp_dataset(namespace=namespace, token=token, private=True) as dataset:
get_latest_dataset_revision_if_supported_or_raise(
dataset=dataset, hf_endpoint=CI_HUB_ENDPOINT, hf_token=CI_APP_TOKEN
)


@pytest.mark.parametrize(
"token,namespace",
[
(NORMAL_USER_TOKEN, NORMAL_USER),
(NORMAL_USER_TOKEN, NORMAL_ORG),
(PRO_USER_TOKEN, PRO_USER),
(ENTERPRISE_USER_TOKEN, ENTERPRISE_USER),
(ENTERPRISE_USER_TOKEN, ENTERPRISE_ORG),
],
)
def test_update_private_raises(
Expand Down Expand Up @@ -157,6 +188,23 @@ def test_update_disabled_dataset_raises_way_2(
update_dataset(dataset=dataset, hf_endpoint=CI_HUB_ENDPOINT, hf_token=CI_APP_TOKEN)


@pytest.mark.parametrize(
"token,namespace",
[
(PRO_USER_TOKEN, PRO_USER),
(ENTERPRISE_USER_TOKEN, ENTERPRISE_ORG),
],
)
def test_update_private(
queue_mongo_resource: QueueMongoResource,
cache_mongo_resource: CacheMongoResource,
token: str,
namespace: str,
) -> None:
with tmp_dataset(namespace=namespace, token=token, private=True) as dataset:
update_dataset(dataset=dataset, hf_endpoint=CI_HUB_ENDPOINT, hf_token=CI_APP_TOKEN)


@pytest.mark.parametrize(
"token,namespace",
[
Expand Down
61 changes: 57 additions & 4 deletions services/api/tests/routes/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
NotSupportedDisabledViewerError,
NotSupportedPrivateRepositoryError,
)
from libcommon.operations import EntityInfo
from libcommon.processing_graph import processing_graph
from libcommon.queue import Queue
from libcommon.simple_cache import upsert_response
Expand Down Expand Up @@ -267,17 +268,17 @@ def test_get_cache_entry_from_steps_no_cache_private() -> None:
dataset = "dataset"
revision = "revision"
config = "config"
author = "author"

app_config = AppConfig.from_env()

no_cache = "config-is-valid"

with patch(
"libcommon.operations.get_dataset_info",
return_value=DatasetInfo(id=dataset, sha=revision, private=True, downloads=0, likes=0, tags=[]),
):
# ^ the dataset does not exist on the Hub, we don't want to raise an issue here

return_value=DatasetInfo(id=dataset, sha=revision, private=True, downloads=0, likes=0, tags=[], author=author),
), patch("libcommon.operations.get_entity_info", return_value=EntityInfo(isPro=False, isEnterprise=False)):
# ^ the dataset and the author do not exist on the Hub, we don't want to raise an issue here
with raises(NotSupportedPrivateRepositoryError):
get_cache_entry_from_steps(
processing_step_names=[no_cache],
Expand All @@ -289,6 +290,58 @@ def test_get_cache_entry_from_steps_no_cache_private() -> None:
)


def test_get_cache_entry_from_steps_no_cache_private_pro() -> None:
dataset = "dataset"
revision = "revision"
config = "config"
author = "author"

app_config = AppConfig.from_env()

no_cache = "config-is-valid"

with patch(
"libcommon.operations.get_dataset_info",
return_value=DatasetInfo(id=dataset, sha=revision, private=True, downloads=0, likes=0, tags=[], author=author),
), patch("libcommon.operations.get_entity_info", return_value=EntityInfo(isPro=True, isEnterprise=False)):
# ^ the dataset and the author do not exist on the Hub, we don't want to raise an issue here
with raises(ResponseNotReadyError):
get_cache_entry_from_steps(
processing_step_names=[no_cache],
dataset=dataset,
config=config,
split=None,
hf_endpoint=app_config.common.hf_endpoint,
blocked_datasets=[],
)


def test_get_cache_entry_from_steps_no_cache_private_enterprise() -> None:
dataset = "dataset"
revision = "revision"
config = "config"
author = "author"

app_config = AppConfig.from_env()

no_cache = "config-is-valid"

with patch(
"libcommon.operations.get_dataset_info",
return_value=DatasetInfo(id=dataset, sha=revision, private=True, downloads=0, likes=0, tags=[], author=author),
), patch("libcommon.operations.get_entity_info", return_value=EntityInfo(isPro=False, isEnterprise=True)):
# ^ the dataset and the author do not exist on the Hub, we don't want to raise an issue here
with raises(ResponseNotReadyError):
get_cache_entry_from_steps(
processing_step_names=[no_cache],
dataset=dataset,
config=config,
split=None,
hf_endpoint=app_config.common.hf_endpoint,
blocked_datasets=[],
)


def test_get_cache_entry_from_steps_no_cache_blocked() -> None:
dataset = "dataset"
revision = "revision"
Expand Down

0 comments on commit 9d157f3

Please sign in to comment.