From 5a498aa8521f4e6637eb7767bc6b1e465b30c9c6 Mon Sep 17 00:00:00 2001 From: Sylvain Lesage Date: Fri, 14 Jun 2024 19:18:50 +0200 Subject: [PATCH] Detect dataset modalities using dataset-filetypes (#2909) * trigger dataset-modalities from dataset-filetypes * prepare for another modality detection method * add detection of modality per file extension * fix test and types * log in case of exception (removes bandit warning) + add tests * fix test * refresh all the entries * Additional modalities detection (#2912) * extract features modality detection to method * detect tabular datasets * add two modalities + fix comment * add simple time-series detection * add comment * add test --------- Co-authored-by: Quentin Lhoest * .tif/.tiff -> image, not geospatial --------- Co-authored-by: Quentin Lhoest --- .../src/libcommon/processing_graph.py | 4 +- libs/libcommon/tests/test_operations.py | 2 +- libs/libcommon/tests/test_processing_graph.py | 7 +- services/worker/src/worker/dtos.py | 2 +- .../worker/job_runners/dataset/modalities.py | 253 ++++++++++++++++-- .../job_runners/dataset/test_modalities.py | 202 +++++++++++++- 6 files changed, 434 insertions(+), 36 deletions(-) diff --git a/libs/libcommon/src/libcommon/processing_graph.py b/libs/libcommon/src/libcommon/processing_graph.py index e66e2703b6..81b321a262 100644 --- a/libs/libcommon/src/libcommon/processing_graph.py +++ b/libs/libcommon/src/libcommon/processing_graph.py @@ -708,8 +708,8 @@ def parse_id(id: str) -> tuple[str, str, Optional[str], Optional[str], str]: }, "dataset-modalities": { "input_type": "dataset", - "triggered_by": "dataset-info", - "job_runner_version": 1, + "triggered_by": ["dataset-info", "dataset-filetypes"], + "job_runner_version": 2, "difficulty": 20, }, "dataset-croissant-crumbs": { diff --git a/libs/libcommon/tests/test_operations.py b/libs/libcommon/tests/test_operations.py index 73d0636381..2e1ff101b1 100644 --- a/libs/libcommon/tests/test_operations.py +++ b/libs/libcommon/tests/test_operations.py @@ -427,7 +427,7 @@ def test_2274_only_first_steps( } ) - assert len(queue.get_pending_jobs_df(dataset=dataset)) == 7 + assert len(queue.get_pending_jobs_df(dataset=dataset)) == 8 assert len(get_cache_entries_df(dataset=dataset)) == 2 # let's delete all the jobs, to get in the same state as the bug diff --git a/libs/libcommon/tests/test_processing_graph.py b/libs/libcommon/tests/test_processing_graph.py index f3fcb78743..6d94d7a477 100644 --- a/libs/libcommon/tests/test_processing_graph.py +++ b/libs/libcommon/tests/test_processing_graph.py @@ -168,8 +168,8 @@ def test_graph() -> None: ( "dataset-modalities", ["dataset-hub-cache"], - ["dataset-info"], - ["dataset-config-names", "config-parquet-and-info", "config-info", "dataset-info"], + ["dataset-info", "dataset-filetypes"], + ["dataset-config-names", "config-parquet-and-info", "config-info", "dataset-info", "dataset-filetypes"], ), ( "dataset-is-valid", @@ -363,6 +363,7 @@ def test_graph() -> None: "config-size", "config-split-names", "dataset-config-names", + "dataset-filetypes", "dataset-info", "dataset-is-valid", "dataset-compatible-libraries", @@ -417,7 +418,7 @@ def test_graph() -> None: ), ( "dataset-filetypes", - [], + ["dataset-modalities"], [], [], ), diff --git a/services/worker/src/worker/dtos.py b/services/worker/src/worker/dtos.py index 1a6d210a66..ac57c6eaa9 100644 --- a/services/worker/src/worker/dtos.py +++ b/services/worker/src/worker/dtos.py @@ -333,7 +333,7 @@ class DatasetCompatibleLibrariesResponse(TypedDict): formats: list[DatasetFormat] -DatasetModality = Literal["image", "audio", "text"] +DatasetModality = Literal["image", "audio", "text", "video", "geospatial", "3d", "tabular", "timeseries"] class DatasetModalitiesResponse(TypedDict): diff --git a/services/worker/src/worker/job_runners/dataset/modalities.py b/services/worker/src/worker/job_runners/dataset/modalities.py index 1f4c03eea3..c834ce6d28 100644 --- a/services/worker/src/worker/job_runners/dataset/modalities.py +++ b/services/worker/src/worker/job_runners/dataset/modalities.py @@ -3,7 +3,7 @@ import logging -from datasets import Audio, Features, Image, Translation, TranslationVariableLanguages, Value +from datasets import Audio, Features, Image, Sequence, Translation, TranslationVariableLanguages, Value from datasets.features.features import FeatureType, _visit from libcommon.exceptions import PreviousStepFormatError from libcommon.simple_cache import ( @@ -18,9 +18,65 @@ from worker.job_runners.dataset.dataset_job_runner import DatasetJobRunner -def compute_modalities_response(dataset: str) -> DatasetModalitiesResponse: +def detect_features_modalities(features: Features) -> set[DatasetModality]: """ - Get the response of 'dataset-modalities' for one specific dataset on huggingface.co. + Detect modalities of a dataset using the features (column types). + + Args: + features (`datasets.Features`): + The features of a config. + + Returns: + `set[DatasetModality]`: A set of modalities. + """ + modalities: set[DatasetModality] = set() + + def classify_modality(feature: FeatureType) -> None: + nonlocal modalities + if isinstance(feature, Audio): + modalities.add("audio") + elif isinstance(feature, Image): + modalities.add("image") + elif isinstance(feature, Value) and feature.dtype in ("string", "large_string"): + modalities.add("text") + elif isinstance(feature, (Translation, TranslationVariableLanguages)): + modalities.add("text") + + _visit(features, classify_modality) + + # detection of tabular data: if there are at least two top-level numerical columns, and no "media" columns + if ( + not ("audio" in modalities or "image" in modalities) + and len( + [ + feature + for feature in features.values() + if isinstance(feature, Value) and ("int" in feature.dtype or "float" in feature.dtype) + ] + ) + >= 2 + ): + modalities.add("tabular") + + # detection of time series + if any( + "emb" not in column_name # ignore lists of floats that may be embeddings + and ( + (isinstance(feature, Sequence) and feature.feature == Value("float32")) + or (isinstance(feature, list) and feature[0] == Value("float32")) + ) + for column_name, feature in features.items() + ): + modalities.add("timeseries") + # other idea: detect datasets with only numerical columns and one timestamp column + # (and ideally be able to detect dates/timestamps even from a column with string type) + + return modalities + + +def detect_modalities_from_features(dataset: str) -> set[DatasetModality]: + """ + Detect modalities of a dataset using the features (column types). Args: dataset (`str`): @@ -33,10 +89,8 @@ def compute_modalities_response(dataset: str) -> DatasetModalitiesResponse: If the content of the previous step has not the expected format Returns: - `tuple[DatasetModalitiesResponse, float]`: An object with the modalities_response and the progress. + `set[DatasetModality]`: A set of modalities. """ - logging.info(f"compute 'dataset-modalities' for {dataset=}") - dataset_info_response = get_previous_step_or_raise(kind="dataset-info", dataset=dataset) content = dataset_info_response["content"] if "dataset_info" not in content or not isinstance(content["dataset_info"], dict): @@ -44,25 +98,186 @@ def compute_modalities_response(dataset: str) -> DatasetModalitiesResponse: try: modalities: set[DatasetModality] = set() + for config_info in content["dataset_info"].values(): + modalities.update(detect_features_modalities(features=Features.from_dict(config_info["features"]))) + except Exception as e: + raise PreviousStepFormatError("Previous step did not return the expected content.", e) from e - def classify_modality(feature: FeatureType) -> None: - nonlocal modalities - if isinstance(feature, Audio): - modalities.add("audio") - elif isinstance(feature, Image): - modalities.add("image") - elif isinstance(feature, Value) and feature.dtype in ("string", "large_string"): - modalities.add("text") - elif isinstance(feature, (Translation, TranslationVariableLanguages)): - modalities.add("text") + return modalities - for config_info in content["dataset_info"].values(): - features = Features.from_dict(config_info["features"]) - _visit(features, classify_modality) +def detect_modalities_from_filetypes(dataset: str) -> set[DatasetModality]: + """ + Detect modalities of a dataset using the repository file extensions. + + Args: + dataset (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + + Raises: + [~`libcommon.simple_cache.CachedArtifactError`]: + If the previous step gave an error. + [~`libcommon.exceptions.PreviousStepFormatError`]: + If the content of the previous step has not the expected format + + Returns: + `set[DatasetModality]`: A set of modalities. + """ + dataset_filetypes_response = get_previous_step_or_raise(kind="dataset-filetypes", dataset=dataset) + content = dataset_filetypes_response["content"] + if "filetypes" not in content or not isinstance(content["filetypes"], list): + raise PreviousStepFormatError("Previous step did not return the expected content: 'filetypes'.") + + # from https://developer.mozilla.org/en-US/docs/Web/Media/Formats/Image_types + IMAGE_EXTENSIONS = ( + ".apng", + ".avif", + ".gif", + ".jpg", + ".jpeg", + ".jfif", + ".pjpeg", + ".pjp", + ".png", + ".svg", + "webp", + ".bmp", + ".ico", + ".cur", + ".tif", + ".tiff", + ) + # from https://developer.mozilla.org/en-US/docs/Web/Media/Formats/Containers#browser_compatibility + others + AUDIO_EXTENSIONS = ( + ".aac", + ".flac", + ".mp3", + ".m4a", + ".oga", + ".wav", + # other audio formats + ".weba", + ".opus", + ".spx", + ".wma", + ".aiff", + ".ape", + ".mka", + ".wv", + ".tak", + ) + AUDIO_BUT_COULD_ALSO_BE_VIDEO_EXTENSIONS = (".ogg",) + VIDEO_EXTENSIONS = ( + ".m4v", + ".m4p", + ".ogv", + ".mov", + ".mkv", + # other video formats + ".avi", + ".wmv", + ".flv", + ) + VIDEO_BUT_COULD_ALSO_BE_AUDIO_EXTENSIONS = (".3gp", ".mpg", ".mpeg", ".mp4", ".webm") + GEOSPATIAL_EXTENSIONS = ( + # vectorial + ".shp", + ".shx", + ".dbf", + ".prj", + ".cpg", + ".kml", + ".kmz", + ".gpx", + ".geojson", + ".topojson", + ".gml", + ".geoparquet", + ".fgb", + # raster + ".img", + ".bil", + ".bip", + ".bsq", + # geotiff uses .tif or .tiff, but better to just show "image" modality + # than wrongly put "geospatial" if it only contains tif images + # ".tif", + # ".tiff", + # vectorial or raster + ".gpkg", + ".mbtiles", + ".pmtiles", + ) + _3D_EXTENSIONS = ( + # from https://docs.unity3d.com/Manual/3D-formats.html + ".fbx", + ".dae", + ".dxf", + ".obj", + # other 3D formats + ".stl", + ".ply", + ".gltf", + ".glb", + ".usdz", + ) + TEXT_EXTENSIONS = (".txt",) + try: + modalities: set[DatasetModality] = set() + for filetype in content["filetypes"]: + # TODO: should we condition by a number of files (filetype["count"] > threshold) to avoid false positives? + if filetype["extension"] in IMAGE_EXTENSIONS: + modalities.add("image") + elif filetype["extension"] in AUDIO_EXTENSIONS + AUDIO_BUT_COULD_ALSO_BE_VIDEO_EXTENSIONS: + modalities.add("audio") + elif filetype["extension"] in VIDEO_EXTENSIONS + VIDEO_BUT_COULD_ALSO_BE_AUDIO_EXTENSIONS: + modalities.add("video") + elif filetype["extension"] in GEOSPATIAL_EXTENSIONS: + modalities.add("geospatial") + elif filetype["extension"] in _3D_EXTENSIONS: + modalities.add("3d") + elif filetype["extension"] in TEXT_EXTENSIONS: + modalities.add("text") except Exception as e: raise PreviousStepFormatError("Previous step did not return the expected content.", e) from e + return modalities + + +def compute_modalities_response(dataset: str) -> DatasetModalitiesResponse: + """ + Get the response of 'dataset-modalities' for one specific dataset on huggingface.co. + + Args: + dataset (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + + Raises: + [~`libcommon.exceptions.PreviousStepFormatError`]: + If the content of the previous step has not the expected format + + Returns: + `tuple[DatasetModalitiesResponse, float]`: An object with the modalities_response and the progress. + """ + logging.info(f"compute 'dataset-modalities' for {dataset=}") + + modalities: set[DatasetModality] = set() + try: + modalities.update(detect_modalities_from_features(dataset)) + except PreviousStepFormatError: + raise + except Exception: + logging.info(f"failed to detect modalities from features of {dataset=}") + pass + + try: + modalities.update(detect_modalities_from_filetypes(dataset)) + except PreviousStepFormatError: + raise + except Exception: + logging.info(f"failed to detect modalities from file types of {dataset=}") + pass + return DatasetModalitiesResponse( { "modalities": sorted(modalities), diff --git a/services/worker/tests/job_runners/dataset/test_modalities.py b/services/worker/tests/job_runners/dataset/test_modalities.py index ad282db02e..e86271ac80 100644 --- a/services/worker/tests/job_runners/dataset/test_modalities.py +++ b/services/worker/tests/job_runners/dataset/test_modalities.py @@ -6,15 +6,15 @@ from typing import Any import pytest -from datasets import Features, Image, Value +from datasets import Features, Image, Sequence, Value from libcommon.dtos import Priority +from libcommon.exceptions import PreviousStepFormatError from libcommon.resources import CacheMongoResource, QueueMongoResource -from libcommon.simple_cache import CachedArtifactError, upsert_response +from libcommon.simple_cache import upsert_response from worker.config import AppConfig -from worker.job_runners.dataset.modalities import ( - DatasetModalitiesJobRunner, -) +from worker.dtos import DatasetModalitiesResponse +from worker.job_runners.dataset.modalities import DatasetModalitiesJobRunner from worker.resources import LibrariesResource from ..utils import REVISION_NAME, UpstreamResponse @@ -29,11 +29,18 @@ def prepare_and_clean_mongo(app_config: AppConfig) -> None: GetJobRunner = Callable[[str, AppConfig], DatasetModalitiesJobRunner] TEXT_DATASET = "text-dataset" +TABULAR_DATASET = "tabular-dataset" IMAGE_TEXT_DATASET = "image-text-dataset" +IMAGE_DATASET = "image-dataset" +TIME_SERIES_DATASET = "time-series-dataset" ERROR_DATASET = "error-dataset" text_features = Features({"conversations": [{"from": Value("string"), "value": Value("string")}]}) image_text_features = Features({"image": Image(), "caption": Value("string")}) +tabular_features = Features({"col1": Value("int8"), "col2": Value("float32")}) +not_tabular_features_1 = Features({"col1": Value("int8"), "col2": Value("float32"), "image": Image()}) +not_tabular_features_2 = Features({"col1": Value("int8"), "col2": Value("string")}) +time_series_features = Features({"window": Sequence(Value("float32")), "target": Value("float32")}) UPSTREAM_RESPONSE_INFO_TEXT: UpstreamResponse = UpstreamResponse( kind="dataset-info", @@ -57,7 +64,51 @@ def prepare_and_clean_mongo(app_config: AppConfig) -> None: }, progress=1.0, ) -UPSTREAM_RESPONSE_INFD_ERROR: UpstreamResponse = UpstreamResponse( +UPSTREAM_RESPONSE_INFO_TABULAR: UpstreamResponse = UpstreamResponse( + kind="dataset-info", + dataset=TABULAR_DATASET, + dataset_git_revision=REVISION_NAME, + http_status=HTTPStatus.OK, + content={ + "dataset_info": {"default": {"config_name": "default", "features": tabular_features.to_dict()}}, + "partial": False, + }, + progress=1.0, +) +UPSTREAM_RESPONSE_INFO_NOT_TABULAR_1: UpstreamResponse = UpstreamResponse( + kind="dataset-info", + dataset=IMAGE_DATASET, + dataset_git_revision=REVISION_NAME, + http_status=HTTPStatus.OK, + content={ + "dataset_info": {"default": {"config_name": "default", "features": not_tabular_features_1.to_dict()}}, + "partial": False, + }, + progress=1.0, +) +UPSTREAM_RESPONSE_INFO_NOT_TABULAR_2: UpstreamResponse = UpstreamResponse( + kind="dataset-info", + dataset=TEXT_DATASET, + dataset_git_revision=REVISION_NAME, + http_status=HTTPStatus.OK, + content={ + "dataset_info": {"default": {"config_name": "default", "features": not_tabular_features_2.to_dict()}}, + "partial": False, + }, + progress=1.0, +) +UPSTREAM_RESPONSE_INFO_TIME_SERIES: UpstreamResponse = UpstreamResponse( + kind="dataset-info", + dataset=TIME_SERIES_DATASET, + dataset_git_revision=REVISION_NAME, + http_status=HTTPStatus.OK, + content={ + "dataset_info": {"default": {"config_name": "default", "features": time_series_features.to_dict()}}, + "partial": False, + }, + progress=1.0, +) +UPSTREAM_RESPONSE_INFO_ERROR: UpstreamResponse = UpstreamResponse( kind="dataset-info", dataset=ERROR_DATASET, dataset_git_revision=REVISION_NAME, @@ -65,14 +116,88 @@ def prepare_and_clean_mongo(app_config: AppConfig) -> None: content={}, progress=0.0, ) -EXPECTED_TEXT = ( +UPSTREAM_RESPONSE_INFO_MALFORMED: UpstreamResponse = UpstreamResponse( + kind="dataset-info", + dataset=ERROR_DATASET, + dataset_git_revision=REVISION_NAME, + http_status=HTTPStatus.OK, + # The content is missing the "dataset_info" key + content={"bad": "content"}, + progress=0.0, +) + +UPSTREAM_RESPONSE_FILETYPES_TEXT: UpstreamResponse = UpstreamResponse( + kind="dataset-filetypes", + dataset=TEXT_DATASET, + dataset_git_revision=REVISION_NAME, + http_status=HTTPStatus.OK, + content={ + "filetypes": [ + {"extension": ".txt", "count": 1, "compressed_in": ".gz"}, + {"extension": ".gz", "count": 1}, + ], + "partial": False, + }, + progress=1.0, +) +UPSTREAM_RESPONSE_FILETYPES_ALL: UpstreamResponse = UpstreamResponse( + kind="dataset-filetypes", + dataset=TEXT_DATASET, + dataset_git_revision=REVISION_NAME, + http_status=HTTPStatus.OK, + content={ + "filetypes": [ + {"extension": ".txt", "count": 1, "compressed_in": ".gz"}, + {"extension": ".avi", "count": 1}, + {"extension": ".geoparquet", "count": 1, "archived_in": ".zip"}, + {"extension": ".gz", "count": 1}, + {"extension": ".zip", "count": 1}, + {"extension": ".jpg", "count": 1}, + {"extension": ".wav", "count": 1}, + {"extension": ".gltf", "count": 1}, + ], + "partial": False, + }, + progress=1.0, +) + +EXPECTED_TEXT: tuple[DatasetModalitiesResponse, float] = ( {"modalities": ["text"]}, 1.0, ) -EXPECTED_IMAGE_TEXT = ( +EXPECTED_TABULAR: tuple[DatasetModalitiesResponse, float] = ( + {"modalities": ["tabular"]}, + 1.0, +) +EXPECTED_IMAGE: tuple[DatasetModalitiesResponse, float] = ( + {"modalities": ["image"]}, + 1.0, +) +EXPECTED_IMAGE_TEXT: tuple[DatasetModalitiesResponse, float] = ( {"modalities": ["image", "text"]}, 1.0, ) +EXPECTED_ALL_MODALITIES: tuple[DatasetModalitiesResponse, float] = ( + { + "modalities": [ + "3d", + "audio", + "geospatial", + "image", + "text", + "video", + ] + }, + 1.0, +) +EXPECTED_EMPTY: tuple[DatasetModalitiesResponse, float] = ( + {"modalities": []}, + 1.0, +) +EXPECTED_TIME_SERIES: tuple[DatasetModalitiesResponse, float] = ( + {"modalities": ["timeseries"]}, + 1.0, +) @pytest.fixture @@ -121,6 +246,63 @@ def _get_job_runner( ], EXPECTED_IMAGE_TEXT, ), + ( + TEXT_DATASET, + [ + UPSTREAM_RESPONSE_FILETYPES_TEXT, + ], + EXPECTED_TEXT, + ), + ( + TEXT_DATASET, + [ + UPSTREAM_RESPONSE_INFO_TEXT, + UPSTREAM_RESPONSE_FILETYPES_TEXT, + ], + EXPECTED_TEXT, + ), + ( + TEXT_DATASET, + [ + UPSTREAM_RESPONSE_FILETYPES_ALL, + ], + EXPECTED_ALL_MODALITIES, + ), + ( + ERROR_DATASET, + [ + UPSTREAM_RESPONSE_INFO_ERROR, + ], + EXPECTED_EMPTY, + ), + ( + TABULAR_DATASET, + [ + UPSTREAM_RESPONSE_INFO_TABULAR, + ], + EXPECTED_TABULAR, + ), + ( + IMAGE_DATASET, + [ + UPSTREAM_RESPONSE_INFO_NOT_TABULAR_1, + ], + EXPECTED_IMAGE, + ), + ( + TEXT_DATASET, + [ + UPSTREAM_RESPONSE_INFO_NOT_TABULAR_2, + ], + EXPECTED_TEXT, + ), + ( + TIME_SERIES_DATASET, + [ + UPSTREAM_RESPONSE_INFO_TIME_SERIES, + ], + EXPECTED_TIME_SERIES, + ), ], ) def test_compute( @@ -144,9 +326,9 @@ def test_compute( ( ERROR_DATASET, [ - UPSTREAM_RESPONSE_INFD_ERROR, + UPSTREAM_RESPONSE_INFO_MALFORMED, ], - pytest.raises(CachedArtifactError), + pytest.raises(PreviousStepFormatError), ) ], )