Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add CustomFeature base class to support user-defined features with encoding/decoding logic #7221

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
16 changes: 15 additions & 1 deletion src/datasets/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"Array4D",
"Array5D",
"ClassLabel",
"CustomFeature",
"Features",
"LargeList",
"Sequence",
Expand All @@ -14,8 +15,21 @@
"TranslationVariableLanguages",
"Video",
]


from .audio import Audio
from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value
from .features import (
Array2D,
Array3D,
Array4D,
Array5D,
ClassLabel,
CustomFeature,
Features,
LargeList,
Sequence,
Value,
)
from .image import Image
from .translation import Translation, TranslationVariableLanguages
from .video import Video
69 changes: 56 additions & 13 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,31 @@ def cast_to_python_objects(obj: Any, only_1d_for_numpy=False, optimize_list_cast
)[0]


class CustomFeature:
"""
Base class for feature types like Audio, Image, ClassLabel, etc that require special treatment (encoding/decoding).
"""

requires_encoding: ClassVar[bool] = False
requires_decoding: ClassVar[bool] = False

def encode_example(self, example):
if self.requires_encoding:
return self._encode_example(example)
return example

def _encode_example(self, example):
raise NotImplementedError("Should be implemented by child class if `requires_encoding` is True")

def decode_example(self, example):
if self.requires_decoding:
return self._decode_example(example)
return example

def _decode_example(self, example):
raise NotImplementedError("Should be implemented by child class if `requires_decoding` is True")


@dataclass
class Value:
"""
Expand Down Expand Up @@ -542,7 +567,7 @@ def __call__(self):
pa_type = globals()[self.__class__.__name__ + "ExtensionType"](self.shape, self.dtype)
return pa_type

def encode_example(self, value):
def _encode_example(self, value):
return value


Expand Down Expand Up @@ -1091,7 +1116,7 @@ def int2str(self, values: Union[int, Iterable]) -> Union[str, Iterable]:
output = [self._int2str[int(v)] for v in values]
return output if return_list else output[0]

def encode_example(self, example_data):
def _encode_example(self, example_data):
if self.num_classes is None:
raise ValueError(
"Trying to use ClassLabel feature with undefined number of class. "
Expand Down Expand Up @@ -1180,6 +1205,8 @@ class LargeList:
Child feature data type of each item within the large list.
"""

requires_encoding: ClassVar[bool] = True
requires_decoding: ClassVar[bool] = True
feature: Any
id: Optional[str] = None
# Automatically constructed
Expand All @@ -1203,6 +1230,7 @@ class LargeList:
Array5D,
Audio,
Image,
CustomFeature,
Video,
]

Expand Down Expand Up @@ -1267,19 +1295,20 @@ def get_nested_type(schema: FeatureType) -> pa.DataType:
return schema()


def encode_nested_example(schema, obj, level=0):
def encode_nested_example(schema, obj, is_nested: bool = False):
"""Encode a nested example.
This is used since some features (in particular ClassLabel) have some logic during encoding.

To avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be encoded.
If the first element needs to be encoded, then all the elements of the list will be encoded, otherwise they'll stay the same.
"""

# Nested structures: we allow dict, list/tuples, sequences
if isinstance(schema, dict):
if level == 0 and obj is None:
if not is_nested and obj is None:
raise ValueError("Got None but expected a dictionary instead")
return (
{k: encode_nested_example(schema[k], obj.get(k), level=level + 1) for k in schema}
{k: encode_nested_example(schema[k], obj.get(k), is_nested=True) for k in schema}
if obj is not None
else None
)
Expand All @@ -1295,9 +1324,10 @@ def encode_nested_example(schema, obj, level=0):
for first_elmt in obj:
if _check_non_null_non_empty_recursive(first_elmt, sub_schema):
break
if encode_nested_example(sub_schema, first_elmt, level=level + 1) != first_elmt:
return [encode_nested_example(sub_schema, o, level=level + 1) for o in obj]
if encode_nested_example(sub_schema, first_elmt, is_nested=True) != first_elmt:
return [encode_nested_example(sub_schema, o, is_nested=True) for o in obj]
return list(obj)

