diff --git a/pyha_analyzer/dataset.py b/pyha_analyzer/dataset.py index c4d3521..fc80f4d 100644 --- a/pyha_analyzer/dataset.py +++ b/pyha_analyzer/dataset.py @@ -5,7 +5,6 @@ get_datasets returns the train and validation datasets as BirdCLEFDataset objects. If this module is run directly, it tests that the dataloader works - """ import logging import os @@ -34,8 +33,96 @@ # pylint: disable=too-many-instance-attributes class PyhaDFDataset(Dataset): """ - Dataset designed to work with pyha output - Save unchunked data + A class for loading the dataset and creates dataloaders for training and validation + + This class represents a dataset designed to work with pyha output, and saves + unchunked data. Contains methods for loading the dataset. + + Attributes: + samples (pandas.DataFrame): + - filtered dataframe which contains non-null values + in the "FILE NAME" column. + + num_samples (int): + - how many samples the dataset contains. + + train (bool): + - whether the dataset is a training set or not. + + device (str): + - which device the computations will occur. + + onehot (bool): + - whether the dataset has been one-hot encoded. + + cfg (pyha_analyzer.config.Config): + - configuration settings of the dataset. + + data_dir (set[str]): + - collection of paths of audio files. + + bad_files (list[int]): + - indices of bad files. + + classes (list[str]): + - list of all species to be classified. + + class_to_idx (dict[str, int]): + - dictionary with each string in classes as keys + and an assigned int index value as value + + num_classes (int): + - the number of species (classes) in the dataset + + convert_to_mel (torchaudio.transforms.MelSpectrogram): + - transformation object that converts raw waveforms into mel spectrograms. + + decibel_convert (torchaudio.transforms.AmplitudeToDB): + - transformation that converts raw waveforms into decibel scale. + + mixup (pyha_analyzer.augmentations.Mixup): + - torch.nn.Module object that mixes up the dataset for data augmentation. + + audio_augmentations (torch.nn.Sequential): + - pipeline for augmenting audio files. + + image_augmentations (torch.nn.Sequential): + - pipeline for augmenting spectrogramimages. + + Methods: + __init__(self, df, train, species, cfg, onehot = False): + - Initializes a PyhaDFDataset with the given attributes. + + calc_class_distribution(self): + - Returns class distribution (number of samples per class). + + verify_audio(self): + - Checks to make sure files exist that are referenced in input df. + + process_audio_file(self, file_name: str): + - Save waveform of audio file as a tensor and save that tensor to .pt. + + serialize_data(self): + - For each file, check to see if the file is already a presaved tensor + If the files is not a presaved tensor and is an audio file, convert + to tensor to make future training faster + + __len__(self): + - Returns how many elements are in the sample DataFrame. + + to_image(self, audio): + - Convert audio clip to 3-channel spectrogram image + + __getitem__(self, index): + - Takes an index and returns tuple of spectrogram image with + corresponding label + + get_num_classes(self): + - Returns number of classes + + get_sample_weights(self): + - Returns the weights as computed by the first place winner of + BirdCLEF 2023 """ # df, train, and species decided outside of config, so those cannot be added in there @@ -47,6 +134,30 @@ def __init__(self, cfg: config.Config, onehot:bool = False, ) -> None: + """ + Initializes a PyhaDFDataset with the given attributes. + + Args: + df (pandas.DataFrame): + - dataframe of data contained in this object. + + train (bool): + - whether the data is the training set data or not. + + species (list[str]): + - a list of strings representing each species identified + in this dataset. + + cfg (pyha_analyzer.config.Config): + - configuration settings of the dataset. + + onehot (bool): + - whether the data has been one-hot encoded or not, dafaulted + to False. + + Returns: + None + """ self.samples = df[~(df[cfg.file_name_col].isnull())] if onehot: if self.samples.iloc[0][species].shape[0] != len(species): @@ -80,6 +191,7 @@ def __init__(self, self.num_classes = len(species) self.serialize_data() + self.class_dist = self.calc_class_distribution() #Data augmentations @@ -107,7 +219,14 @@ def __init__(self, RandomApply([audtr.TimeMasking(cfg.time_mask_param)], p=cfg.time_mask_p)) def calc_class_distribution(self) -> torch.Tensor: - """ Returns class distribution (number of samples per class) """ + """ + Returns class distribution (number of samples per class). + + Returns: + class_dist (torch.Tensor): + - a 1d Torch Tensor representing the amount of samples + in each class. + """ class_dist = [] if self.onehot: for class_name in self.classes: @@ -125,7 +244,10 @@ def calc_class_distribution(self) -> torch.Tensor: def verify_audio(self) -> None: """ - Checks to make sure files exist that are referenced in input df + Checks to make sure files exist that are referenced in input df. + + Returns: + None """ missing_files = pd.Series(self.samples[self.cfg.file_name_col].unique()) \ .progress_apply( @@ -143,7 +265,19 @@ def verify_audio(self) -> None: def process_audio_file(self, file_name: str) -> pd.Series: """ - Save waveform of audio file as a tensor and save that tensor to .pt + Save waveform of audio file as a tensor and save that tensor to .pt. + + Args: + file_name (str): + - name of an audio file + + Returns: + - Pandas Series of the original file name and the new file name. + - If the audio file has already been processed, does not save and + simply returns the original file name and the "supposed" new file + name. + - If an exception occurs, returns the original file location and + "bad" as the new location. """ exts = "." + file_name.split(".")[-1] new_name = file_name.replace(exts, ".pt") @@ -193,9 +327,12 @@ def process_audio_file(self, file_name: str) -> pd.Series: def serialize_data(self) -> None: """ - For each file, check to see if the file is already a presaved tensor - If the files is not a presaved tensor and is an audio file, convert to tensor to make - Future training faster + For each file, check to see if the file is already a presaved tensor + If the files is not a presaved tensor and is an audio file, convert + to tensor to make future training faster + + Returns: + None """ self.verify_audio() files = pd.DataFrame(self.samples[self.cfg.file_name_col].unique(), @@ -225,11 +362,24 @@ def serialize_data(self) -> None: self.samples["original_file_path"] = self.samples[self.cfg.file_name_col] def __len__(self): + """ + Returns how many elements are in the sample DataFrame. + + Returns: + - The number of elements in the sample DataFrame + """ return self.samples.shape[0] def to_image(self, audio): """ - Convert audio clip to 3-channel spectrogram image + Convert audio clip to 3-channel spectrogram image + + Args: + audio (torch.Tensor): + - torch tensor that represents the audio clip as a raw waveform + + Returns: + - torch tensor that represents the audioclip as a mel spectrogram """ # Mel spectrogram # Pylint complains this is not callable, but it is a torch.nn.Module @@ -250,7 +400,20 @@ def to_image(self, audio): return torch.stack([mel, mel, mel]) def __getitem__(self, index): #-> Any: - """ Takes an index and returns tuple of spectrogram image with corresponding label + """ + Takes an index and returns tuple of spectrogram image with corresponding label + Args: + index (int): + - index of the item + + Returns: + tuple: a tuple containing: + image (torch.Tensor): + - torch tensor representing the mel spectrogram + image at the index + target (torch.Tensor): + - torch tensor representing the one-hot encoded label of + the image """ assert isinstance(index, int) audio, target = utils.get_annotation( @@ -282,14 +445,24 @@ def __getitem__(self, index): #-> Any: return image, target def get_num_classes(self) -> int: - """ Returns number of classes + """ + Returns number of classes + + Returns: + - the "num_classes" attribute, which represents the number of classes + in the dataset """ return self.num_classes def get_sample_weights(self) -> pd.Series: - """ Returns the weights as computed by the first place winner of BirdCLEF 2023 - See https://www.kaggle.com/competitions/birdclef-2023/discussion/412808 - Congrats on your win! + """ + Returns the weights as computed by the first place winner of BirdCLEF 2023 + See https://www.kaggle.com/competitions/birdclef-2023/discussion/412808 + Congrats on your win! + + Returns: + - a pandas.Series object that represents the weights as computed by + the first place winner of BirdCLEF 2023. """ manual_id = self.cfg.manual_id_col all_primary_labels = self.samples[manual_id] @@ -302,9 +475,20 @@ def get_sample_weights(self) -> pd.Series: def get_datasets(cfg) -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFDataset]]: - """ Returns train and validation datasets - does random sampling for train/valid split - adds transforms to dataset + """ + Returns train and validation datasets + does random sampling for train/valid split + adds transforms to dataset + + Args: + cfg (pyha_analyzer.config.Config): + - configuration settings of the dataset. + + Returns: + tuple: a tuple containing: + - train_ds (PyhaDFDataset): the training dataset + - valid_ds (PyhaDFDataset): the validation dataset + - infer_ds (Optional[PyhaDFDataset]): the inference/test dataset """ train_p = cfg.train_test_split path = cfg.dataframe_csv @@ -392,7 +576,13 @@ def get_datasets(cfg) -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFData def set_torch_file_sharing(_) -> None: """ - Sets torch.multiprocessing to use file sharing + Sets torch.multiprocessing to use file sharing + + Args: + _ (any): placeholder parameter. + + Returns: + None """ torch.multiprocessing.set_sharing_strategy("file_system") @@ -401,6 +591,28 @@ def make_dataloaders(train_dataset, val_dataset, infer_dataset, cfg )-> Tuple[DataLoader, DataLoader, Optional[DataLoader]]: """ Loads datasets and dataloaders for train and validation + + Args: + train_dataset (PyhaDFDataset): + - the training dataset + + val_dataset (PyhaDFDataset): + - the validation dataset + + infer_dataset (PyhaDFDataset): + - the inference/test dataset + + cfg (pyha_analyzer.config.Config): + - configuration settings of the dataset + + Returns: + tuple: a tuple containing: + - train_dataloader (torch.utils.data.Dataloader): dataloader for the + training set + - val_dataloader (torch.utils.data.Dataloader): dataloader for the + validation set + - infer_dataloader (Optional[torch.utils.data.Dataloader]): dataloader + for the inference/test dataset """ @@ -452,7 +664,10 @@ def make_dataloaders(train_dataset, val_dataset, infer_dataset, cfg def main() -> None: """ - testing function. + testing function. + + Returns: + None """ # run = wandb.init( # entity=cfg.wandb_entity,