diff --git a/libs/libcommon/src/libcommon/croissant_utils.py b/libs/libcommon/src/libcommon/croissant_utils.py index d9e403536..d22cd12d3 100644 --- a/libs/libcommon/src/libcommon/croissant_utils.py +++ b/libs/libcommon/src/libcommon/croissant_utils.py @@ -2,7 +2,7 @@ # Copyright 2024 The HuggingFace Authors. from collections.abc import Mapping -from typing import Any, Union +from typing import Any, Optional, Union from datasets import ClassLabel, Image, Sequence, Value @@ -53,8 +53,23 @@ def truncate_features_from_croissant_crumbs_response(content: Mapping[str, Any]) } +def get_source( + distribution_name: str, column: str, add_transform: bool, json_path: Optional[str] = None +) -> dict[str, Any]: + """Returns a Source dictionary for a Field.""" + source = {"fileSet": {"@id": distribution_name}, "extract": {"column": column}} + if add_transform and json_path: + source["transform"] = {"jsonPath": json_path} + return source + + def feature_to_croissant_field( - distribution_name: str, field_name: str, column: str, feature: Any + distribution_name: str, + field_name: str, + column: str, + feature: Any, + add_transform: bool = False, + json_path: Optional[str] = None, ) -> Union[dict[str, Any], None]: """Converts a Hugging Face Datasets feature to a Croissant field or None if impossible.""" if isinstance(feature, Value) and feature.dtype in HF_TO_CROISSANT_VALUE_TYPE: @@ -64,20 +79,21 @@ def feature_to_croissant_field( "name": field_name, "description": f"Column '{column}' from the Hugging Face parquet file.", "dataType": HF_TO_CROISSANT_VALUE_TYPE[feature.dtype], - "source": {"fileSet": {"@id": distribution_name}, "extract": {"column": column}}, + "source": get_source(distribution_name, column, add_transform, json_path), } elif isinstance(feature, Image): + source = get_source(distribution_name, column, add_transform, json_path) + if transform := source.get("transform"): + source["transform"] = [transform, {"jsonPath": "bytes"}] + else: + source["transform"] = {"jsonPath": "bytes"} return { "@type": "cr:Field", "@id": field_name, "name": field_name, "description": f"Image column '{column}' from the Hugging Face parquet file.", "dataType": "sc:ImageObject", - "source": { - "fileSet": {"@id": distribution_name}, - "extract": {"column": column}, - "transform": {"jsonPath": "bytes"}, - }, + "source": source, } elif isinstance(feature, ClassLabel): return { @@ -87,7 +103,26 @@ def feature_to_croissant_field( "description": f"ClassLabel column '{column}' from the Hugging Face parquet file.\nLabels:\n" + ", ".join(f"{name} ({i})" for i, name in enumerate(feature.names)), "dataType": "sc:Integer", - "source": {"fileSet": {"@id": distribution_name}, "extract": {"column": column}}, + "source": get_source(distribution_name, column, add_transform, json_path), + } + # Field with sub-fields. + elif isinstance(feature, dict): + return { + "@type": "cr:Field", + "@id": field_name, + "name": field_name, + "description": f"Column '{column}' from the Hugging Face parquet file.", + "subField": [ + feature_to_croissant_field( + distribution_name, + f"{field_name}/{subfeature_name}", + column, + sub_feature, + add_transform=True, + json_path=subfeature_name, + ) + for subfeature_name, sub_feature in feature.items() + ], } elif isinstance(feature, (Sequence, list)): if isinstance(feature, Sequence): diff --git a/services/worker/tests/job_runners/dataset/test_croissant_crumbs.py b/services/worker/tests/job_runners/dataset/test_croissant_crumbs.py index 6355cbd36..aa139d906 100644 --- a/services/worker/tests/job_runners/dataset/test_croissant_crumbs.py +++ b/services/worker/tests/job_runners/dataset/test_croissant_crumbs.py @@ -128,30 +128,44 @@ def test_get_croissant_crumbs_from_dataset_infos() -> None: assert croissant_crumbs["recordSet"][3]["name"] == "record_set_user_squad_with_space_0" assert isinstance(croissant_crumbs["recordSet"][1]["field"], list) assert isinstance(squad_info["features"], dict) - assert "1 skipped column: answers" in croissant_crumbs["recordSet"][1]["description"] + assert "skipped column" not in croissant_crumbs["recordSet"][1]["description"] assert croissant_crumbs["recordSet"][1]["@id"] == "record_set_user_squad_with_space" assert croissant_crumbs["recordSet"][3]["@id"] == "record_set_user_squad_with_space_0" for i in [1, 3]: for field in croissant_crumbs["recordSet"][i]["field"]: - assert "source" in field - assert "fileSet" in field["source"] - assert "@id" in field["source"]["fileSet"] - assert field["source"]["fileSet"]["@id"] - assert "extract" in field["source"] + if "subField" not in field: + assert "source" in field + assert "fileSet" in field["source"] + assert "@id" in field["source"]["fileSet"] + assert field["source"]["fileSet"]["@id"] + assert "extract" in field["source"] + else: + for sub_field in field["subField"]: + assert "source" in sub_field + assert "fileSet" in sub_field["source"] + assert "@id" in sub_field["source"]["fileSet"] + assert sub_field["source"]["fileSet"]["@id"] + assert "extract" in sub_field["source"] + assert "transform" in sub_field["source"] if field["description"] == "Split to which the example belongs to.": assert "regex" in field["source"]["transform"] assert field["source"]["extract"]["fileProperty"] == "fullpath" assert field["references"]["field"]["@id"] == croissant_crumbs["recordSet"][i - 1]["field"][0]["@id"] else: - assert field["source"]["extract"]["column"] == field["@id"].split("/")[-1] + if "subField" not in field: + assert field["source"]["extract"]["column"] == field["@id"].split("/")[-1] + else: + for sub_field in field["subField"]: + assert sub_field["source"]["extract"]["column"] == field["@id"].split("/")[-1] # Test fields. - assert len(croissant_crumbs["recordSet"][1]["field"]) == 5 - assert len(croissant_crumbs["recordSet"][3]["field"]) == 5 + assert len(croissant_crumbs["recordSet"][1]["field"]) == 6 + assert len(croissant_crumbs["recordSet"][3]["field"]) == 6 for field in croissant_crumbs["recordSet"][1]["field"]: assert field["@type"] == "cr:Field" - assert field["dataType"] == "sc:Text" - assert len(croissant_crumbs["recordSet"][1]["field"]) == len(squad_info["features"]) + if "subField" not in field: + assert field["dataType"] == "sc:Text" + assert len(croissant_crumbs["recordSet"][1]["field"]) == len(squad_info["features"]) + 1 # Test distribution. assert "distribution" in croissant_crumbs