elif isinstance(schema, LargeList):
if obj is None:
return None
Expand All @@ -1307,8 +1337,8 @@ def encode_nested_example(schema, obj, level=0):
for first_elmt in obj:
if _check_non_null_non_empty_recursive(first_elmt, sub_schema):
break
if encode_nested_example(sub_schema, first_elmt, level=level + 1) != first_elmt:
return [encode_nested_example(sub_schema, o, level=level + 1) for o in obj]
if encode_nested_example(sub_schema, first_elmt, is_nested=True) != first_elmt:
return [encode_nested_example(sub_schema, o, is_nested=True) for o in obj]
return list(obj)
elif isinstance(schema, Sequence):
if obj is None:
Expand All @@ -1320,13 +1350,13 @@ def encode_nested_example(schema, obj, level=0):
if isinstance(obj, (list, tuple)):
# obj is a list of dict
for k in schema.feature:
list_dict[k] = [encode_nested_example(schema.feature[k], o.get(k), level=level + 1) for o in obj]
list_dict[k] = [encode_nested_example(schema.feature[k], o.get(k), is_nested=True) for o in obj]
return list_dict
else:
# obj is a single dict
for k in schema.feature:
list_dict[k] = (
[encode_nested_example(schema.feature[k], o, level=level + 1) for o in obj[k]]
[encode_nested_example(schema.feature[k], o, is_nested=True) for o in obj[k]]
if k in obj
else None
)
Expand All @@ -1342,14 +1372,18 @@ def encode_nested_example(schema, obj, level=0):
# be careful when comparing tensors here
if (
not isinstance(first_elmt, list)
or encode_nested_example(schema.feature, first_elmt, level=level + 1) != first_elmt
or encode_nested_example(schema.feature, first_elmt, is_nested=True) != first_elmt
):
return [encode_nested_example(schema.feature, o, level=level + 1) for o in obj]
return [encode_nested_example(schema.feature, o, is_nested=True) for o in obj]
return list(obj)
# Object with special encoding:
# ClassLabel will convert from string to int, TranslationVariableLanguages does some checks
elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD, Video)):
return schema.encode_example(obj) if obj is not None else None

# Custom features
elif isinstance(schema, CustomFeature) and schema.requires_encoding:
return schema.encode_example(obj) if obj is not None else None
# Other object should be directly convertible to a native Arrow type (like Translation and Translation)
return obj

Expand Down Expand Up @@ -1403,6 +1437,10 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
# we pass the token to read and decode files from private repositories in streaming mode
if obj is not None and schema.decode:
return schema.decode_example(obj, token_per_repo_id=token_per_repo_id)
# Custom features
elif isinstance(schema, CustomFeature) and schema.requires_decoding:
if obj is not None and schema.decode:
return schema.decode_example(obj, token_per_repo_id=token_per_repo_id)
return obj


Expand Down Expand Up @@ -1432,6 +1470,9 @@ def register_feature(
Register a Feature object using a name and class.
This function must be used on a Feature class.
"""
assert issubclass(
feature_cls, CustomFeature
), f"Custom feature class {feature_cls.__name__} must inherit from datasets.CustomFeature"
if feature_type in _FEATURE_TYPES:
logger.warning(
f"Overwriting feature type '{feature_type}' ({_FEATURE_TYPES[feature_type].__name__} -> {feature_cls.__name__})"
Expand Down Expand Up @@ -1628,6 +1669,8 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False
return require_decoding(feature.feature)
elif isinstance(feature, Sequence):
return require_decoding(feature.feature)
elif isinstance(feature, CustomFeature):
return feature.requires_decoding and (feature.decode if not ignore_decode_attribute else True)
else:
return hasattr(feature, "decode_example") and (feature.decode if not ignore_decode_attribute else True)

Expand Down