From f031ec75e041fe58fe08af463d2f4a8f3a5bd945 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Sat, 12 Oct 2024 17:10:37 +0100 Subject: [PATCH] fallback to default feature casting in case user defined features not available during dataset loading --- src/datasets/features/features.py | 7 ++++++- src/datasets/info.py | 8 +++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 1d241e0b7b7..49eab94d0a0 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1788,7 +1788,12 @@ def from_arrow_schema(cls, pa_schema: pa.Schema) -> "Features": if pa_schema.metadata is not None and "huggingface".encode("utf-8") in pa_schema.metadata: metadata = json.loads(pa_schema.metadata["huggingface".encode("utf-8")].decode()) if "info" in metadata and "features" in metadata["info"] and metadata["info"]["features"] is not None: - metadata_features = Features.from_dict(metadata["info"]["features"]) + try: + metadata_features = Features.from_dict(metadata["info"]["features"]) + except Exception as e: + logger.warning( + f"Warning: failed to load features from Arrow schema metadata: {e}, decoding may not be as intended" + ) metadata_features_schema = metadata_features.arrow_schema obj = { field.name: ( diff --git a/src/datasets/info.py b/src/datasets/info.py index d9e4cad598f..99f5f7f739d 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -314,7 +314,13 @@ def _to_yaml_dict(self) -> dict: def _from_yaml_dict(cls, yaml_data: dict) -> "DatasetInfo": yaml_data = copy.deepcopy(yaml_data) if yaml_data.get("features") is not None: - yaml_data["features"] = Features._from_yaml_list(yaml_data["features"]) + try: + yaml_data["features"] = Features._from_yaml_list(yaml_data["features"]) + except Exception as e: + logger.warning( + f"Warning: failed to load features from Arrow schema metadata: {e}, decoding may not be as intended" + ) + del yaml_data["features"] if yaml_data.get("splits") is not None: yaml_data["splits"] = SplitDict._from_yaml_list(yaml_data["splits"]) field_names = {f.name for f in dataclasses.fields(cls)}