diff --git a/docs/source/openapi.json b/docs/source/openapi.json
index 7079f20874..a2f1891efc 100644
--- a/docs/source/openapi.json
+++ b/docs/source/openapi.json
@@ -1084,7 +1084,7 @@
},
"ColumnType": {
"type": "string",
- "enum": ["float", "int", "class_label", "string_label", "string_text", "bool", "list"]
+ "enum": ["float", "int", "class_label", "string_label", "string_text", "bool", "list", "audio"]
},
"Histogram": {
"type": "object",
@@ -6128,6 +6128,176 @@
}
]
}
+ },
+ "A split (MLCommons/peoples_speech) with audio column": {
+ "summary": "Statistics on an audio column 'audio'.",
+ "description": "Try with https://datasets-server.huggingface.co/statistics?dataset=MLCommons/peoples_speech&config=validation&split=validation.",
+ "value": {
+ "num_examples": 18622,
+ "statistics": [
+ {
+ "column_name": "audio",
+ "column_type": "audio",
+ "column_statistics": {
+ "nan_count": 0,
+ "nan_proportion": 0.0,
+ "min": 0.653,
+ "max": 105.97,
+ "mean": 6.41103,
+ "median": 4.8815,
+ "std": 5.63269,
+ "histogram": {
+ "hist": [
+ 15867,
+ 2319,
+ 350,
+ 67,
+ 12,
+ 5,
+ 0,
+ 1,
+ 0,
+ 1
+ ],
+ "bin_edges": [
+ 0.653,
+ 11.1847,
+ 21.7164,
+ 32.2481,
+ 42.7798,
+ 53.3115,
+ 63.8432,
+ 74.3749,
+ 84.9066,
+ 95.4383,
+ 105.97
+ ]
+ }
+ }
+ },
+ {
+ "column_name": "duration_ms",
+ "column_type": "int",
+ "column_statistics": {
+ "nan_count": 0,
+ "nan_proportion": 0.0,
+ "min": 833,
+ "max": 105970,
+ "mean": 6411.06079,
+ "median": 4881.5,
+ "std": 5632.67057,
+ "histogram": {
+ "hist": [
+ 15950,
+ 2244,
+ 345,
+ 64,
+ 12,
+ 5,
+ 0,
+ 1,
+ 0,
+ 1
+ ],
+ "bin_edges": [
+ 833,
+ 11347,
+ 21861,
+ 32375,
+ 42889,
+ 53403,
+ 63917,
+ 74431,
+ 84945,
+ 95459,
+ 105970
+ ]
+ }
+ }
+ },
+ {
+ "column_name": "id",
+ "column_type": "string_text",
+ "column_statistics": {
+ "nan_count": 0,
+ "nan_proportion": 0.0,
+ "min": 43,
+ "max": 197,
+ "mean": 120.06675,
+ "median": 136.0,
+ "std": 44.49607,
+ "histogram": {
+ "hist": [
+ 3599,
+ 939,
+ 278,
+ 1914,
+ 1838,
+ 1646,
+ 4470,
+ 1443,
+ 1976,
+ 519
+ ],
+ "bin_edges": [
+ 43,
+ 59,
+ 75,
+ 91,
+ 107,
+ 123,
+ 139,
+ 155,
+ 171,
+ 187,
+ 197
+ ]
+ }
+ }
+ },
+ {
+ "column_name": "text",
+ "column_type": "string_text",
+ "column_statistics": {
+ "nan_count": 0,
+ "nan_proportion": 0.0,
+ "min": 1,
+ "max": 1219,
+ "mean": 94.52873,
+ "median": 75.0,
+ "std": 79.11078,
+ "histogram": {
+ "hist": [
+ 13703,
+ 3975,
+ 744,
+ 146,
+ 36,
+ 10,
+ 5,
+ 1,
+ 1,
+ 1
+ ],
+ "bin_edges": [
+ 1,
+ 123,
+ 245,
+ 367,
+ 489,
+ 611,
+ 733,
+ 855,
+ 977,
+ 1099,
+ 1219
+ ]
+ }
+ }
+ }
+ ],
+ "partial": false
+ }
}
}
}
diff --git a/docs/source/statistics.md b/docs/source/statistics.md
index dbbe137c84..16d2eaa4bc 100644
--- a/docs/source/statistics.md
+++ b/docs/source/statistics.md
@@ -165,16 +165,18 @@ The response JSON contains three keys:
## Response structure by data type
-Currently, statistics are supported for strings, float and integer numbers, and the special [`datasets.ClassLabel`](https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.ClassLabel) feature type of the [`datasets`](https://huggingface.co/docs/datasets/) library.
+Currently, statistics are supported for strings, float and integer numbers, lists, audio data and the special [`datasets.ClassLabel`](https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.ClassLabel) feature type of the [`datasets`](https://huggingface.co/docs/datasets/) library.
`column_type` in response can be one of the following values:
-* `class_label` - for [`datasets.ClassLabel`](https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.ClassLabel) feature
-* `float` - for float dtypes
-* `int` - for integer dtypes
-* `bool` - for boolean dtype
-* `string_label` - for string dtypes being treated as categories (see below)
-* `string_text` - for string dtypes if they do not represent categories (see below)
+* `class_label` - for [`datasets.ClassLabel`](https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.ClassLabel) feature which represents categorical data
+* `float` - for float data types
+* `int` - for integer data types
+* `bool` - for boolean data type
+* `string_label` - for string data types being treated as categories (see below)
+* `string_text` - for string data types if they do not represent categories (see below)
+* `list` - for lists of any other data types (including lists)
+* `audio` - for audio data
### `class_label`
@@ -426,3 +428,107 @@ If string column does not satisfy the conditions to be treated as a `string_labe
+
+### list
+
+For lists, the distribution of their lengths is computed. The following measures are returned:
+
+* minimum, maximum, mean, and standard deviation of lists lengths
+* number and proportion of `null` values
+* histogram of lists lengths with up to 10 bins
+
+Example
+
+
+```json
+{
+ "column_name": "chat_history",
+ "column_type": "list",
+ "column_statistics": {
+ "nan_count": 0,
+ "nan_proportion": 0.0,
+ "min": 1,
+ "max": 3,
+ "mean": 1.01741,
+ "median": 1.0,
+ "std": 0.13146,
+ "histogram": {
+ "hist": [
+ 11177,
+ 196,
+ 1
+ ],
+ "bin_edges": [
+ 1,
+ 2,
+ 3,
+ 3
+ ]
+ }
+ }
+}
+```
+
+
+
+
+Note that dictionaries of lists are not supported.
+
+
+### audio
+
+For audio data, the distribution of audio files durations is computed. The following measures are returned:
+
+* minimum, maximum, mean, and standard deviation of audio files durations
+* number and proportion of `null` values
+* histogram of audio files durations with 10 bins
+
+
+Example
+
+
+```json
+{
+ "column_name": "audio",
+ "column_type": "audio",
+ "column_statistics": {
+ "nan_count": 0,
+ "nan_proportion": 0,
+ "min": 1.02,
+ "max": 15,
+ "mean": 13.93042,
+ "median": 14.77,
+ "std": 2.63734,
+ "histogram": {
+ "hist": [
+ 32,
+ 25,
+ 18,
+ 24,
+ 22,
+ 17,
+ 18,
+ 19,
+ 55,
+ 1770
+ ],
+ "bin_edges": [
+ 1.02,
+ 2.418,
+ 3.816,
+ 5.214,
+ 6.612,
+ 8.01,
+ 9.408,
+ 10.806,
+ 12.204,
+ 13.602,
+ 15
+ ]
+ }
+ }
+}
+```
+
+
+
\ No newline at end of file
diff --git a/services/worker/README.md b/services/worker/README.md
index e7535af786..27a7437f6b 100644
--- a/services/worker/README.md
+++ b/services/worker/README.md
@@ -117,6 +117,8 @@ The response has three fields: `num_examples`, `statistics`, and `partial`. `par
* `string_label` - for string dtypes ("string", "large_string") - if there are less than or equal to `MAX_NUM_STRING_LABELS` unique values (hardcoded in worker's code, for now it's 30)
* `string_text` - for string dtypes ("string", "large_string") - if there are more than `MAX_NUM_STRING_LABELS` unique values
* `bool` - for boolean dtype ("bool")
+* `list` - for lists of other data types (including lists)
+* `audio` - for audio data
`column_statistics` content depends on the feature type, see examples below.
##### class_label
@@ -450,6 +452,96 @@ If a string column doesn't satisfy the conditions to be considered a category (s
+##### list
+
+Show distribution of lists lengths. Note: dictionaries of lists are not supported (only lists of dictionaries).
+
+example:
+
+
+```python
+{
+ "column_name": "list_col",
+ "column_type": "list",
+ "column_statistics": {
+ "nan_count": 0,
+ "nan_proportion": 0.0,
+ "min": 1,
+ "max": 3,
+ "mean": 1.01741,
+ "median": 1.0,
+ "std": 0.13146,
+ "histogram": {
+ "hist": [
+ 11177,
+ 196,
+ 1
+ ],
+ "bin_edges": [
+ 1,
+ 2,
+ 3,
+ 3
+ ]
+ }
+ }
+}
+```
+
+
+
+##### audio
+
+Shows distribution of audio files durations.
+
+example:
+
+
+```python
+{
+ "column_name": "audio_col",
+ "column_type": "audio",
+ "column_statistics": {
+ "nan_count": 0,
+ "nan_proportion": 0,
+ "min": 1.02,
+ "max": 15,
+ "mean": 13.93042,
+ "median": 14.77,
+ "std": 2.63734,
+ "histogram": {
+ "hist": [
+ 32,
+ 25,
+ 18,
+ 24,
+ 22,
+ 17,
+ 18,
+ 19,
+ 55,
+ 1770
+ ],
+ "bin_edges": [
+ 1.02,
+ 2.418,
+ 3.816,
+ 5.214,
+ 6.612,
+ 8.01,
+ 9.408,
+ 10.806,
+ 12.204,
+ 13.602,
+ 15
+ ]
+ }
+ }
+}
+```
+
+
+
### Splits worker
diff --git a/services/worker/src/worker/job_runners/split/descriptive_statistics.py b/services/worker/src/worker/job_runners/split/descriptive_statistics.py
index f2c9929989..c8c7f83e0f 100644
--- a/services/worker/src/worker/job_runners/split/descriptive_statistics.py
+++ b/services/worker/src/worker/job_runners/split/descriptive_statistics.py
@@ -3,13 +3,16 @@
import enum
import functools
+import io
import logging
from collections import Counter
from pathlib import Path
-from typing import Any, Optional, Protocol, TypedDict, Union
+from typing import Any, Callable, Optional, TypedDict, Union
+import librosa
import numpy as np
import polars as pl
+import pyarrow.parquet as pq
from datasets import Features
from libcommon.dtos import JobInfo
from libcommon.exceptions import (
@@ -30,6 +33,7 @@
from libcommon.storage import StrPath
from libcommon.utils import download_file_from_hub
from polars import List
+from tqdm.contrib.concurrent import thread_map
from worker.config import AppConfig, DescriptiveStatisticsConfig
from worker.dtos import CompleteJobResult
@@ -66,6 +70,7 @@ class ColumnType(str, enum.Enum):
CLASS_LABEL = "class_label"
STRING_LABEL = "string_label"
STRING_TEXT = "string_text"
+ AUDIO = "audio"
class Histogram(TypedDict):
@@ -263,29 +268,35 @@ def nan_count_proportion(data: pl.DataFrame, column_name: str, n_samples: int) -
return nan_count, nan_proportion
-class _ComputeStatisticsFuncT(Protocol):
- def __call__(self, data: pl.DataFrame, column_name: str, n_samples: int, *args: Any, **kwargs: Any) -> Any:
- ...
-
-
-def raise_with_column_name(func: _ComputeStatisticsFuncT) -> _ComputeStatisticsFuncT:
+def raise_with_column_name(func: Callable) -> Callable: # type: ignore
"""
Wraps error from Column._compute_statistics() so that we always keep information about which
column caused an error.
"""
@functools.wraps(func)
- def _compute_statistics_wrapper(
- data: pl.DataFrame, column_name: str, n_samples: int, *args: Any, **kwargs: Any
- ) -> Any:
+ def _compute_statistics_wrapper(*args: Any, column_name: str, **kwargs: Any) -> Any:
try:
- return func(data, column_name, n_samples, *args, **kwargs)
+ return func(*args, column_name=column_name, **kwargs)
except Exception as error:
raise StatisticsComputationError(f"Error for column={column_name}: {error=}", error)
return _compute_statistics_wrapper
+def all_nan_statistics_item(n_samples: int) -> NumericalStatisticsItem:
+ return NumericalStatisticsItem(
+ nan_count=n_samples,
+ nan_proportion=1.0,
+ min=None,
+ max=None,
+ mean=None,
+ median=None,
+ std=None,
+ histogram=None,
+ )
+
+
class Column:
"""Abstract class to compute stats for columns of all supported data types."""
@@ -299,15 +310,12 @@ def __init__(
@staticmethod
def _compute_statistics(
- data: pl.DataFrame,
- column_name: str,
- n_samples: int,
*args: Any,
**kwargs: Any,
) -> SupportedStatistics:
raise NotImplementedError
- def compute_and_prepare_response(self, data: pl.DataFrame) -> StatisticsPerColumnItem:
+ def compute_and_prepare_response(self, *args: Any, **kwargs: Any) -> StatisticsPerColumnItem:
raise NotImplementedError
@@ -355,7 +363,9 @@ def _compute_statistics(
)
def compute_and_prepare_response(self, data: pl.DataFrame) -> StatisticsPerColumnItem:
- stats = self._compute_statistics(data, self.name, self.n_samples, feature_dict=self.feature_dict)
+ stats = self._compute_statistics(
+ data, column_name=self.name, n_samples=self.n_samples, feature_dict=self.feature_dict
+ )
return StatisticsPerColumnItem(
column_name=self.name,
column_type=ColumnType.CLASS_LABEL,
@@ -376,16 +386,8 @@ def _compute_statistics(
logging.info(f"Compute statistics for float column {column_name} with polars. ")
nan_count, nan_proportion = nan_count_proportion(data, column_name, n_samples)
if nan_count == n_samples: # all values are None
- return NumericalStatisticsItem(
- nan_count=n_samples,
- nan_proportion=1.0,
- min=None,
- max=None,
- mean=None,
- median=None,
- std=None,
- histogram=None,
- )
+ return all_nan_statistics_item(n_samples)
+
minimum, maximum, mean, median, std = min_max_mean_median_std(data, column_name)
logging.debug(f"{minimum=}, {maximum=}, {mean=}, {median=}, {std=}, {nan_count=} {nan_proportion=}")
@@ -411,7 +413,7 @@ def _compute_statistics(
)
def compute_and_prepare_response(self, data: pl.DataFrame) -> StatisticsPerColumnItem:
- stats = self._compute_statistics(data, self.name, self.n_samples, n_bins=self.n_bins)
+ stats = self._compute_statistics(data, column_name=self.name, n_samples=self.n_samples, n_bins=self.n_bins)
return StatisticsPerColumnItem(
column_name=self.name,
column_type=ColumnType.FLOAT,
@@ -432,16 +434,7 @@ def _compute_statistics(
logging.info(f"Compute statistics for integer column {column_name} with polars. ")
nan_count, nan_proportion = nan_count_proportion(data, column_name, n_samples=n_samples)
if nan_count == n_samples:
- return NumericalStatisticsItem(
- nan_count=n_samples,
- nan_proportion=1.0,
- min=None,
- max=None,
- mean=None,
- median=None,
- std=None,
- histogram=None,
- )
+ return all_nan_statistics_item(n_samples)
minimum, maximum, mean, median, std = min_max_mean_median_std(data, column_name)
logging.debug(f"{minimum=}, {maximum=}, {mean=}, {median=}, {std=}, {nan_count=} {nan_proportion=}")
@@ -469,7 +462,7 @@ def _compute_statistics(
)
def compute_and_prepare_response(self, data: pl.DataFrame) -> StatisticsPerColumnItem:
- stats = self._compute_statistics(data, self.name, self.n_samples, n_bins=self.n_bins)
+ stats = self._compute_statistics(data, column_name=self.name, n_samples=self.n_samples, n_bins=self.n_bins)
return StatisticsPerColumnItem(
column_name=self.name,
column_type=ColumnType.INT,
@@ -511,12 +504,12 @@ def _compute_statistics(
pl.col(column_name).str.len_chars().alias(lengths_column_name)
)
lengths_stats: NumericalStatisticsItem = IntColumn._compute_statistics(
- lengths_df, lengths_column_name, n_bins=n_bins, n_samples=n_samples
+ lengths_df, column_name=lengths_column_name, n_bins=n_bins, n_samples=n_samples
)
return lengths_stats
def compute_and_prepare_response(self, data: pl.DataFrame) -> StatisticsPerColumnItem:
- stats = self._compute_statistics(data, self.name, self.n_samples, n_bins=self.n_bins)
+ stats = self._compute_statistics(data, column_name=self.name, n_samples=self.n_samples, n_bins=self.n_bins)
string_type = ColumnType.STRING_LABEL if "frequencies" in stats else ColumnType.STRING_TEXT
return StatisticsPerColumnItem(
column_name=self.name,
@@ -543,7 +536,7 @@ def _compute_statistics(data: pl.DataFrame, column_name: str, n_samples: int) ->
)
def compute_and_prepare_response(self, data: pl.DataFrame) -> StatisticsPerColumnItem:
- stats = self._compute_statistics(data, self.name, self.n_samples)
+ stats = self._compute_statistics(data, column_name=self.name, n_samples=self.n_samples)
return StatisticsPerColumnItem(
column_name=self.name,
column_type=ColumnType.BOOL,
@@ -564,22 +557,13 @@ def _compute_statistics(
logging.info(f"Compute statistics for list/Sequence column {column_name} with polars. ")
nan_count, nan_proportion = nan_count_proportion(data, column_name, n_samples)
if nan_count == n_samples:
- return NumericalStatisticsItem(
- nan_count=n_samples,
- nan_proportion=1.0,
- min=None,
- max=None,
- mean=None,
- median=None,
- std=None,
- histogram=None,
- )
- df_without_na = data.select(pl.col(column_name)).drop_nulls()
+ return all_nan_statistics_item(n_samples)
+ df_without_na = data.select(pl.col(column_name)).drop_nulls()
lengths_column_name = f"{column_name}_len"
lengths_df = df_without_na.with_columns(pl.col(column_name).list.len().alias(lengths_column_name))
lengths_stats = IntColumn._compute_statistics(
- lengths_df, lengths_column_name, n_bins=n_bins, n_samples=n_samples - nan_count
+ lengths_df, column_name=lengths_column_name, n_bins=n_bins, n_samples=n_samples - nan_count
)
return NumericalStatisticsItem(
@@ -594,7 +578,7 @@ def _compute_statistics(
)
def compute_and_prepare_response(self, data: pl.DataFrame) -> StatisticsPerColumnItem:
- stats = self._compute_statistics(data, self.name, self.n_samples, n_bins=self.n_bins)
+ stats = self._compute_statistics(data, column_name=self.name, n_samples=self.n_samples, n_bins=self.n_bins)
return StatisticsPerColumnItem(
column_name=self.name,
column_type=ColumnType.LIST,
@@ -602,7 +586,76 @@ def compute_and_prepare_response(self, data: pl.DataFrame) -> StatisticsPerColum
)
-SupportedColumns = Union[ClassLabelColumn, IntColumn, FloatColumn, StringColumn, BoolColumn, ListColumn]
+class AudioColumn(Column):
+ def __init__(self, *args: Any, n_bins: int, **kwargs: Any):
+ super().__init__(*args, **kwargs)
+ self.n_bins = n_bins
+
+ @staticmethod
+ def get_duration(example: dict[str, Any]) -> float:
+ with io.BytesIO(example["bytes"]) as f:
+ return librosa.get_duration(path=f) # type: ignore # expects PathLike but BytesIO also works
+
+ @staticmethod
+ @raise_with_column_name
+ def _compute_statistics(
+ parquet_directory: Path,
+ column_name: str,
+ n_samples: int,
+ n_bins: int,
+ ) -> NumericalStatisticsItem:
+ logging.info(f"Compute statistics for Audio column {column_name} with librosa and polars. ")
+ parquet_files = list(parquet_directory.glob("*.parquet"))
+ durations = []
+ for filename in parquet_files:
+ shard_audios = pq.read_table(filename, columns=[column_name]).drop_null().to_pydict()[column_name]
+ shard_durations = (
+ thread_map(
+ AudioColumn.get_duration,
+ shard_audios,
+ desc=f"Computing durations of audio for {filename.name}",
+ leave=False,
+ )
+ if shard_audios
+ else []
+ )
+ durations.extend(shard_durations)
+
+ if not durations:
+ return all_nan_statistics_item(n_samples)
+
+ nan_count = n_samples - len(durations)
+ nan_proportion = np.round(nan_count / n_samples, DECIMALS).item() if nan_count != 0 else 0.0
+ duration_df = pl.from_dict({column_name: durations})
+ duration_stats: NumericalStatisticsItem = FloatColumn._compute_statistics(
+ data=duration_df,
+ column_name=column_name,
+ n_samples=len(durations),
+ n_bins=n_bins,
+ )
+ return NumericalStatisticsItem(
+ nan_count=nan_count,
+ nan_proportion=nan_proportion,
+ min=duration_stats["min"],
+ max=duration_stats["max"],
+ mean=duration_stats["mean"],
+ median=duration_stats["median"],
+ std=duration_stats["std"],
+ histogram=duration_stats["histogram"],
+ )
+
+ def compute_and_prepare_response(self, parquet_directory: Path) -> StatisticsPerColumnItem:
+ stats = self._compute_statistics(
+ parquet_directory=parquet_directory, column_name=self.name, n_samples=self.n_samples, n_bins=self.n_bins
+ )
+ return StatisticsPerColumnItem(
+ column_name=self.name,
+ column_type=ColumnType.AUDIO,
+ column_statistics=stats,
+ )
+
+
+SupportedColumns = Union[ClassLabelColumn, IntColumn, FloatColumn, StringColumn, BoolColumn, ListColumn, AudioColumn]
def compute_descriptive_statistics_response(
@@ -720,10 +773,11 @@ def compute_descriptive_statistics_response(
resume_download=False,
)
- local_parquet_glob_path = Path(local_parquet_directory) / config / f"{split_directory}/*.parquet"
+ local_parquet_split_directory = Path(local_parquet_directory) / config / split_directory
+ local_parquet_split_glob = local_parquet_split_directory / "*.parquet"
num_examples = pl.read_parquet(
- local_parquet_glob_path, columns=[pl.scan_parquet(local_parquet_glob_path).columns[0]]
+ local_parquet_split_glob, columns=[pl.scan_parquet(local_parquet_split_glob).columns[0]]
).shape[0]
def _column_from_feature(
@@ -732,7 +786,7 @@ def _column_from_feature(
if isinstance(dataset_feature, list) or (
isinstance(dataset_feature, dict) and dataset_feature.get("_type") == "Sequence"
):
- schema = pl.scan_parquet(local_parquet_glob_path).schema[dataset_feature_name]
+ schema = pl.scan_parquet(local_parquet_split_glob).schema[dataset_feature_name]
# Compute only if it's internally a List! because it can also be Struct, see
# https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Features
if isinstance(schema, List):
@@ -744,6 +798,11 @@ def _column_from_feature(
feature_name=dataset_feature_name, n_samples=num_examples, feature_dict=dataset_feature
)
+ if dataset_feature.get("_type") == "Audio":
+ return AudioColumn(
+ feature_name=dataset_feature_name, n_samples=num_examples, n_bins=histogram_num_bins
+ )
+
if dataset_feature.get("_type") == "Value":
if dataset_feature.get("dtype") in INTEGER_DTYPES:
return IntColumn(
@@ -784,13 +843,17 @@ def _column_from_feature(
)
for column in columns:
- try:
- data = pl.read_parquet(local_parquet_glob_path, columns=[column.name])
- except Exception as error:
- raise PolarsParquetReadError(
- f"Error reading parquet file(s) at {local_parquet_glob_path=}, columns=[{column.name}]: {error}", error
- )
- column_stats = column.compute_and_prepare_response(data)
+ if isinstance(column, AudioColumn):
+ column_stats = column.compute_and_prepare_response(local_parquet_split_directory)
+ else:
+ try:
+ data = pl.read_parquet(local_parquet_split_glob, columns=[column.name])
+ except Exception as error:
+ raise PolarsParquetReadError(
+ f"Error reading parquet file(s) at {local_parquet_split_glob=}, columns=[{column.name}]: {error}",
+ error,
+ )
+ column_stats = column.compute_and_prepare_response(data)
all_stats.append(column_stats)
if not all_stats:
diff --git a/services/worker/tests/fixtures/data/audio/audio_1.wav b/services/worker/tests/fixtures/data/audio/audio_1.wav
new file mode 100644
index 0000000000..3bd88c544b
Binary files /dev/null and b/services/worker/tests/fixtures/data/audio/audio_1.wav differ
diff --git a/services/worker/tests/fixtures/data/audio/audio_2.wav b/services/worker/tests/fixtures/data/audio/audio_2.wav
new file mode 100644
index 0000000000..ad98f6928a
Binary files /dev/null and b/services/worker/tests/fixtures/data/audio/audio_2.wav differ
diff --git a/services/worker/tests/fixtures/data/audio/audio_3.wav b/services/worker/tests/fixtures/data/audio/audio_3.wav
new file mode 100644
index 0000000000..6a2372b0da
Binary files /dev/null and b/services/worker/tests/fixtures/data/audio/audio_3.wav differ
diff --git a/services/worker/tests/fixtures/data/audio/audio_4.wav b/services/worker/tests/fixtures/data/audio/audio_4.wav
new file mode 100644
index 0000000000..5d7b5c2652
Binary files /dev/null and b/services/worker/tests/fixtures/data/audio/audio_4.wav differ
diff --git a/services/worker/tests/fixtures/datasets.py b/services/worker/tests/fixtures/datasets.py
index b9a80627d1..44f2d0c646 100644
--- a/services/worker/tests/fixtures/datasets.py
+++ b/services/worker/tests/fixtures/datasets.py
@@ -27,6 +27,7 @@
from datasets.features.features import FeatureType
from .descriptive_statistics_dataset import (
+ audio_dataset,
statistics_dataset,
statistics_not_supported_dataset,
statistics_string_text_dataset,
@@ -178,4 +179,5 @@ def datasets() -> Mapping[str, Dataset]:
"descriptive_statistics": statistics_dataset,
"descriptive_statistics_string_text": statistics_string_text_dataset,
"descriptive_statistics_not_supported": statistics_not_supported_dataset,
+ "audio_statistics": audio_dataset,
}
diff --git a/services/worker/tests/fixtures/descriptive_statistics_dataset.py b/services/worker/tests/fixtures/descriptive_statistics_dataset.py
index d307f97f40..1e4b308029 100644
--- a/services/worker/tests/fixtures/descriptive_statistics_dataset.py
+++ b/services/worker/tests/fixtures/descriptive_statistics_dataset.py
@@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The HuggingFace Authors.
+from pathlib import Path
from typing import Optional
-from datasets import ClassLabel, Dataset, Features, Sequence, Value
+from datasets import Audio, ClassLabel, Dataset, Features, Sequence, Value
# from GLUE dataset, "ax" subset
LONG_TEXTS = """The cat sat on the mat.
@@ -1366,22 +1367,18 @@ def nan_column() -> list[None]:
"list__sequence_of_sequence_bool_nan_column": Sequence(Sequence(Value("bool"))),
"list__sequence_of_sequence_bool_all_nan_column": Sequence(Sequence(Value("bool"))),
"list__sequence_of_sequence_dict_column": Sequence(
- Sequence({"author": Value("string"), "content": Value("string"), "likes": Value("int32")})
+ Sequence({"author": Value("string"), "likes": Value("int32")})
),
"list__sequence_of_sequence_dict_nan_column": Sequence(
- Sequence({"author": Value("string"), "content": Value("string"), "likes": Value("int32")})
+ Sequence({"author": Value("string"), "likes": Value("int32")})
),
"list__sequence_of_sequence_dict_all_nan_column": Sequence(
- Sequence({"author": Value("string"), "content": Value("string"), "likes": Value("int32")})
- ),
- "list__sequence_of_list_dict_column": Sequence(
- [{"author": Value("string"), "content": Value("string"), "likes": Value("int32")}]
- ),
- "list__sequence_of_list_dict_nan_column": Sequence(
- [{"author": Value("string"), "content": Value("string"), "likes": Value("int32")}]
+ Sequence({"author": Value("string"), "likes": Value("int32")})
),
+ "list__sequence_of_list_dict_column": Sequence([{"author": Value("string"), "likes": Value("int32")}]),
+ "list__sequence_of_list_dict_nan_column": Sequence([{"author": Value("string"), "likes": Value("int32")}]),
"list__sequence_of_list_dict_all_nan_column": Sequence(
- [{"author": Value("string"), "content": Value("string"), "likes": Value("int32")}]
+ [{"author": Value("string"), "likes": Value("int32")}]
),
}
),
@@ -1577,3 +1574,29 @@ def nan_column() -> list[None]:
}
),
)
+
+
+audio_dataset = Dataset.from_dict(
+ {
+ "audio": [
+ str(Path(__file__).resolve().parent / "data" / "audio" / "audio_1.wav"),
+ str(Path(__file__).resolve().parent / "data" / "audio" / "audio_2.wav"),
+ str(Path(__file__).resolve().parent / "data" / "audio" / "audio_3.wav"),
+ str(Path(__file__).resolve().parent / "data" / "audio" / "audio_4.wav"),
+ ],
+ "audio_nan": [
+ str(Path(__file__).resolve().parent / "data" / "audio" / "audio_1.wav"),
+ None,
+ str(Path(__file__).resolve().parent / "data" / "audio" / "audio_3.wav"),
+ None,
+ ],
+ "audio_all_nan": [None, None, None, None],
+ },
+ features=Features(
+ {
+ "audio": Audio(sampling_rate=1600),
+ "audio_nan": Audio(sampling_rate=1600),
+ "audio_all_nan": Audio(sampling_rate=1600),
+ }
+ ),
+)
diff --git a/services/worker/tests/fixtures/hub.py b/services/worker/tests/fixtures/hub.py
index 24f83a0a3f..8c28bdb7d0 100644
--- a/services/worker/tests/fixtures/hub.py
+++ b/services/worker/tests/fixtures/hub.py
@@ -333,6 +333,13 @@ def hub_public_descriptive_statistics_not_supported(datasets: Mapping[str, Datas
delete_hub_dataset_repo(repo_id=repo_id)
+@pytest.fixture(scope="session")
+def hub_public_audio_statistics(datasets: Mapping[str, Dataset]) -> Iterator[str]:
+ repo_id = create_hub_dataset_repo(prefix="audio_statistics", dataset=datasets["audio_statistics"])
+ yield repo_id
+ delete_hub_dataset_repo(repo_id=repo_id)
+
+
@pytest.fixture(scope="session")
def hub_public_n_configs_with_default(datasets: Mapping[str, Dataset]) -> Iterator[str]:
default_config_name, _ = get_default_config_split()
@@ -1068,6 +1075,19 @@ def hub_responses_descriptive_statistics_not_supported(
}
+@pytest.fixture
+def hub_responses_audio_statistics(
+ hub_public_audio_statistics: str,
+) -> HubDatasetTest:
+ return {
+ "name": hub_public_audio_statistics,
+ "config_names_response": create_config_names_response(hub_public_audio_statistics),
+ "splits_response": create_splits_response(hub_public_audio_statistics),
+ "first_rows_response": None,
+ "parquet_and_info_response": None,
+ }
+
+
@pytest.fixture
def hub_responses_descriptive_statistics_parquet_builder(
hub_public_descriptive_statistics_parquet_builder: str,
diff --git a/services/worker/tests/job_runners/split/test_descriptive_statistics.py b/services/worker/tests/job_runners/split/test_descriptive_statistics.py
index 5b6ad6ab87..c971ce469c 100644
--- a/services/worker/tests/job_runners/split/test_descriptive_statistics.py
+++ b/services/worker/tests/job_runners/split/test_descriptive_statistics.py
@@ -9,8 +9,10 @@
import numpy as np
import pandas as pd
import polars as pl
+import pyarrow.parquet as pq
import pytest
from datasets import ClassLabel, Dataset
+from datasets.table import embed_table_storage
from huggingface_hub.hf_api import HfApi
from libcommon.dtos import Priority
from libcommon.exceptions import StatisticsComputationError
@@ -27,6 +29,7 @@
MAX_NUM_STRING_LABELS,
MAX_PROPORTION_STRING_LABELS,
NO_LABEL_VALUE,
+ AudioColumn,
BoolColumn,
ClassLabelColumn,
ColumnType,
@@ -471,6 +474,51 @@ def descriptive_statistics_string_text_partial_expected(datasets: Mapping[str, D
return {"num_examples": df.shape[0], "statistics": expected_statistics, "partial": True}
+@pytest.fixture
+def audio_statistics_expected() -> dict: # type: ignore
+ audio_lengths = [1.0, 2.0, 3.0, 4.0] # datasets consists of 4 audio files of 1, 2, 3, 4 seconds lengths
+ audio_statistics = count_expected_statistics_for_numerical_column(
+ column=pd.Series(audio_lengths), dtype=ColumnType.FLOAT
+ )
+ expected_statistics = {
+ "column_name": "audio",
+ "column_type": ColumnType.AUDIO,
+ "column_statistics": audio_statistics,
+ }
+ nan_audio_lengths = [1.0, None, 3.0, None] # take first and third audio file for this testcase
+ nan_audio_statistics = count_expected_statistics_for_numerical_column(
+ column=pd.Series(nan_audio_lengths), dtype=ColumnType.FLOAT
+ )
+ expected_nan_statistics = {
+ "column_name": "audio_nan",
+ "column_type": ColumnType.AUDIO,
+ "column_statistics": nan_audio_statistics,
+ }
+ expected_all_nan_statistics = {
+ "column_name": "audio_all_nan",
+ "column_type": ColumnType.AUDIO,
+ "column_statistics": {
+ "nan_count": 4,
+ "nan_proportion": 1.0,
+ "min": None,
+ "max": None,
+ "mean": None,
+ "median": None,
+ "std": None,
+ "histogram": None,
+ },
+ }
+ return {
+ "num_examples": 4,
+ "statistics": {
+ "audio": expected_statistics,
+ "audio_nan": expected_nan_statistics,
+ "audio_all_nan": expected_all_nan_statistics,
+ },
+ "partial": False,
+ }
+
+
@pytest.mark.parametrize(
"column_name",
[
@@ -710,6 +758,31 @@ def test_polars_struct_thread_panic_error(struct_thread_panic_error_parquet_file
assert df.schema["conversations"] == conversations_schema
+@pytest.mark.parametrize(
+ "column_name",
+ ["audio", "audio_nan", "audio_all_nan"],
+)
+def test_audio_statistics(
+ column_name: str,
+ audio_statistics_expected: dict, # type: ignore
+ datasets: Mapping[str, Dataset],
+ tmp_path_factory: pytest.TempPathFactory,
+) -> None:
+ expected = audio_statistics_expected["statistics"][column_name]["column_statistics"]
+ parquet_directory = tmp_path_factory.mktemp("data")
+ parquet_filename = parquet_directory / "data.parquet"
+ dataset_table = datasets["audio_statistics"].data
+ dataset_table_embedded = embed_table_storage(dataset_table) # store audio as bytes instead of paths to files
+ pq.write_table(dataset_table_embedded, parquet_filename)
+ computed = AudioColumn._compute_statistics(
+ parquet_directory=parquet_directory,
+ column_name=column_name,
+ n_samples=4,
+ n_bins=N_BINS,
+ )
+ assert computed == expected
+
+
@pytest.mark.parametrize(
"hub_dataset_name,expected_error_code",
[
@@ -717,8 +790,8 @@ def test_polars_struct_thread_panic_error(struct_thread_panic_error_parquet_file
("descriptive_statistics_string_text", None),
("descriptive_statistics_string_text_partial", None),
("descriptive_statistics_not_supported", "NoSupportedFeaturesError"),
+ ("audio_statistics", None),
("gated", None),
- ("audio", "NoSupportedFeaturesError"),
],
)
def test_compute(
@@ -732,12 +805,13 @@ def test_compute(
hub_responses_descriptive_statistics_parquet_builder: HubDatasetTest,
hub_responses_gated_descriptive_statistics: HubDatasetTest,
hub_responses_descriptive_statistics_not_supported: HubDatasetTest,
- hub_responses_audio: HubDatasetTest,
+ hub_responses_audio_statistics: HubDatasetTest,
hub_dataset_name: str,
expected_error_code: Optional[str],
descriptive_statistics_expected: dict, # type: ignore
descriptive_statistics_string_text_expected: dict, # type: ignore
descriptive_statistics_string_text_partial_expected: dict, # type: ignore
+ audio_statistics_expected: dict, # type: ignore
) -> None:
hub_datasets = {
"descriptive_statistics": hub_responses_descriptive_statistics,
@@ -745,7 +819,7 @@ def test_compute(
"descriptive_statistics_string_text_partial": hub_responses_descriptive_statistics_parquet_builder,
"descriptive_statistics_not_supported": hub_responses_descriptive_statistics_not_supported,
"gated": hub_responses_gated_descriptive_statistics,
- "audio": hub_responses_audio,
+ "audio_statistics": hub_responses_audio_statistics,
}
expected = {
"descriptive_statistics": descriptive_statistics_expected,
@@ -753,6 +827,7 @@ def test_compute(
"gated": descriptive_statistics_expected,
"descriptive_statistics_string_text": descriptive_statistics_string_text_expected,
"descriptive_statistics_string_text_partial": descriptive_statistics_string_text_partial_expected,
+ "audio_statistics": audio_statistics_expected,
}
dataset = hub_datasets[hub_dataset_name]["name"]
splits_response = hub_datasets[hub_dataset_name]["splits_response"]
@@ -849,6 +924,7 @@ def test_compute(
ColumnType.INT,
ColumnType.STRING_TEXT,
ColumnType.LIST,
+ ColumnType.AUDIO,
]:
hist, expected_hist = (
column_response_stats.pop("histogram"),