Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for custom feature encoding/decoding #7284

Merged
merged 2 commits into from
Nov 21, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,7 @@ def encode_nested_example(schema, obj, level=0):
return list(obj)
# Object with special encoding:
# ClassLabel will convert from string to int, TranslationVariableLanguages does some checks
elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD, Video)):
elif hasattr(schema, "encode_example"):
return schema.encode_example(obj) if obj is not None else None
# Other object should be directly convertible to a native Arrow type (like Translation and Translation)
return obj
Expand Down Expand Up @@ -1399,10 +1399,9 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
else:
return decode_nested_example([schema.feature], obj)
# Object with special decoding:
elif isinstance(schema, (Audio, Image, Video)):
elif hasattr(schema, "decode_example") and getattr(schema, "decode", True):
# we pass the token to read and decode files from private repositories in streaming mode
if obj is not None and schema.decode:
return schema.decode_example(obj, token_per_repo_id=token_per_repo_id)
return schema.decode_example(obj, token_per_repo_id=token_per_repo_id)
lhoestq marked this conversation as resolved.
Show resolved Hide resolved
return obj


Expand Down Expand Up @@ -1629,7 +1628,9 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False
elif isinstance(feature, Sequence):
return require_decoding(feature.feature)
else:
return hasattr(feature, "decode_example") and (feature.decode if not ignore_decode_attribute else True)
return hasattr(feature, "decode_example") and (
getattr(feature, "decode", True) if not ignore_decode_attribute else True
)


def require_storage_cast(feature: FeatureType) -> bool:
Expand Down
Loading