From a8cd9393116c1b9a6283b46e4926b7ecd145f36b Mon Sep 17 00:00:00 2001 From: Polina Kazakova Date: Mon, 13 May 2024 14:00:59 +0200 Subject: [PATCH] Fix image stats when image column type is bytes, not struct (#2793) * check for pa type of audio/image column and pass bytes directly to processing func * include null type * update test * check with isinstance instead of determining pa type --------- Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- .../split/descriptive_statistics.py | 20 +++++++++----- .../descriptive_statistics_dataset.py | 12 ++++----- .../split/test_descriptive_statistics.py | 27 +++++++++++++++++++ 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/services/worker/src/worker/job_runners/split/descriptive_statistics.py b/services/worker/src/worker/job_runners/split/descriptive_statistics.py index 183ea7f989..e15fd8636b 100644 --- a/services/worker/src/worker/job_runners/split/descriptive_statistics.py +++ b/services/worker/src/worker/job_runners/split/descriptive_statistics.py @@ -583,7 +583,7 @@ class MediaColumn(Column): transform_column: type[Column] @classmethod - def transform(cls, example: dict[str, Any]) -> Any: + def transform(cls, example: Optional[Union[bytes, dict[str, Any]]]) -> Any: """ Function to use to transform the original values to further pass these transformed values to statistics computation. Used inside ._compute_statistics() method. @@ -656,13 +656,16 @@ class AudioColumn(MediaColumn): transform_column = FloatColumn @staticmethod - def get_duration(example: dict[str, Any]) -> float: + def get_duration(example: Optional[Union[bytes, dict[str, Any]]]) -> Optional[float]: """Get audio durations""" - with io.BytesIO(example["bytes"]) as f: + if example is None: + return None + example_bytes = example["bytes"] if isinstance(example, dict) else example + with io.BytesIO(example_bytes) as f: return librosa.get_duration(path=f) # type: ignore # expects PathLike but BytesIO also works @classmethod - def transform(cls, example: dict[str, Any]) -> float: + def transform(cls, example: Optional[Union[bytes, dict[str, Any]]]) -> Optional[float]: return cls.get_duration(example) @@ -670,14 +673,17 @@ class ImageColumn(MediaColumn): transform_column = IntColumn @staticmethod - def get_width(example: dict[str, Any]) -> int: + def get_width(example: Optional[Union[bytes, dict[str, Any]]]) -> Optional[int]: """Get image widths.""" - with io.BytesIO(example["bytes"]) as f: + if example is None: + return None + example_bytes = example["bytes"] if isinstance(example, dict) else example + with io.BytesIO(example_bytes) as f: image = Image.open(f) return image.size[0] @classmethod - def transform(cls, example: dict[str, Any]) -> int: + def transform(cls, example: Optional[Union[bytes, dict[str, Any]]]) -> Optional[int]: return cls.get_width(example) diff --git a/services/worker/tests/fixtures/descriptive_statistics_dataset.py b/services/worker/tests/fixtures/descriptive_statistics_dataset.py index 307d4c56dc..6e6984a3ad 100644 --- a/services/worker/tests/fixtures/descriptive_statistics_dataset.py +++ b/services/worker/tests/fixtures/descriptive_statistics_dataset.py @@ -1640,9 +1640,9 @@ def nan_column() -> list[None]: }, features=Features( { - "audio": Audio(sampling_rate=1600), - "audio_nan": Audio(sampling_rate=1600), - "audio_all_nan": Audio(sampling_rate=1600), + "audio": Audio(sampling_rate=1600, decode=False), + "audio_nan": Audio(sampling_rate=1600, decode=False), + "audio_all_nan": Audio(sampling_rate=1600, decode=False), } ), ) @@ -1666,9 +1666,9 @@ def nan_column() -> list[None]: }, features=Features( { - "image": Image(), - "image_nan": Image(), - "image_all_nan": Image(), + "image": Image(decode=False), + "image_nan": Image(decode=False), + "image_all_nan": Image(decode=False), } ), ) diff --git a/services/worker/tests/job_runners/split/test_descriptive_statistics.py b/services/worker/tests/job_runners/split/test_descriptive_statistics.py index ea66c7ad58..4ce0eec47d 100644 --- a/services/worker/tests/job_runners/split/test_descriptive_statistics.py +++ b/services/worker/tests/job_runners/split/test_descriptive_statistics.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd import polars as pl +import pyarrow as pa import pyarrow.parquet as pq import pytest from datasets import ClassLabel, Dataset @@ -799,6 +800,19 @@ def test_audio_statistics( ) assert computed == expected + # write samples as bytes, not as struct {"bytes": b"", "path": ""} + audios = datasets["audio_statistics"][column_name][:] + pa_table_bytes = pa.Table.from_pydict( + {column_name: [open(audio["path"], "rb").read() if audio else None for audio in audios]} + ) + pq.write_table(pa_table_bytes, parquet_filename) + computed = AudioColumn.compute_statistics( + parquet_directory=parquet_directory, + column_name=column_name, + n_samples=4, + ) + assert computed == expected + @pytest.mark.parametrize( "column_name", @@ -823,6 +837,19 @@ def test_image_statistics( ) assert computed == expected + # write samples as bytes, not as struct {"bytes": b"", "path": ""} + images = datasets["image_statistics"][column_name][:] + pa_table_bytes = pa.Table.from_pydict( + {column_name: [open(image["path"], "rb").read() if image else None for image in images]} + ) + pq.write_table(pa_table_bytes, parquet_filename) + computed = ImageColumn.compute_statistics( + parquet_directory=parquet_directory, + column_name=column_name, + n_samples=4, + ) + assert computed == expected + @pytest.mark.parametrize( "hub_dataset_name,expected_error_code",