Skip to content

Commit

Permalink
[update] pad audio if shorter than 30 seconds
Browse files Browse the repository at this point in the history
  • Loading branch information
Jourdelune committed Jun 20, 2024
1 parent 503e9db commit 5e2b98a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
6 changes: 3 additions & 3 deletions dataset/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def process_file(self, audio_dir: str, filename: str, max_duration_ms: int = 300
for i in range(0, total_duration, max_duration_ms):
segment = audio[i : i + max_duration_ms]

# don't save if shorter than 3 second
if len(segment) < 3000:
continue
# pad the segment if it is shorter than the max duration
if len(segment) < max_duration_ms:
segment = segment + AudioSegment.silent(duration=max_duration_ms - len(segment))

segment_filename = (
f"{os.path.splitext(filename)[0]}_part{i // max_duration_ms + 1}.mp3"
Expand Down
52 changes: 37 additions & 15 deletions dataset/load_dataset.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,51 @@
"""
This script loads the dataset from the given path.
"""

import glob
import os

from datasets import Dataset
from datasets import Audio, Dataset, Features


def load_dataset(path: str) -> Dataset:
"""
Function that gathers the dataset.
def load_dataset(path: str, sampling_rate: int) -> Dataset:
"""Function that loads the dataset.
Args:
path (str): The path to the dataset.
sampling_rate (int): The sampling rate of the audio.
Returns:
Dataset: The dataset loaded.
"""

def gen():
i = 4
clean_audio = glob.glob(path + "/*.mp3")
codecs = [f for f in glob.glob(path+"/*", ) if os.path.isdir(f)]
codecs = [
f
for f in glob.glob(
path + "/*",
)
if os.path.isdir(f)
]
for i in range(len(clean_audio)):
return_dict = {"clean": os.path.abspath(clean_audio[i])}
for codec in codecs:
return_dict = {
"clean": os.path.abspath(clean_audio[i]),
"compressed": os.path.abspath(glob.glob(str(codec) + "/*.mp3")[i]),
"type": codec.split('/')[-1]

}
yield return_dict

return Dataset.from_generator(gen)
try:
return_dict = {
"clean": os.path.abspath(clean_audio[i]),
"compressed": os.path.abspath(
os.path.join(codec, os.path.basename(clean_audio[i]))
),
}
yield return_dict
except IndexError:
pass

features = Features(
{
"clean": Audio(sampling_rate=sampling_rate),
"compressed": Audio(sampling_rate=sampling_rate),
}
)
return Dataset.from_generator(gen, features=features)

0 comments on commit 5e2b98a

Please sign in to comment.