diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2fd313b..0e4bae9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,6 +27,7 @@ jobs: - name: Install library run: | python -m pip install .[keras_tf] + python -m pip install .[torch] python -m pip install .[openl3] python -m pip install .[autopool] python -m pip install .[tests] @@ -102,3 +103,27 @@ jobs: - name: Test with pytest run: | pytest --cov-report term-missing --cov-report=xml --cov dcase_models ./tests + + build_torch: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.6, 3.7, 3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + sudo apt-get install -y wget libsndfile-dev sox + - name: Install library + run: | + python -m pip install .[torch] + python -m pip install .[tests] + - name: Test with pytest + run: | + pytest --cov-report term-missing --cov-report=xml --cov dcase_models ./tests diff --git a/dcase_models/backend.py b/dcase_models/backend.py new file mode 100644 index 0000000..112572a --- /dev/null +++ b/dcase_models/backend.py @@ -0,0 +1,25 @@ +backends = [] + +try: + import tensorflow as tf + + tensorflow_version = '2' if tf.__version__.split(".")[0] == "2" else '1' + + if tensorflow_version == '2': + backends.append('tensorflow2') + else: + backends.append('tensorflow1') +except: + tensorflow_version = None + +try: + import torch + backends.append('torch') +except: + torch = None + +try: + import sklearn + backends.append('sklearn') +except: + sklearn = None diff --git a/dcase_models/data/data_generator.py b/dcase_models/data/data_generator.py index 91cbd64..bb6e638 100644 --- a/dcase_models/data/data_generator.py +++ b/dcase_models/data/data_generator.py @@ -3,13 +3,26 @@ import inspect import random -import tensorflow as tf -tensorflow2 = tf.__version__.split('.')[0] == '2' +from dcase_models.backend import backends -if tensorflow2: - from tensorflow.keras.utils import Sequence +if 'torch' in backends: + import torch + from torch.utils.data import Dataset as TorchDataset else: - from keras.utils import Sequence + class TorchDataset: + pass + +if ('tensorflow1' in backends) | ('tensorflow2' in backends): + import tensorflow as tf + tensorflow2 = tf.__version__.split('.')[0] == '2' + + if tensorflow2: + from tensorflow.keras.utils import Sequence + else: + from keras.utils import Sequence +else: + class Sequence: + pass from dcase_models.data.feature_extractor import FeatureExtractor from dcase_models.data.dataset_base import Dataset @@ -523,21 +536,62 @@ def set_scaler_outputs(self, scaler_outputs): self.scaler_outputs = scaler_outputs -class KerasDataGenerator(Sequence): +if ('tensorflow1' in backends) | ('tensorflow2' in backends): + class KerasDataGenerator(Sequence): - def __init__(self, data_generator): - self.data_gen = data_generator - self.data_gen.shuffle_list() + def __init__(self, data_generator): + self.data_gen = data_generator + self.data_gen.shuffle_list() - def __len__(self): - 'Denotes the number of batches per epoch' - return len(self.data_gen) + def __len__(self): + 'Denotes the number of batches per epoch' + return len(self.data_gen) - def __getitem__(self, index): - 'Generate one batch of data' - # Generate indexes of the batch - return self.data_gen.get_data_batch(index) + def __getitem__(self, index): + 'Generate one batch of data' + # Generate indexes of the batch + return self.data_gen.get_data_batch(index) - def on_epoch_end(self): - 'Updates indexes after each epoch' - self.data_gen.shuffle_list() + def on_epoch_end(self): + 'Updates indexes after each epoch' + self.data_gen.shuffle_list() +else: + class KerasDataGenerator(): + def __init__(self, data_generator): + raise ImportError("Tensorflow is not installed") + + +if 'torch' in backends: + class PyTorchDataGenerator(TorchDataset): + + def __init__(self, data_generator): + self.data_gen = data_generator + self.data_gen.shuffle_list() + + def __len__(self): + 'Denotes the number of batches per epoch' + return len(self.data_gen) + + def __getitem__(self, index): + 'Generate one batch of data' + # Generate indexes of the batch + X, Y = self.data_gen.get_data_batch(index) + if type(X) is not list: + X = [X] + if type(Y) is not list: + Y = [Y] + tensor_X = [] + tensor_Y = [] + for j in range(len(X)): + tensor_X.append(torch.tensor(X[j], dtype=torch.float)) + for j in range(len(Y)): + tensor_Y.append(torch.tensor(Y[j], dtype=torch.long)) + return tensor_X, tensor_Y + + def shuffle_list(self): + 'Updates indexes after each epoch' + self.data_gen.shuffle_list() +else: + class PyTorchDataGenerator(): + def __init__(self, data_generator): + raise ImportError("Pytorch is not installed") \ No newline at end of file diff --git a/dcase_models/model/container.py b/dcase_models/model/container.py index bcc9b36..487bc1c 100644 --- a/dcase_models/model/container.py +++ b/dcase_models/model/container.py @@ -1,27 +1,37 @@ +from dcase_models.util.files import save_json, save_pickle, load_pickle +from dcase_models.util.metrics import evaluate_metrics +from dcase_models.util.callbacks import ClassificationCallback, SEDCallback, TaggingCallback +from dcase_models.util.callbacks import PyTorchCallback +from dcase_models.data.data_generator import DataGenerator, KerasDataGenerator, PyTorchDataGenerator + import numpy as np import os import json +import inspect -import tensorflow as tf -tensorflow2 = tf.__version__.split('.')[0] == '2' - -if tensorflow2: - import tensorflow.keras.backend as K - from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint - from tensorflow.keras.models import model_from_json, Model - from tensorflow.keras.layers import Dense, Input -else: - import keras.backend as K - from keras.callbacks import CSVLogger, ModelCheckpoint - from keras.models import model_from_json, Model - from keras.layers import Dense, Input - -from dcase_models.util.files import save_json -from dcase_models.util.metrics import evaluate_metrics -from dcase_models.util.callbacks import ClassificationCallback, SEDCallback -from dcase_models.util.callbacks import TaggingCallback -from dcase_models.data.data_generator import DataGenerator, KerasDataGenerator +from dcase_models.backend import backends + +if 'torch' in backends: + import torch + from torch import nn +if ('tensorflow1' in backends) | ('tensorflow2' in backends): + import tensorflow as tf + tensorflow2 = tf.__version__.split('.')[0] == '2' + + if tensorflow2: + import tensorflow.keras.backend as K + from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint + from tensorflow.keras.models import model_from_json, Model + from tensorflow.keras.layers import Dense, Input + else: + import keras.backend as K + from keras.callbacks import CSVLogger, ModelCheckpoint + from keras.models import model_from_json, Model + from keras.layers import Dense, Input + +if 'sklearn' in backends: + import sklearn class ModelContainer(): """ Abstract base class to store and manage models. @@ -534,3 +544,759 @@ def get_intermediate_output(self, output_ix_name, inputs): return None return output + + +class PyTorchModelContainer(ModelContainer): + """ ModelContainer for pytorch models. + + A class that contains a pytorch model, the methods to train, evaluate, + save and load the model. Descendants of this class can be specialized for + specific models (i.e see SB_CNN class) + + Parameters + ---------- + model : nn.Module or None, default=None + If model is None the model is created with `build()`. + + model_path : str or None, default=None + Path to the model. If it is not None, the model loaded from this path. + + model_name : str, default=PyTorchModelContainer + Model name. + + metrics : list of str, default=['classification'] + List of metrics used for evaluation. + See `dcase_models.utils.metrics`. + + kwargs + Additional keyword arguments to `load_model_from_json()`. + + """ + + if 'torch' in backends: + class Model(nn.Module): + # Define your model here + def __init__(self, params): + super().__init__() + + def forward(self, x): + pass + + def __init__(self, model=None, + model_name="PyTorchModelContainer", + metrics=['classification'], use_cuda=True, **kwargs): + + if 'torch' not in backends: + raise ImportError('Pytorch is not installed') + + self.use_cuda = use_cuda & torch.cuda.is_available() + + super().__init__(model=model, model_path=None, + model_name=model_name, + metrics=metrics) + + # Build the model + if model is None: + self.build() + + def build(self): + """ + Define your model here + """ + self.model = self.Model(self) + + def train(self, data_train, data_val, weights_path='./', + optimizer='Adam', learning_rate=0.001, early_stopping=100, + considered_improvement=0.01, losses='BCELoss', + loss_weights=[1], batch_size=32, sequence_time_sec=0.5, + metric_resolution_sec=1.0, label_list=[], + shuffle=True, epochs=10): + """ + Trains the pytorch model using the data and paramaters of arguments. + + Parameters + ---------- + data_train : tuple of ndarray or DataGenerator + Tuple or DataGenerator of training. + Example of tuple: (X_train, Y_train) whose shapes + are (N_instances, N_hops, N_mel_bands) and (N_instances, N_classes) respectively. + Example of DataGenerator: data_gen = DataGenerator(..., train=True) + data_val : tuple of ndarray or DataGenerator + Idem for validation set. + weights_path : str + Path where to save the best weights of the model + in the training process + optimizer : str or torch.nn.Optimizer + Optimizer used to train the model. String argument should coincide + with the class name in torch.optim + learning_rate : float + Learning rate used to train the model + early_stopping : int + Number of epochs to stop the training if there is not improvement + considered_improvement : float + Improvement in the performance metric considered to save a checkpoint. + losses : (list of) torch loss functions (see https://pytorch.org/docs/stable/nn.html#loss-functions) + Loss function(s) used for training. + loss_weights : list + List of weights for each loss function. Should be of the same length than losses + batch_size : int + Batch size used in the training process. Ignore if data_train is a DataGenerator + sequence_time_sec : float + Used for SED evaluation. Time resolution (in seconds) of the output + (i.e features.sequence_hop_time) + metric_resolution_sec: float + Used for SED evaluation. Time resolution (in seconds) of evaluation + label_list: list + List of class labels (i.e dataset.label_list). + This is needed for model evaluation. + shuffle : bool + If true the data_train is shuffle after each epoch. + Ignored if data_train is DataGenerator + epochs : int + Number of training epochs + + """ + + if type(losses) is not list: + losses = [losses] + + for j, loss in enumerate(losses): + if type(loss) is str: + try: + loss_fn = getattr(torch.nn.modules.loss, loss) + except: + raise AttributeError( + ("Loss {} not availabe. See the list of losses at " + "https://pytorch.org/docs/stable/nn.html#loss-functions").format(loss) + ) + else: + if (torch.nn.modules.loss._Loss in inspect.getmro(loss.__class__)): + loss_fn = loss + else: + raise AttributeError('loss should be a string or torch.nn.modules.loss._Loss') + losses[j] = loss_fn() + + if type(loss_weights) is not list: + loss_weights = [loss_weights] + + if len(loss_weights) != len(losses): + raise AttributeError( + ("loss_weights and losses should have the same length. Received: lengths {:d} and {:d} " + "respectively").format(len(loss_weights), len(losses)) + ) + + if type(optimizer) is str: + try: + optimizer_function = getattr(torch.optim, optimizer) + except: + raise AttributeError( + ("Optimizer {} not availabe. See the list of optimizers at " + "https://pytorch.org/docs/stable/optim.html").format(optimizer) + ) + else: + if (torch.optim.optimizer.Optimizer in inspect.getmro(optimizer.__class__)): + optimizer_function = optimizer + else: + raise AttributeError('optimizer should be a string or torch.optim.Optimizer') + + opt = optimizer_function(self.model.parameters(), lr=learning_rate) + + if self.use_cuda: + self.model.to('cuda') + + train_from_numpy = False + if type(data_train) is tuple: + train_from_numpy = True + X_train, Y_train = data_train + if type(X_train) is not list: + X_train = [X_train] + if type(Y_train) is not list: + Y_train = [Y_train] + + tensors_X = [] + tensors_Y = [] + for j in range(len(X_train)): + tensors_X.append(torch.Tensor(X_train[j])) + tensors_Y.append(torch.Tensor(Y_train[j])) + + torch_dataset = torch.utils.data.TensorDataset(*(tensors_X + tensors_Y)) + data_loader = torch.utils.data.DataLoader(torch_dataset, batch_size=batch_size, shuffle=shuffle) + + n_inputs = len(tensors_X) + else: + torch_data_train = PyTorchDataGenerator(data_train) + data_loader = torch.utils.data.DataLoader(torch_data_train, batch_size=1) + + current_metrics = [0] + best_metrics = -np.inf + epoch_best = 0 + epochs_since_improvement = 0 + + if self.metrics[0] == 'sed': + callback = SEDCallback( + data_val, best_F1=-np.Inf, early_stopping=early_stopping, file_weights=weights_path, + considered_improvement=considered_improvement, sequence_time_sec=sequence_time_sec, + metric_resolution_sec=metric_resolution_sec, label_list=label_list + ) + elif self.metrics[0] == 'classification': + callback = ClassificationCallback( + data_val, best_acc=-np.Inf, early_stopping=early_stopping, file_weights=weights_path, + considered_improvement=considered_improvement, + label_list=label_list + ) + elif self.metrics[0] == 'tagging': + callback = TaggingCallback( + data_val, best_F1=-np.Inf, early_stopping=early_stopping, file_weights=weights_path, + considered_improvement=considered_improvement, label_list=label_list + ) + else: + raise AttributeError("{} metric is not allowed".format(self.metrics[0])) + + callback = PyTorchCallback(self, callback) + + for epoch in range(epochs): + # train + for batch_ix, batch in enumerate(data_loader): + # Compute prediction and loss + if train_from_numpy: + X = batch[:n_inputs] + Y = batch[n_inputs:] + else: + X, Y = batch + for j in range(len(X)): + X[j] = torch.squeeze(X[j], axis=0) + Y[j] = torch.squeeze(Y[j], axis=0) + if self.use_cuda: + for j in range(len(X)): + X[j] = X[j].cuda() + pred = self.model(*X) + if type(pred) is not list: + preds = [pred] + else: + preds = pred + assert len(preds) == len(losses) + loss = 0 + for loss_fn, loss_weight, pred, gt in zip(losses, loss_weights, preds, Y): + if self.use_cuda: + gt = gt.cuda() + loss += loss_weight*loss_fn(pred.float(), gt.float()) + + # Backpropagation + opt.zero_grad() + loss.backward() + opt.step() + + if batch == len(data_train) - 1: + loss, current = loss.item(), batch * len(X) + print(f"loss: {loss:>7f} [{current:>5d}/{len(data_train):>5d}]") + + if shuffle & (not train_from_numpy): + torch_data_train.shuffle_list() + + # validation + with torch.no_grad(): + callback.on_epoch_end(epoch) + if callback.stop_training: + break + + def evaluate(self, data_test, **kwargs): + """ + Evaluates the keras model using X_test and Y_test. + + Parameters + ---------- + X_test : ndarray + 3D array with mel-spectrograms of test set. + Shape = (N_instances, N_hops, N_mel_bands) + Y_test : ndarray + 2D array with the annotations of test set (one hot encoding). + Shape (N_instances, N_classes) + scaler : Scaler, optional + Scaler objet to be applied if is not None. + + Returns + ------- + float + evaluation's accuracy + list + list of annotations (ground_truth) + list + list of model predictions + + """ + return evaluate_metrics(self, data_test, self.metrics, **kwargs) + + def load_model_from_json(self, folder, **kwargs): + """ + Loads a model from a model.json file in the path given by folder. + The model is load in self.model attribute. + + Parameters + ---------- + folder : str + Path to the folder that contains model.json file + """ + raise NotImplementedError() + + def save_model_json(self, folder): + """ + Saves the model to a model.json file in the given folder path. + + Parameters + ---------- + folder : str + Path to the folder to save model.json file + """ + raise NotImplementedError() + + def save_model_weights(self, weights_folder): + """ + Saves self.model weights in weights_folder/best_weights.hdf5. + + Parameters + ---------- + weights_folder : str + Path to save the weights file + """ + weights_file = 'best_weights.pth' + weights_path = os.path.join(weights_folder, weights_file) + torch.save(self.model.state_dict(), weights_path) + + def load_model_weights(self, weights_folder): + """ + Loads self.model weights in weights_folder/best_weights.hdf5. + + Parameters + ---------- + weights_folder : str + Path to save the weights file. + + """ + weights_file = 'best_weights.pth' + weights_path = os.path.join(weights_folder, weights_file) + self.model.load_state_dict(torch.load(weights_path)) + self.model.eval() + + def load_pretrained_model_weights(self, + weights_folder='./pretrained_weights'): + """ + Loads pretrained weights to self.model weights. + + Parameters + ---------- + weights_folder : str + Path to load the weights file + + """ + raise NotImplementedError() + + def get_number_of_parameters(self): + return sum(p.numel() for p in self.model.parameters()) + + def check_if_model_exists(self, folder, **kwargs): + """ Checks if the model already exits in the path. + + Check if the folder/model.json file exists and includes + the same model as self.model. + + Parameters + ---------- + folder : str + Path to the folder to check. + + """ + weights_file = 'best_weights.pth' + weights_path = os.path.join(folder, weights_file) + new_model = self.Model(self) + try: + new_model.load_state_dict(torch.load(weights_path)) + except: + return False + + return True + + def cut_network(self, layer_where_to_cut): + """ Cuts the network at the layer passed as argument. + + Parameters + ---------- + layer_where_to_cut : str or int + Layer name (str) or index (int) where cut the model. + + Returns + ------- + keras.models.Model + Cutted model. + + """ + raise NotImplementedError() + + def fine_tuning(self, layer_where_to_cut, new_number_of_classes=10, + new_activation='softmax', + freeze_source_model=True, new_model=None): + """ Create a new model for fine-tuning. + + Cut the model in the layer_where_to_cut layer + and add a new fully-connected layer. + + Parameters + ---------- + layer_where_to_cut : str or int + Name (str) of index (int) of the layer where cut the model. + This layer is included in the new model. + + new_number_of_classes : int + Number of units in the new fully-connected layer + (number of classes). + + new_activation : str + Activation of the new fully-connected layer. + + freeze_source_model : bool + If True, the source model is set to not be trainable. + + new_model : Keras Model + If is not None, this model is added after the cut model. + This is useful if you want add more than + a fully-connected layer. + + """ + raise NotImplementedError() + + def get_available_intermediate_outputs(self): + """ Return a list of available intermediate outputs. + + Return a list of model's layers. + + Returns + ------- + list of str + List of layers names. + + """ + raise NotImplementedError() + + def get_intermediate_output(self, output_ix_name, inputs): + """ Return the output of the model in a given layer. + + Cut the model in the given layer and predict the output + for the given inputs. + + Returns + ------- + ndarray + Output of the model in the given layer. + + """ + raise NotImplementedError() + + def predict(self, x): + # Imitate keras.predict() function + """ Imitate the output of the keras.predict() function. + + Cut the model in the given layer and predict the output + for the given inputs. + + Parameters + ---------- + x: (list of) ndarray + Model's input(s) + + Returns + ------- + (list of) ndarray + Model's input(s) + + """ + if type(x) is not list: + x = [x] + for j in range(len(x)): + x[j] = torch.tensor(x[j].astype(float), dtype=torch.float) + if self.use_cuda: + x[j] = x[j].cuda() + if self.use_cuda: + self.model.cuda() + + y = self.model(*x) + + if (type(y) is list) or (type(y) is tuple): + y_np = [] + for j in range(len(y)): + y_np.append(y[j].cpu().detach().numpy()) + else: + y_np = y.cpu().detach().numpy() + + return y_np + + +class SklearnModelContainer(ModelContainer): + """ ModelContainer for scikit-learn models. + + A class that contains a scikit-learn classifier, the methods to train, evaluate, + save and load the classifier. + + Parameters + ---------- + model : scikit-learn model, default=None + If model is None the model is loaded from model_path + + model_path : str or None, default=None + Path to the model. If model is None, and model_path is not None, the model is loaded from this path. + + model_name : str, default=SklearnModelContainer + Model name. + + metrics : list of str, default=['classification'] + List of metrics used for evaluation. + See `dcase_models.utils.metrics`. + + """ + + def __init__(self, model=None, model_path=None, + model_name="SklearnModelContainer", + metrics=['classification']): + + if (model is None) & (model_path is None): + raise AttributeError("model or model_path should be passed as argument") + + super().__init__(model=model, model_path=None, + model_name=model_name, + metrics=metrics) + + if (model is None) & (model_path is not None): + self.load_model_weights(model_path) + + def build(self): + """ + Not used + """ + raise NotImplementedError() + + def train(self, data_train, data_val=None, weights_path='./', + sequence_time_sec=0.5, metric_resolution_sec=1.0, label_list=[], + **kwargs): + """ + Trains the scikit-learn model using the data and parameters of arguments. + + Parameters + ---------- + data_train : tuple of ndarray or DataGenerator + Tuple or DataGenerator of training. + Tuple should include inputs and ouputs for training: (X_train, Y_train) whose shapes + are for instances (N_instances, N_mel_bands) and (N_instances, N_classes) or (N_instances,) respectively. + Example of DataGenerator: data_gen = DataGenerator(..., train=True) + data_val : None or tuple of ndarray or DataGenerator + Idem for validation set. If None, there is no validation when the training is ended. + weights_path : str + Path where to save the best weights of the model after the training process + sequence_time_sec : float + Used for SED evaluation. Time resolution (in seconds) of the output. Ignored if data_val is None. + (i.e features.sequence_hop_time) + metric_resolution_sec: float + Used for SED evaluation. Time resolution (in seconds) of evaluation. Ignored if data_val is None. + label_list: list + List of class labels (i.e dataset.label_list). Ignored if data_val is None. + This is needed for model evaluation. + kwargs: kwargs + kwargs of sklearn's fit() function + + """ + if type(data_train) is not tuple: + # DataGenerator, check if partial_fit is available + if 'partial_fit' not in dir(self.model): + raise AttributeError( + ("This model does not allow partial_fit, and therefore data_train should be a numpy array. " + "Please call DataGenerator.get_data() before.")) + for batch in range(len(data_train)): + X, Y = data_train.get_data_batch(batch) + if type(X) is not np.ndarray: + raise AttributeError("Multi-input is not allowed") + if type(Y) is not np.ndarray: + raise AttributeError("Multi-output is not allowed") + if len(X.shape) != 2: + raise AttributeError("The input should be a 2D array. Received shape {}".format(X.shape)) + if len(Y.shape) > 2: + raise AttributeError("The output should be a 1D or 2D array. Received shape {}".format(Y.shape)) + classes = np.arange(len(label_list)) + self.model.partial_fit(X, Y, classes) + else: + if type(data_train[0]) is not np.ndarray: + raise AttributeError("Multi-input is not allowed") + if type(data_train[1]) is not np.ndarray: + raise AttributeError("Multi-output is not allowed") + if len(data_train[0].shape) != 2: + raise AttributeError("The input should be a 2D array. Received shape {}".format(data_train[0].shape)) + if len(data_train[1].shape) > 2: + raise AttributeError( + "The output should be a 1D or 2D array. Received shape {}".format(data_train[1].shape)) + + self.model.fit(data_train[0], data_train[1], **kwargs) + + self.save_model_weights(weights_path) + + kwargs = {} + if self.metrics[0] == 'sed': + kwargs = { + 'sequence_time_sec': sequence_time_sec, + 'metric_resolution_sec': metric_resolution_sec + } + if data_val is not None: + results = evaluate_metrics( + self, + data_val, + self.metrics, + label_list=label_list, + **kwargs + ) + return results[self.metrics[0]] + + def evaluate(self, data_test, **kwargs): + """ + Evaluates the keras model using X_test and Y_test. + + Parameters + ---------- + X_test : ndarray + 2D array with mel-spectrograms of test set. + Shape = (N_instances, N_mel_bands) + Y_test : ndarray + 2D array with the annotations of test set (one hot encoding). + Shape (N_instances, N_classes) + scaler : Scaler, optional + Scaler objet to be applied if is not None. + + Returns + ------- + float + evaluation's accuracy + list + list of annotations (ground_truth) + list + list of model predictions + + """ + return evaluate_metrics(self, data_test, self.metrics, **kwargs) + + def load_model_from_json(self, folder, **kwargs): + """ + Loads a model from a model.json file in the path given by folder. + The model is load in self.model attribute. + + Parameters + ---------- + folder : str + Path to the folder that contains model.json file + """ + raise NotImplementedError() + + def save_model_json(self, folder): + """ + Saves the model to a model.json file in the given folder path. + + Parameters + ---------- + folder : str + Path to the folder to save model.json file + """ + raise NotImplementedError() + + def save_model_weights(self, weights_folder): + """ + Saves self.model weights in weights_folder/best_weights.hdf5. + + Parameters + ---------- + weights_folder : str + Path to save the weights file + """ + weights_file = 'model.skl' + weights_path = os.path.join(weights_folder, weights_file) + save_pickle(self.model, weights_path) + + def load_model_weights(self, weights_folder): + """ + Loads self.model weights in weights_folder/best_weights.hdf5. + + Parameters + ---------- + weights_folder : str + Path to save the weights file. + + """ + weights_file = 'model.skl' + weights_path = os.path.join(weights_folder, weights_file) + self.model = load_pickle(weights_path) + + def load_pretrained_model_weights(self, + weights_folder='./pretrained_weights'): + raise NotImplementedError() + + def get_number_of_parameters(self): + return len(self.model.get_params()) + + def check_if_model_exists(self, folder, **kwargs): + """ Checks if the model already exits in the path. + + Check if the folder/model.json file exists and includes + the same model as self.model. + + Parameters + ---------- + folder : str + Path to the folder to check. + + """ + weights_file = 'model.skl' + weights_path = os.path.join(folder, weights_file) + if not os.path.exists(weights_path): + return False + + new_model = load_pickle(weights_path) + + if new_model.__class__.__name__ != self.model.__class__.__name__: + return False + + new_params = new_model.get_params() + for key, value in self.model.get_params().items(): + if value != new_params[key]: + return False + return True + + def cut_network(self, layer_where_to_cut): + raise NotImplementedError() + + def fine_tuning(self, layer_where_to_cut, new_number_of_classes=10, + new_activation='softmax', + freeze_source_model=True, new_model=None): + raise NotImplementedError() + + def get_available_intermediate_outputs(self): + raise NotImplementedError() + + def get_intermediate_output(self, output_ix_name, inputs): + raise NotImplementedError() + + def predict(self, x): + # Imitate keras.predict() function + """ Imitates the output of the keras.predict() function. + + Parameters + ---------- + x: ndarray + Model's input + + Returns + ------- + ndarray + Model's output + + """ + pred = self.model.predict(x) + if len(pred.shape) == 1: + # if single output, apply one-hot encoder (needed for evaluation) + y = np.zeros((len(pred), len(self.model.classes_))) + for j in range(len(pred)): + y[j, int(pred[j])] = 1 + else: + y = pred + return y diff --git a/dcase_models/model/models.py b/dcase_models/model/models.py index 490769f..8d145bf 100644 --- a/dcase_models/model/models.py +++ b/dcase_models/model/models.py @@ -3,38 +3,43 @@ import sys import os -import tensorflow as tf -tensorflow2 = tf.__version__.split('.')[0] == '2' - -if tensorflow2: - from tensorflow.keras.layers import GRU, Bidirectional - from tensorflow.keras.layers import TimeDistributed, Activation, Reshape - from tensorflow.keras.layers import GlobalAveragePooling2D - from tensorflow.keras.layers import GlobalMaxPooling2D - from tensorflow.keras.layers import Input, Lambda, Conv2D, MaxPooling2D - from tensorflow.keras.layers import Conv1D - from tensorflow.keras.layers import Dropout, Dense, Flatten - from tensorflow.keras.layers import BatchNormalization - from tensorflow.keras.layers import Layer - from tensorflow.keras.models import Model - from tensorflow.keras.regularizers import l2 - import tensorflow.keras.backend as K +from dcase_models.backend import backends + +if ('tensorflow1' in backends) | ('tensorflow2' in backends): + import tensorflow as tf + tensorflow2 = tf.__version__.split('.')[0] == '2' + + if tensorflow2: + from tensorflow.keras.layers import GRU, Bidirectional + from tensorflow.keras.layers import TimeDistributed, Activation, Reshape + from tensorflow.keras.layers import GlobalAveragePooling2D + from tensorflow.keras.layers import GlobalMaxPooling2D + from tensorflow.keras.layers import Input, Lambda, Conv2D, MaxPooling2D + from tensorflow.keras.layers import Conv1D + from tensorflow.keras.layers import Dropout, Dense, Flatten + from tensorflow.keras.layers import BatchNormalization + from tensorflow.keras.layers import Layer + from tensorflow.keras.models import Model + from tensorflow.keras.regularizers import l2 + import tensorflow.keras.backend as K + else: + from keras.layers import GRU, Bidirectional + from keras.layers import TimeDistributed, Activation, Reshape + from keras.layers import GlobalAveragePooling2D + from keras.layers import GlobalMaxPooling2D + from keras.layers import Input, Lambda, Conv2D, MaxPooling2D + from keras.layers import Conv1D + from keras.layers import Dropout, Dense, Flatten + from keras.layers import BatchNormalization + from keras.layers import Layer + from keras.models import Model + from keras.regularizers import l2 + import keras.backend as K + + from tensorflow import clip_by_value else: - from keras.layers import GRU, Bidirectional - from keras.layers import TimeDistributed, Activation, Reshape - from keras.layers import GlobalAveragePooling2D - from keras.layers import GlobalMaxPooling2D - from keras.layers import Input, Lambda, Conv2D, MaxPooling2D - from keras.layers import Conv1D - from keras.layers import Dropout, Dense, Flatten - from keras.layers import BatchNormalization - from keras.layers import Layer - from keras.models import Model - from keras.regularizers import l2 - import keras.backend as K - - -from tensorflow import clip_by_value + class Layer: + pass from dcase_models.model.container import KerasModelContainer diff --git a/dcase_models/util/callbacks.py b/dcase_models/util/callbacks.py index 8c31b6b..b8dfec1 100644 --- a/dcase_models/util/callbacks.py +++ b/dcase_models/util/callbacks.py @@ -3,13 +3,19 @@ from dcase_models.util.metrics import evaluate_metrics -import tensorflow as tf -tensorflow2 = tf.__version__.split('.')[0] == '2' +from dcase_models.backend import backends -if tensorflow2: - from tensorflow.keras.callbacks import Callback +if ('tensorflow1' in backends) | ('tensorflow2' in backends): + import tensorflow as tf + tensorflow2 = tf.__version__.split('.')[0] == '2' + + if tensorflow2: + from tensorflow.keras.callbacks import Callback + else: + from keras.callbacks import Callback else: - from keras.callbacks import Callback + class Callback: + pass eps = 1e-6 @@ -21,7 +27,7 @@ class ClassificationCallback(Callback): def __init__(self, data, file_weights=None, best_acc=0, early_stopping=0, considered_improvement=0.01, - label_list=[]): + label_list=[], model=None): """ Initialize the keras callback Parameters @@ -146,7 +152,9 @@ def on_epoch_end(self, epoch, logs={}): """ results = evaluate_metrics(self.model, self.data, ['sed'], - label_list=self.label_list) + label_list=self.label_list, + sequence_time_sec=self.sequence_time_sec, + metric_resolution_sec=self.metric_resolution_sec) results = results['sed'].results() F1 = results['overall']['f_measure']['f_measure'] @@ -250,3 +258,24 @@ def on_epoch_end(self, epoch, logs={}): print('Not improvement for %d epochs, stopping the training' % self.early_stopping) self.model.stop_training = True + +class PyTorchCallback(): + class ToyKerasModel(): + def __init__(self, model_container): + self.model_container = model_container + self.stop_training = False + + def save_weights(self, file_weights): + self.model_container.save_model_weights(file_weights) + + def predict(self, X): + return self.model_container.predict(X) + + def __init__(self, model_container, callback): + self.model = self.ToyKerasModel(model_container) + self.callback = callback + self.callback.model = self.model + + def on_epoch_end(self, epoch, logs={}): + self.callback.on_epoch_end(epoch, logs=logs) + self.stop_training = self.model.stop_training diff --git a/dcase_models/util/metrics.py b/dcase_models/util/metrics.py index ae91ff5..0029e52 100644 --- a/dcase_models/util/metrics.py +++ b/dcase_models/util/metrics.py @@ -148,48 +148,41 @@ def _check_lists_for_evaluation(Y_val, Y_predicted): True if checks passed. """ - + if type(Y_val) is not list: raise AttributeError( 'Y_val type is invalid. It should be a list of 2D array and received {}'.format( - type(Y_val) - ) + type(Y_val)) ) if type(Y_predicted) is not list: raise AttributeError( 'Y_predicted type is invalid. It should be a list of 2D array and received {}'.format( - type(Y_predicted) - ) + type(Y_predicted)) ) if len(Y_val) != len(Y_predicted): raise AttributeError('Y_val and Y_predicted should have the same length (received {:d} and {:d})'.format( - len(Y_val), len(Y_predicted) - ) + len(Y_val), len(Y_predicted)) ) for j in range(len(Y_val)): if type(Y_val[j]) is not np.ndarray: raise AttributeError('Each element of Y_val should be a 2D numpy array and received {}'.format( - type(Y_val[j]) + type(Y_val[j])) ) - ) if len(Y_val[j].shape) != 2: raise AttributeError('Each element of Y_val should be a 2D array and received an array of shape {}'.format( - str(Y_val[j].shape) + str(Y_val[j].shape)) ) - ) if type(Y_predicted[j]) is not np.ndarray: raise AttributeError('Each element of Y_predicted should be a 2D numpy array and received {}'.format( - type(Y_predicted[j]) + type(Y_predicted[j])) ) - ) if len(Y_predicted[j].shape) != 2: raise AttributeError('Each element of Y_predicted should be a 2D array and received an array of shape {}'.format( - str(Y_predicted[j].shape) + str(Y_predicted[j].shape)) ) - ) def sed(Y_val, Y_predicted, sequence_time_sec=0.5, metric_resolution_sec=1.0, label_list=[]): diff --git a/setup.py b/setup.py index 7b9000e..cc9b7eb 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ long_description_content_type="text/markdown", url="https://github.com/pzinemanas/DCASE-models", download_url='http://github.com/pzinemanas/DCASE-models/releases', - packages=setuptools.find_packages(), + packages=setuptools.find_packages(exclude=["test", "*.test", "*.test.*"]), install_requires=[ 'numpy>=1.1', 'pandas>=0.25', @@ -33,6 +33,7 @@ 'keras_tf': ['tensorflow<1.14', 'keras==2.2.4'], 'keras_tf_gpu': ['tensorflow-gpu<1.14', 'keras==2.2.4'], 'tf2': ['tensorflow>2.0'], + 'torch': ['torch>1.1'], 'openl3': ['openl3==0.3.1'], 'autopool': ['autopool==0.1.0'], 'docs': ['numpydoc', 'sphinx!=1.3.1', 'sphinx_rtd_theme'], @@ -45,9 +46,12 @@ ] }, classifiers=[ - "Programming Language :: Python :: 3", + "Programming Language :: Python", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8" ], - python_requires='>=3.6', + #python_requires='>=3.6, <3.9', ) diff --git a/tests/test_container.py b/tests/test_container.py index ce23cfa..495f28a 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -1,22 +1,63 @@ from dcase_models.model.container import ModelContainer, KerasModelContainer +from dcase_models.model.container import PyTorchModelContainer, SklearnModelContainer from dcase_models.data.data_generator import KerasDataGenerator - -import tensorflow as tf - -tensorflow2 = tf.__version__.split(".")[0] == "2" - -if tensorflow2: - from tensorflow.keras.layers import Input, Dense - from tensorflow.keras.models import Model -else: - from keras.layers import Input, Dense - from keras.models import Model +from dcase_models.util.files import save_pickle import os import numpy as np import pytest import shutil +from dcase_models.backend import backends + +if 'torch' in backends: + import torch + from torch import nn + + class TorchModel(PyTorchModelContainer): + class Model(nn.Module): + def __init__(self, params): + super().__init__() + self.layer = nn.Linear(10, 2) + + def forward(self, x): + y = self.layer(x) + return y + + def __init__(self, model=None, model_path=None, + metrics=['classification']): + super().__init__(model=model, model_path=model_path, + model_name='MLP', metrics=metrics) + + torch_container = TorchModel() +else: + torch = None + +if ('tensorflow1' in backends) | ('tensorflow2' in backends): + import tensorflow as tf + + tensorflow_version = '2' if tf.__version__.split(".")[0] == "2" else '1' + + if tensorflow_version == '2': + from tensorflow.keras.layers import Input, Dense + from tensorflow.keras.models import Model + else: + from keras.layers import Input, Dense + from keras.models import Model + + x = Input(shape=(10,), dtype="float32", name="input") + y = Dense(2)(x) + model = Model(x, y) +else: + tensorflow_version = None + +if 'sklearn' in backends: + import sklearn + from sklearn.ensemble import RandomForestClassifier + from sklearn.svm import SVC + from sklearn.linear_model import SGDClassifier +else: + sklearn = None def _clean(path): if os.path.isdir(path): @@ -25,9 +66,7 @@ def _clean(path): os.remove(path) -x = Input(shape=(10,), dtype="float32", name="input") -y = Dense(2)(x) -model = Model(x, y) +os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # ModelContainer @@ -58,6 +97,7 @@ def test_model_container(): # KerasModelContainer +@pytest.mark.skipif(tensorflow_version is None, reason="Tensorflow is not installed") def test_init(): _clean("./model.json") model_container = KerasModelContainer(model) @@ -70,7 +110,7 @@ def test_init(): assert len(model_container.model.layers) == 2 _clean("./model.json") - +@pytest.mark.skipif(tensorflow_version is None, reason="Tensorflow is not installed") def test_load_model_from_json(): _clean("./model.json") model_container = KerasModelContainer(model) @@ -80,7 +120,7 @@ def test_load_model_from_json(): assert len(model_container.model.layers) == 2 _clean("./model.json") - +@pytest.mark.skipif(tensorflow_version is None, reason="Tensorflow is not installed") def test_save_model_from_json(): _clean("./model.json") model_container = KerasModelContainer(model) @@ -88,7 +128,7 @@ def test_save_model_from_json(): assert os.path.exists("./model.json") _clean("./model.json") - +@pytest.mark.skipif(tensorflow_version is None, reason="Tensorflow is not installed") def test_save_model_weights(): weights_file = "./best_weights.hdf5" _clean(weights_file) @@ -97,7 +137,7 @@ def test_save_model_weights(): assert os.path.exists(weights_file) _clean(weights_file) - +@pytest.mark.skipif(tensorflow_version is None, reason="Tensorflow is not installed") def test_load_model_weights(): weights_file = "./best_weights.hdf5" _clean(weights_file) @@ -111,7 +151,7 @@ def test_load_model_weights(): assert np.allclose(new_weights[1], weights[1]) _clean(weights_file) - +@pytest.mark.skipif(tensorflow_version is None, reason="Tensorflow is not installed") def test_check_if_model_exists(): model_container = KerasModelContainer(model) model_file = "./model.json" @@ -128,7 +168,7 @@ def test_check_if_model_exists(): _clean(model_file) assert not model_container.check_if_model_exists("./") - +@pytest.mark.skipif(tensorflow_version is None, reason="Tensorflow is not installed") def test_train(): x = Input(shape=(4,), dtype="float32", name="input") y = Dense(2)(x) @@ -228,3 +268,607 @@ def shuffle_list(self): # results = model_container.evaluate(([X_val, X_val2], [Y_val, Y_val2]), label_list=['1', '2']) # assert results['classification'].results()['overall']['accuracy'] > 0.25 + + +# PyTorchModelContainer +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_init(): + model_container = TorchModel() + assert len(list(model_container.model.children())) == 1 + assert model_container.model_name == "MLP" + assert model_container.metrics == ["classification"] + + model_container = PyTorchModelContainer() + assert len(list(model_container.model.children())) == 0 + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_load_model_from_json(): + model_container = PyTorchModelContainer() + with pytest.raises(NotImplementedError): + model_container.load_model_from_json("./") + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_save_model_from_json(): + model_container = PyTorchModelContainer() + with pytest.raises(NotImplementedError): + model_container.save_model_json("./") + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_save_model_weights(): + weights_file = "./best_weights.pth" + _clean(weights_file) + model_container = TorchModel() + model_container.save_model_weights("./") + assert os.path.exists(weights_file) + _clean(weights_file) + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_load_model_weights(): + weights_file = "./best_weights.pth" + _clean(weights_file) + model_container = TorchModel() + model_container.model.layer.weight.data = torch.full((10, 2), 0.5) + model_container.model.layer.bias.data = torch.full((2, ), 0.5) + weights = model_container.model.parameters() + print(list(weights)) + model_container.save_model_weights("./") + with torch.no_grad(): + model_container.model.layer.weight = nn.Parameter(torch.zeros_like(model_container.model.layer.weight)) + model_container.model.layer.bias = nn.Parameter(torch.zeros_like(model_container.model.layer.bias)) + new_weights = model_container.model.parameters() + for param1, param2 in zip(weights, new_weights): + print(param1, param2) + assert not torch.allclose(param1, param2) + + model_container.load_model_weights("./") + new_weights = model_container.model.parameters() + for param1, param2 in zip(weights, new_weights): + assert torch.allclose(param1, param2) + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_check_if_model_exists(): + model_container = TorchModel() + model_file = "./best_weights.pth" + _clean(model_file) + print(model_container.model) + model_container.save_model_weights("./") + assert model_container.check_if_model_exists("./") + + class TorchModel2(PyTorchModelContainer): + class Model(nn.Module): + def __init__(self, params): + super().__init__() + self.layer = nn.Linear(11, 2) + + def forward(self, x): + y = self.layer(x) + return y + + def __init__(self, model=None, model_path=None, + metrics=['classification']): + super().__init__(model=model, model_path=model_path, + model_name='MLP', metrics=metrics) + + model_container = TorchModel2() + + assert not model_container.check_if_model_exists("./") + + _clean(model_file) + assert not model_container.check_if_model_exists("./") + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_train(): + class ToyModel(PyTorchModelContainer): + class Model(nn.Module): + def __init__(self, params): + super().__init__() + self.layer = nn.Linear(4, 2) + self.act = nn.Softmax(-1) + + def forward(self, x): + y = self.layer(x) + y = self.act(y) + return y + + def __init__(self, model=None, model_path=None, + metrics=['classification'], use_cuda=True): + super().__init__(model=model, model_path=model_path, + model_name='MLP', metrics=metrics, use_cuda=use_cuda) + + model_container = ToyModel(use_cuda=False) + + X_train = np.concatenate((-np.ones((100, 4)), np.ones((100, 4))), axis=0) + Y_train = np.zeros((200, 2)) + Y_train[:100, 0] = 1 + Y_train[100:, 1] = 1 + X_val = -np.ones((1, 4)) + Y_val = np.zeros((1, 2)) + Y_val[0, 0] = 1 + + X_val2 = np.ones((1, 4)) + Y_val2 = np.zeros((1, 2)) + Y_val2[0, 1] = 1 + + file_weights = "./best_weights.pth" + # file_log = "./training.log" + _clean(file_weights) + with torch.no_grad(): + model_container.model.layer.weight = nn.Parameter(torch.zeros_like(model_container.model.layer.weight)) + model_container.model.layer.weight[:2, 0] = -0.5 + model_container.model.layer.weight[2:, 1] = 0.5 + + model_container.train( + (X_train, Y_train), + ([X_val, X_val2], [Y_val, Y_val2]), + epochs=3, + label_list=["1", "2"], + ) + assert os.path.exists(file_weights) + _clean(file_weights) + + results = model_container.evaluate( + ([X_val, X_val2], [Y_val, Y_val2]), label_list=["1", "2"] + ) + assert results["classification"].results()["overall"]["accuracy"] >= 0.0 + + # DataGenerator + class ToyDataGenerator: + def __init__(self, X_val, Y_val, train=True): + self.X_val = X_val + self.Y_val = Y_val + self.train = train + + def __len__(self): + return 3 + + def get_data_batch(self, index): + if self.train: + return X_val, Y_val + else: + return [X_val], [Y_val] + + def shuffle_list(self): + pass + + data_generator = ToyDataGenerator(X_train, Y_train) + data_generator_val = ToyDataGenerator(X_val, Y_val, train=False) + + model_container = ToyModel(use_cuda=False) + + with torch.no_grad(): + model_container.model.layer.weight = nn.Parameter(torch.zeros_like(model_container.model.layer.weight)) + model_container.model.layer.weight[:2, 0] = -0.5 + model_container.model.layer.weight[2:, 1] = 0.5 + + model_container.train( + data_generator, + data_generator_val, + epochs=3, + batch_size=None, + label_list=["1", "2"], + ) + assert os.path.exists(file_weights) + _clean(file_weights) + + results = model_container.evaluate( + data_generator_val, label_list=["1", "2"] + ) + assert results["classification"].results()["overall"]["accuracy"] >= 0.0 + + # Other callbacks + for metric in ["tagging", "sed"]: + model_container = ToyModel(metrics=[metric]) + + file_weights = "./best_weights.pth" + _clean(file_weights) + model_container.train( + data_generator, + data_generator_val, + epochs=3, + label_list=["1", "2"], + ) + _clean(file_weights) + + results = model_container.evaluate(data_generator_val, label_list=['1', '2']) + assert results[metric].results()['overall']['f_measure']['f_measure'] >= 0 + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_predict(): + class ToyModel(PyTorchModelContainer): + class Model(nn.Module): + def __init__(self, params): + super().__init__() + + def forward(self, x): + x = 3*x + 2 + return x + + def __init__(self, model=None, model_path=None, + metrics=['classification'], use_cuda=True): + super().__init__(model=model, model_path=model_path, + model_name='MLP', metrics=metrics, use_cuda=use_cuda) + + model_container = ToyModel(use_cuda=False) + + x = np.ones((3, 2)) + pred = model_container.predict(x) + assert np.allclose(pred, x*3 + 2) + + # multi output + class ToyModel(PyTorchModelContainer): + class Model(nn.Module): + def __init__(self, params): + super().__init__() + + def forward(self, x): + y1 = 3*x + 2 + y2 = 4*x + return [y1, y2] + + def __init__(self, model=None, model_path=None, + metrics=['classification'], use_cuda=True): + super().__init__(model=model, model_path=model_path, + model_name='MLP', metrics=metrics, use_cuda=use_cuda) + + model_container = ToyModel(use_cuda=False) + + x = np.ones((3, 2)) + pred = model_container.predict(x) + assert np.allclose(pred[0], x*3 + 2) + assert np.allclose(pred[1], x*4) + + # multi input + class ToyModel(PyTorchModelContainer): + class Model(nn.Module): + def __init__(self, params): + super().__init__() + + def forward(self, x1, x2): + y = x1 + x2 + return y + + def __init__(self, model=None, model_path=None, + metrics=['classification'], use_cuda=True): + super().__init__(model=model, model_path=model_path, + model_name='MLP', metrics=metrics, use_cuda=use_cuda) + + model_container = ToyModel(use_cuda=False) + + x1 = np.ones((3, 2)) + x2 = np.ones((3, 2)) + pred = model_container.predict([x1, x2]) + assert np.allclose(pred, x1 + x2) + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_get_number_of_parameters(): + class ToyModel(PyTorchModelContainer): + class Model(nn.Module): + def __init__(self, params): + super().__init__() + self.layer = nn.Linear(4, 2) + self.act = nn.Softmax(-1) + + def forward(self, x): + y = self.layer(x) + y = self.act(y) + return y + + def __init__(self, model=None, model_path=None, + metrics=['classification'], use_cuda=True): + super().__init__(model=model, model_path=model_path, + model_name='MLP', metrics=metrics, use_cuda=use_cuda) + + model_container = ToyModel(use_cuda=False) + + assert model_container.get_number_of_parameters() == 10 + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_cut_network(): + with pytest.raises(NotImplementedError): + torch_container.cut_network(None) + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_fine_tuning(): + with pytest.raises(NotImplementedError): + torch_container.fine_tuning(None) + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_get_available_intermediate_outputs(): + with pytest.raises(NotImplementedError): + torch_container.get_available_intermediate_outputs() + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_get_intermediate_output(): + with pytest.raises(NotImplementedError): + torch_container.get_intermediate_output(None, None) + +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_pytorch_load_pretrained_model_weights(): + with pytest.raises(NotImplementedError): + torch_container.load_pretrained_model_weights() + +# SklearnModelContainer +@pytest.mark.skipif(sklearn is None, reason="sklearn is not installed") +def test_sklearn_init(): + with pytest.raises(AttributeError): + model_container = SklearnModelContainer() + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + assert len(model_container.model.get_params()) > 0 + assert model_container.model_name == "SklearnModelContainer" + assert model_container.metrics == ["classification"] + + model_path = './' + model_file = os.path.join(model_path, 'model.skl') + _clean(model_file) + save_pickle(model, model_file) + model_container = SklearnModelContainer(model_path=model_path) + assert len(model_container.model.get_params()) > 0 + _clean(model_file) + +@pytest.mark.skipif(torch is None, reason="sklearn is not installed") +def test_sklearn_load_model_from_json(): + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + with pytest.raises(NotImplementedError): + model_container.load_model_from_json("./") + +@pytest.mark.skipif(torch is None, reason="sklearn is not installed") +def test_sklearn_save_model_from_json(): + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + with pytest.raises(NotImplementedError): + model_container.save_model_json("./") + +@pytest.mark.skipif(torch is None, reason="sklearn is not installed") +def test_sklearn_save_model_weights(): + model_path = './' + model_file = os.path.join(model_path, 'model.skl') + _clean(model_file) + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + model_container.save_model_weights(model_path) + assert os.path.exists(model_file) + _clean(model_file) + +@pytest.mark.skipif(torch is None, reason="sklearn is not installed") +def test_sklearn_load_model_weights(): + model_path = './' + model_file = os.path.join(model_path, 'model.skl') + _clean(model_file) + model = RandomForestClassifier(random_state=0) + + model_container = SklearnModelContainer(model) + model_container.save_model_weights(model_path) + model_container.train((np.zeros((2, 2)), np.zeros(2))) + params = model.get_params() + + model_container = SklearnModelContainer(model_path=model_path) + new_params = model_container.model.get_params() + + assert len(params) == len(new_params) + for key, value in params.items(): + assert value == new_params[key] + + _clean(model_file) + +@pytest.mark.skipif(torch is None, reason="sklearn is not installed") +def test_sklearn_check_if_model_exists(): + model_path = './' + model_file = os.path.join(model_path, 'model.skl') + _clean(model_file) + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + + model_container.save_model_weights(model_path) + assert model_container.check_if_model_exists("./") + + model = RandomForestClassifier(n_estimators=10) + model_container = SklearnModelContainer(model) + assert not model_container.check_if_model_exists("./") + + model = SVC() + model_container = SklearnModelContainer(model) + assert not model_container.check_if_model_exists("./") + + _clean(model_file) + assert not model_container.check_if_model_exists("./") + + +@pytest.mark.skipif(torch is None, reason="sklearn is not installed") +def test_sklearn_train(): + model_path = './' + model_file = os.path.join(model_path, 'model.skl') + _clean(model_file) + + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + + X_train = np.concatenate((-np.ones((100, 4)), np.ones((100, 4))), axis=0) + Y_train = np.zeros((200, )) + Y_train[:100] = 0 + Y_train[100:] = 1 + X_val1 = -np.ones((1, 4)) + Y_val1 = np.zeros((1, 2)) + Y_val1[0, 0] = 1 + + X_val2 = np.ones((1, 4)) + Y_val2 = np.zeros((1, 2)) + Y_val2[0, 1] = 1 + + results = model_container.train((X_train, Y_train), ([X_val1, X_val2], [Y_val1, Y_val2]), label_list=['0', '1']) + assert results.results()["overall"]["accuracy"] >= 0.9 + + # DataGenerator + class ToyDataGenerator: + def __init__(self, X_val, Y_val, train=True): + self.X_val = X_val + self.Y_val = Y_val + self.train = train + + def __len__(self): + return 3 + + def get_data_batch(self, index): + if self.train: + return self.X_val, self.Y_val + else: + return [self.X_val], [self.Y_val] + + def shuffle_list(self): + pass + + data_generator = ToyDataGenerator(X_train, Y_train) + data_generator_val = ToyDataGenerator(X_val1, Y_val1, train=False) + + # RandomForest does not include partial_fit + with pytest.raises(AttributeError): + model_container.train(data_generator, data_generator_val) + + model = SGDClassifier() + model_container = SklearnModelContainer(model) + + results = model_container.train(data_generator, data_generator_val, label_list=['0', '1']) + assert results.results()["overall"]["accuracy"] >= 0.9 + + with pytest.raises(AttributeError): + model_container.train(([X_train], Y_train)) + with pytest.raises(AttributeError): + model_container.train((X_train, [Y_train])) + X_train = np.zeros((10, 2, 2)) + with pytest.raises(AttributeError): + model_container.train((X_train, Y_train)) + X_train = np.zeros((10, 2)) + Y_train = np.zeros((10, 2, 2)) + with pytest.raises(AttributeError): + model_container.train((X_train, Y_train)) + + X_train = np.zeros((10, 2, 2)) + Y_train = np.zeros((10, 2)) + data_generator = ToyDataGenerator(X_train, Y_train) + with pytest.raises(AttributeError): + model_container.train(data_generator) + + X_train = np.zeros((10, 2)) + Y_train = np.zeros((10, 2, 2)) + data_generator = ToyDataGenerator(X_train, Y_train) + with pytest.raises(AttributeError): + model_container.train(data_generator) + + X_train = np.zeros((10, 2)) + Y_train = [] + data_generator = ToyDataGenerator(X_train, Y_train) + with pytest.raises(AttributeError): + model_container.train(data_generator) + + X_train = [] + Y_train = np.zeros((10, 2)) + data_generator = ToyDataGenerator(X_train, Y_train) + with pytest.raises(AttributeError): + model_container.train(data_generator) + +@pytest.mark.skipif(torch is None, reason="sklearn is not installed") +def test_sklearn_evaluate(): + model_path = './' + model_file = os.path.join(model_path, 'model.skl') + _clean(model_file) + model = RandomForestClassifier() + model_container = SklearnModelContainer(model, metrics=['sed']) + + X_train = np.concatenate((-np.ones((100, 4)), np.ones((100, 4))), axis=0) + Y_train = np.zeros((200, )) + Y_train[:100] = 0 + Y_train[100:] = 1 + X_val1 = -np.ones((1, 4)) + Y_val1 = np.zeros((1, 2)) + Y_val1[0, 0] = 1 + + X_val2 = np.ones((1, 4)) + Y_val2 = np.zeros((1, 2)) + Y_val2[0, 1] = 1 + + model_container.train((X_train, Y_train)) + results = model_container.evaluate(([X_val1, X_val2], [Y_val1, Y_val2]), label_list=['0', '1']) + assert results['sed'].results()["overall"]["f_measure"]["f_measure"] >= 0.0 + + +@pytest.mark.skipif(torch is None, reason="sklearn is not installed") +def test_sklearn_predict(): + model_path = './' + model_file = os.path.join(model_path, 'model.skl') + _clean(model_file) + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + + X_train = np.concatenate((-np.ones((100, 4)), np.ones((100, 4))), axis=0) + Y_train = np.zeros((200, )) + Y_train[:100] = 0 + Y_train[100:] = 1 + X_val1 = -np.ones((1, 4)) + Y_val1 = np.zeros((1, 2)) + Y_val1[0, 0] = 1 + + model_container.train((X_train, Y_train)) + pred = model_container.predict(X_val1) + assert pred.shape == (1, 2) + + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + + X_train = np.concatenate((-np.ones((100, 4)), np.ones((100, 4))), axis=0) + Y_train = np.zeros((200, 2)) + Y_train[:100, 0] = 1 + Y_train[100:, 1] = 1 + + model_container.train((X_train, Y_train)) + + pred = model_container.predict(X_val1) + assert pred.shape == (1, 2) + +@pytest.mark.skipif(sklearn is None, reason="sklearn is not installed") +def test_sklearn_get_number_of_parameters(): + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + assert model_container.get_number_of_parameters() == len(model.get_params()) + +@pytest.mark.skipif(sklearn is None, reason="sklearn is not installed") +def test_sklearn_cut_network(): + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + with pytest.raises(NotImplementedError): + model_container.cut_network(None) + +@pytest.mark.skipif(sklearn is None, reason="sklearn is not installed") +def test_sklearn_fine_tuning(): + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + with pytest.raises(NotImplementedError): + model_container.fine_tuning(None) + +@pytest.mark.skipif(sklearn is None, reason="sklearn is not installed") +def test_sklearn_get_available_intermediate_outputs(): + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + with pytest.raises(NotImplementedError): + model_container.get_available_intermediate_outputs() + +@pytest.mark.skipif(sklearn is None, reason="sklearn is not installed") +def test_sklearn_get_intermediate_output(): + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + with pytest.raises(NotImplementedError): + model_container.get_intermediate_output(None, None) + +@pytest.mark.skipif(sklearn is None, reason="sklearn is not installed") +def test_sklearn_load_pretrained_model_weights(): + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + with pytest.raises(NotImplementedError): + model_container.load_pretrained_model_weights() + +@pytest.mark.skipif(sklearn is None, reason="sklearn is not installed") +def test_sklearn_build(): + model = RandomForestClassifier() + model_container = SklearnModelContainer(model) + with pytest.raises(NotImplementedError): + model_container.build() diff --git a/tests/test_features.py b/tests/test_features.py index 8a72a1f..36810e5 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -17,10 +17,14 @@ import shutil import librosa -import tensorflow as tf - -tensorflow2 = tf.__version__.split(".")[0] == "2" +from dcase_models.backend import backends +if ('tensorflow1' in backends) | ('tensorflow2' in backends): + import tensorflow + tensorflow2 = tensorflow.__version__.split(".")[0] == "2" +else: + tensorflow2 = False + tensorflow = None def test_spectrogram(): feature_extractor = Spectrogram(pad_mode="constant") @@ -100,6 +104,7 @@ def test_frames_audio(): assert shape == (1, 32, 1024) +@pytest.mark.skipif(tensorflow is None, reason="TensorFlow is not installed") def test_vggish(): feature_extractor = VGGishEmbeddings() shape = feature_extractor.get_shape(2.0) diff --git a/tests/test_models.py b/tests/test_models.py index b320d71..bb94664 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,3 @@ -from dcase_models.model.container import KerasModelContainer -from dcase_models.data.features import MelSpectrogram, Spectrogram -from dcase_models.data.data_generator import DataGenerator - from dcase_models.model.models import ( MLP, SB_CNN, @@ -13,12 +9,16 @@ ConcatenatedModel, ) -import os import numpy as np import pytest -import tensorflow as tf -tensorflow2 = tf.__version__.split(".")[0] == "2" +from dcase_models.backend import backends + +if ('tensorflow1' in backends) | ('tensorflow2' in backends): + import tensorflow + tensorflow2 = tensorflow.__version__.split(".")[0] == "2" +else: + tensorflow = None try: import autopool @@ -26,6 +26,7 @@ autopool = None +@pytest.mark.skipif(tensorflow is None, reason="TensorFlow is not installed") def test_mlp(): model_container = MLP() assert len(model_container.model.layers) == 7 @@ -53,6 +54,7 @@ def test_mlp(): assert outputs.shape == (3, 10) +@pytest.mark.skipif(tensorflow is None, reason="TensorFlow is not installed") def test_sb_cnn(): model_container = SB_CNN() assert len(model_container.model.layers) == 15 @@ -67,6 +69,7 @@ def test_sb_cnn(): assert outputs.shape == (3, 10) +@pytest.mark.skipif(tensorflow is None, reason="TensorFlow is not installed") def test_sb_cnn_sed(): model_container = SB_CNN_SED() assert len(model_container.model.layers) == 15 @@ -81,6 +84,7 @@ def test_sb_cnn_sed(): assert outputs.shape == (3, 10) +@pytest.mark.skipif(tensorflow is None, reason="TensorFlow is not installed") def test_a_crnn(): model_container = A_CRNN() assert len(model_container.model.layers) == 25 @@ -101,6 +105,7 @@ def test_a_crnn(): assert outputs.shape == (3, 64, 10) +@pytest.mark.skipif(tensorflow is None, reason="TensorFlow is not installed") def test_vggish(): model_container = VGGish() assert len(model_container.model.layers) == 13 @@ -109,6 +114,7 @@ def test_vggish(): assert outputs.shape == (3, 512) +@pytest.mark.skipif(tensorflow is None, reason="TensorFlow is not installed") def test_smel(): model_container = SMel() assert len(model_container.model.layers) == 6 @@ -117,6 +123,7 @@ def test_smel(): assert outputs.shape == (3, 64, 128) +@pytest.mark.skipif(tensorflow is None, reason="TensorFlow is not installed") def test_mst(): model_container = MST() assert len(model_container.model.layers) == 11 @@ -125,6 +132,7 @@ def test_mst(): assert outputs.shape == (3, 44, 128) +@pytest.mark.skipif(tensorflow is None, reason="TensorFlow is not installed") def test_concatenated_model(): model_mst = MST() model_cnn = SB_CNN_SED(n_frames_cnn=44, n_freq_cnn=128)