Skip to content

Commit

Permalink
Fix WebDatasets KeyError for user-defined Features when a field is mi…
Browse files Browse the repository at this point in the history
…ssing in an example (#7004)

* Fix KeyError bug

* Add additional check

Co-authored-by: Quentin Lhoest <[email protected]>

* Add test for missing key handling

* update test

---------

Co-authored-by: Quentin Lhoest <[email protected]>
Co-authored-by: Quentin Lhoest <[email protected]>
  • Loading branch information
3 people authored Jun 28, 2024
1 parent a16477d commit 83d2860
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/datasets/packaged_modules/webdataset/webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,18 @@ def _generate_examples(self, tar_paths, tar_iterators):
audio_field_names = [
field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Audio)
]
all_field_names = list(self.info.features.keys())
for tar_idx, (tar_path, tar_iterator) in enumerate(zip(tar_paths, tar_iterators)):
for example_idx, example in enumerate(self._get_pipeline_from_tar(tar_path, tar_iterator)):
for field_name in all_field_names:
if field_name not in example:
example[field_name] = None
for field_name in image_field_names + audio_field_names:
example[field_name] = {"path": example["__key__"] + "." + field_name, "bytes": example[field_name]}
if example[field_name] is not None:
example[field_name] = {
"path": example["__key__"] + "." + field_name,
"bytes": example[field_name],
}
yield f"{tar_idx}_{example_idx}", example


Expand Down
31 changes: 31 additions & 0 deletions tests/packaged_modules/test_webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,37 @@ def test_image_webdataset(image_wds_file):
assert isinstance(decoded["jpg"], PIL.Image.Image)


@require_pil
def test_image_webdataset_missing_keys(image_wds_file):
import PIL.Image

data_files = {"train": [image_wds_file]}
features = Features(
{
"__key__": Value("string"),
"__url__": Value("string"),
"json": {"caption": Value("string")},
"jpg": Image(),
"jpeg": Image(), # additional field
"txt": Value("string"), # additional field
}
)
webdataset = WebDataset(data_files=data_files, features=features)
split_generators = webdataset._split_generators(DownloadManager())
assert webdataset.info.features == features
split_generator = split_generators[0]
assert split_generator.name == "train"
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
_, example = next(iter(generator))
encoded = webdataset.info.features.encode_example(example)
decoded = webdataset.info.features.decode_example(encoded)
assert isinstance(decoded["json"], dict)
assert isinstance(decoded["json"]["caption"], str)
assert isinstance(decoded["jpg"], PIL.Image.Image)
assert decoded["jpeg"] is None
assert decoded["txt"] is None


@require_sndfile
def test_audio_webdataset(audio_wds_file):
data_files = {"train": [audio_wds_file]}
Expand Down

0 comments on commit 83d2860

Please sign in to comment.