Skip to content

Commit

Permalink
Added detailed documentation to dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JialinWang28 committed Nov 19, 2024
1 parent 732257e commit 1194501
Showing 1 changed file with 235 additions and 20 deletions.
255 changes: 235 additions & 20 deletions pyha_analyzer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
get_datasets returns the train and validation datasets as BirdCLEFDataset objects.
If this module is run directly, it tests that the dataloader works
"""
import logging
import os
Expand Down Expand Up @@ -34,8 +33,96 @@
# pylint: disable=too-many-instance-attributes
class PyhaDFDataset(Dataset):
"""
Dataset designed to work with pyha output
Save unchunked data
A class for loading the dataset and creates dataloaders for training and validation
This class represents a dataset designed to work with pyha output, and saves
unchunked data. Contains methods for loading the dataset.
Attributes:
samples (pandas.DataFrame):
- filtered dataframe which contains non-null values
in the "FILE NAME" column.
num_samples (int):
- how many samples the dataset contains.
train (bool):
- whether the dataset is a training set or not.
device (str):
- which device the computations will occur.
onehot (bool):
- whether the dataset has been one-hot encoded.
cfg (pyha_analyzer.config.Config):
- configuration settings of the dataset.
data_dir (set[str]):
- collection of paths of audio files.
bad_files (list[int]):
- indices of bad files.
classes (list[str]):
- list of all species to be classified.
class_to_idx (dict[str, int]):
- dictionary with each string in classes as keys
and an assigned int index value as value
num_classes (int):
- the number of species (classes) in the dataset
convert_to_mel (torchaudio.transforms.MelSpectrogram):
- transformation object that converts raw waveforms into mel spectrograms.
decibel_convert (torchaudio.transforms.AmplitudeToDB):
- transformation that converts raw waveforms into decibel scale.
mixup (pyha_analyzer.augmentations.Mixup):
- torch.nn.Module object that mixes up the dataset for data augmentation.
audio_augmentations (torch.nn.Sequential):
- pipeline for augmenting audio files.
image_augmentations (torch.nn.Sequential):
- pipeline for augmenting spectrogramimages.
Methods:
__init__(self, df, train, species, cfg, onehot = False):
- Initializes a PyhaDFDataset with the given attributes.
calc_class_distribution(self):
- Returns class distribution (number of samples per class).
verify_audio(self):
- Checks to make sure files exist that are referenced in input df.
process_audio_file(self, file_name: str):
- Save waveform of audio file as a tensor and save that tensor to .pt.
serialize_data(self):
- For each file, check to see if the file is already a presaved tensor
If the files is not a presaved tensor and is an audio file, convert
to tensor to make future training faster
__len__(self):
- Returns how many elements are in the sample DataFrame.
to_image(self, audio):
- Convert audio clip to 3-channel spectrogram image
__getitem__(self, index):
- Takes an index and returns tuple of spectrogram image with
corresponding label
get_num_classes(self):
- Returns number of classes
get_sample_weights(self):
- Returns the weights as computed by the first place winner of
BirdCLEF 2023
"""

# df, train, and species decided outside of config, so those cannot be added in there
Expand All @@ -47,6 +134,30 @@ def __init__(self,
cfg: config.Config,
onehot:bool = False,
) -> None:
"""
Initializes a PyhaDFDataset with the given attributes.
Args:
df (pandas.DataFrame):
- dataframe of data contained in this object.
train (bool):
- whether the data is the training set data or not.
species (list[str]):
- a list of strings representing each species identified
in this dataset.
cfg (pyha_analyzer.config.Config):
- configuration settings of the dataset.
onehot (bool):
- whether the data has been one-hot encoded or not, dafaulted
to False.
Returns:
None
"""
self.samples = df[~(df[cfg.file_name_col].isnull())]
if onehot:
if self.samples.iloc[0][species].shape[0] != len(species):
Expand Down Expand Up @@ -80,6 +191,7 @@ def __init__(self,
self.num_classes = len(species)
self.serialize_data()


self.class_dist = self.calc_class_distribution()

#Data augmentations
Expand Down Expand Up @@ -107,7 +219,14 @@ def __init__(self,
RandomApply([audtr.TimeMasking(cfg.time_mask_param)], p=cfg.time_mask_p))

def calc_class_distribution(self) -> torch.Tensor:
""" Returns class distribution (number of samples per class) """
"""
Returns class distribution (number of samples per class).
Returns:
class_dist (torch.Tensor):
- a 1d Torch Tensor representing the amount of samples
in each class.
"""
class_dist = []
if self.onehot:
for class_name in self.classes:
Expand All @@ -125,7 +244,10 @@ def calc_class_distribution(self) -> torch.Tensor:

def verify_audio(self) -> None:
"""
Checks to make sure files exist that are referenced in input df
Checks to make sure files exist that are referenced in input df.
Returns:
None
"""
missing_files = pd.Series(self.samples[self.cfg.file_name_col].unique()) \
.progress_apply(
Expand All @@ -143,7 +265,19 @@ def verify_audio(self) -> None:

def process_audio_file(self, file_name: str) -> pd.Series:
"""
Save waveform of audio file as a tensor and save that tensor to .pt
Save waveform of audio file as a tensor and save that tensor to .pt.
Args:
file_name (str):
- name of an audio file
Returns:
- Pandas Series of the original file name and the new file name.
- If the audio file has already been processed, does not save and
simply returns the original file name and the "supposed" new file
name.
- If an exception occurs, returns the original file location and
"bad" as the new location.
"""
exts = "." + file_name.split(".")[-1]
new_name = file_name.replace(exts, ".pt")
Expand Down Expand Up @@ -193,9 +327,12 @@ def process_audio_file(self, file_name: str) -> pd.Series:

def serialize_data(self) -> None:
"""
For each file, check to see if the file is already a presaved tensor
If the files is not a presaved tensor and is an audio file, convert to tensor to make
Future training faster
For each file, check to see if the file is already a presaved tensor
If the files is not a presaved tensor and is an audio file, convert
to tensor to make future training faster
Returns:
None
"""
self.verify_audio()
files = pd.DataFrame(self.samples[self.cfg.file_name_col].unique(),
Expand Down Expand Up @@ -225,11 +362,24 @@ def serialize_data(self) -> None:
self.samples["original_file_path"] = self.samples[self.cfg.file_name_col]

def __len__(self):
"""
Returns how many elements are in the sample DataFrame.
Returns:
- The number of elements in the sample DataFrame
"""
return self.samples.shape[0]

def to_image(self, audio):
"""
Convert audio clip to 3-channel spectrogram image
Convert audio clip to 3-channel spectrogram image
Args:
audio (torch.Tensor):
- torch tensor that represents the audio clip as a raw waveform
Returns:
- torch tensor that represents the audioclip as a mel spectrogram
"""
# Mel spectrogram
# Pylint complains this is not callable, but it is a torch.nn.Module
Expand All @@ -250,7 +400,20 @@ def to_image(self, audio):
return torch.stack([mel, mel, mel])

def __getitem__(self, index): #-> Any:
""" Takes an index and returns tuple of spectrogram image with corresponding label
"""
Takes an index and returns tuple of spectrogram image with corresponding label
Args:
index (int):
- index of the item
Returns:
tuple: a tuple containing:
image (torch.Tensor):
- torch tensor representing the mel spectrogram
image at the index
target (torch.Tensor):
- torch tensor representing the one-hot encoded label of
the image
"""
assert isinstance(index, int)
audio, target = utils.get_annotation(
Expand Down Expand Up @@ -282,14 +445,24 @@ def __getitem__(self, index): #-> Any:
return image, target

def get_num_classes(self) -> int:
""" Returns number of classes
"""
Returns number of classes
Returns:
- the "num_classes" attribute, which represents the number of classes
in the dataset
"""
return self.num_classes

def get_sample_weights(self) -> pd.Series:
""" Returns the weights as computed by the first place winner of BirdCLEF 2023
See https://www.kaggle.com/competitions/birdclef-2023/discussion/412808
Congrats on your win!
"""
Returns the weights as computed by the first place winner of BirdCLEF 2023
See https://www.kaggle.com/competitions/birdclef-2023/discussion/412808
Congrats on your win!
Returns:
- a pandas.Series object that represents the weights as computed by
the first place winner of BirdCLEF 2023.
"""
manual_id = self.cfg.manual_id_col
all_primary_labels = self.samples[manual_id]
Expand All @@ -302,9 +475,20 @@ def get_sample_weights(self) -> pd.Series:


def get_datasets(cfg) -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFDataset]]:
""" Returns train and validation datasets
does random sampling for train/valid split
adds transforms to dataset
"""
Returns train and validation datasets
does random sampling for train/valid split
adds transforms to dataset
Args:
cfg (pyha_analyzer.config.Config):
- configuration settings of the dataset.
Returns:
tuple: a tuple containing:
- train_ds (PyhaDFDataset): the training dataset
- valid_ds (PyhaDFDataset): the validation dataset
- infer_ds (Optional[PyhaDFDataset]): the inference/test dataset
"""
train_p = cfg.train_test_split
path = cfg.dataframe_csv
Expand Down Expand Up @@ -392,7 +576,13 @@ def get_datasets(cfg) -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFData

def set_torch_file_sharing(_) -> None:
"""
Sets torch.multiprocessing to use file sharing
Sets torch.multiprocessing to use file sharing
Args:
_ (any): placeholder parameter.
Returns:
None
"""
torch.multiprocessing.set_sharing_strategy("file_system")

Expand All @@ -401,6 +591,28 @@ def make_dataloaders(train_dataset, val_dataset, infer_dataset, cfg
)-> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
"""
Loads datasets and dataloaders for train and validation
Args:
train_dataset (PyhaDFDataset):
- the training dataset
val_dataset (PyhaDFDataset):
- the validation dataset
infer_dataset (PyhaDFDataset):
- the inference/test dataset
cfg (pyha_analyzer.config.Config):
- configuration settings of the dataset
Returns:
tuple: a tuple containing:
- train_dataloader (torch.utils.data.Dataloader): dataloader for the
training set
- val_dataloader (torch.utils.data.Dataloader): dataloader for the
validation set
- infer_dataloader (Optional[torch.utils.data.Dataloader]): dataloader
for the inference/test dataset
"""


Expand Down Expand Up @@ -452,7 +664,10 @@ def make_dataloaders(train_dataset, val_dataset, infer_dataset, cfg

def main() -> None:
"""
testing function.
testing function.
Returns:
None
"""
# run = wandb.init(
# entity=cfg.wandb_entity,
Expand Down

0 comments on commit 1194501

Please sign in to comment.