-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[update] pad audio if shorter than 30 seconds
- Loading branch information
1 parent
503e9db
commit 5e2b98a
Showing
2 changed files
with
40 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,51 @@ | ||
""" | ||
This script loads the dataset from the given path. | ||
""" | ||
|
||
import glob | ||
import os | ||
|
||
from datasets import Dataset | ||
from datasets import Audio, Dataset, Features | ||
|
||
|
||
def load_dataset(path: str) -> Dataset: | ||
""" | ||
Function that gathers the dataset. | ||
def load_dataset(path: str, sampling_rate: int) -> Dataset: | ||
"""Function that loads the dataset. | ||
Args: | ||
path (str): The path to the dataset. | ||
sampling_rate (int): The sampling rate of the audio. | ||
Returns: | ||
Dataset: The dataset loaded. | ||
""" | ||
|
||
def gen(): | ||
i = 4 | ||
clean_audio = glob.glob(path + "/*.mp3") | ||
codecs = [f for f in glob.glob(path+"/*", ) if os.path.isdir(f)] | ||
codecs = [ | ||
f | ||
for f in glob.glob( | ||
path + "/*", | ||
) | ||
if os.path.isdir(f) | ||
] | ||
for i in range(len(clean_audio)): | ||
return_dict = {"clean": os.path.abspath(clean_audio[i])} | ||
for codec in codecs: | ||
return_dict = { | ||
"clean": os.path.abspath(clean_audio[i]), | ||
"compressed": os.path.abspath(glob.glob(str(codec) + "/*.mp3")[i]), | ||
"type": codec.split('/')[-1] | ||
|
||
} | ||
yield return_dict | ||
|
||
return Dataset.from_generator(gen) | ||
try: | ||
return_dict = { | ||
"clean": os.path.abspath(clean_audio[i]), | ||
"compressed": os.path.abspath( | ||
os.path.join(codec, os.path.basename(clean_audio[i])) | ||
), | ||
} | ||
yield return_dict | ||
except IndexError: | ||
pass | ||
|
||
features = Features( | ||
{ | ||
"clean": Audio(sampling_rate=sampling_rate), | ||
"compressed": Audio(sampling_rate=sampling_rate), | ||
} | ||
) | ||
return Dataset.from_generator(gen, features=features) |