Skip to content

Commit

Permalink
Fix image stats when image column type is bytes, not struct (#2793)
Browse files Browse the repository at this point in the history
* check for pa type of audio/image column and pass bytes directly to processing func

* include null type

* update test

* check with isinstance instead of determining pa type

---------

Co-authored-by: Quentin Lhoest <[email protected]>
  • Loading branch information
polinaeterna and lhoestq authored May 13, 2024
1 parent 6fb30a0 commit a8cd939
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ class MediaColumn(Column):
transform_column: type[Column]

@classmethod
def transform(cls, example: dict[str, Any]) -> Any:
def transform(cls, example: Optional[Union[bytes, dict[str, Any]]]) -> Any:
"""
Function to use to transform the original values to further pass these transformed values to statistics
computation. Used inside ._compute_statistics() method.
Expand Down Expand Up @@ -656,28 +656,34 @@ class AudioColumn(MediaColumn):
transform_column = FloatColumn

@staticmethod
def get_duration(example: dict[str, Any]) -> float:
def get_duration(example: Optional[Union[bytes, dict[str, Any]]]) -> Optional[float]:
"""Get audio durations"""
with io.BytesIO(example["bytes"]) as f:
if example is None:
return None
example_bytes = example["bytes"] if isinstance(example, dict) else example
with io.BytesIO(example_bytes) as f:
return librosa.get_duration(path=f) # type: ignore # expects PathLike but BytesIO also works

@classmethod
def transform(cls, example: dict[str, Any]) -> float:
def transform(cls, example: Optional[Union[bytes, dict[str, Any]]]) -> Optional[float]:
return cls.get_duration(example)


class ImageColumn(MediaColumn):
transform_column = IntColumn

@staticmethod
def get_width(example: dict[str, Any]) -> int:
def get_width(example: Optional[Union[bytes, dict[str, Any]]]) -> Optional[int]:
"""Get image widths."""
with io.BytesIO(example["bytes"]) as f:
if example is None:
return None
example_bytes = example["bytes"] if isinstance(example, dict) else example
with io.BytesIO(example_bytes) as f:
image = Image.open(f)
return image.size[0]

@classmethod
def transform(cls, example: dict[str, Any]) -> int:
def transform(cls, example: Optional[Union[bytes, dict[str, Any]]]) -> Optional[int]:
return cls.get_width(example)


Expand Down
12 changes: 6 additions & 6 deletions services/worker/tests/fixtures/descriptive_statistics_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,9 +1640,9 @@ def nan_column() -> list[None]:
},
features=Features(
{
"audio": Audio(sampling_rate=1600),
"audio_nan": Audio(sampling_rate=1600),
"audio_all_nan": Audio(sampling_rate=1600),
"audio": Audio(sampling_rate=1600, decode=False),
"audio_nan": Audio(sampling_rate=1600, decode=False),
"audio_all_nan": Audio(sampling_rate=1600, decode=False),
}
),
)
Expand All @@ -1666,9 +1666,9 @@ def nan_column() -> list[None]:
},
features=Features(
{
"image": Image(),
"image_nan": Image(),
"image_all_nan": Image(),
"image": Image(decode=False),
"image_nan": Image(decode=False),
"image_all_nan": Image(decode=False),
}
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from datasets import ClassLabel, Dataset
Expand Down Expand Up @@ -799,6 +800,19 @@ def test_audio_statistics(
)
assert computed == expected

# write samples as bytes, not as struct {"bytes": b"", "path": ""}
audios = datasets["audio_statistics"][column_name][:]
pa_table_bytes = pa.Table.from_pydict(
{column_name: [open(audio["path"], "rb").read() if audio else None for audio in audios]}
)
pq.write_table(pa_table_bytes, parquet_filename)
computed = AudioColumn.compute_statistics(
parquet_directory=parquet_directory,
column_name=column_name,
n_samples=4,
)
assert computed == expected


@pytest.mark.parametrize(
"column_name",
Expand All @@ -823,6 +837,19 @@ def test_image_statistics(
)
assert computed == expected

# write samples as bytes, not as struct {"bytes": b"", "path": ""}
images = datasets["image_statistics"][column_name][:]
pa_table_bytes = pa.Table.from_pydict(
{column_name: [open(image["path"], "rb").read() if image else None for image in images]}
)
pq.write_table(pa_table_bytes, parquet_filename)
computed = ImageColumn.compute_statistics(
parquet_directory=parquet_directory,
column_name=column_name,
n_samples=4,
)
assert computed == expected


@pytest.mark.parametrize(
"hub_dataset_name,expected_error_code",
Expand Down

0 comments on commit a8cd939

Please sign in to comment.