Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding pytorch and scikit-klearn as backends #24

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
- name: Install library
run: |
python -m pip install .[keras_tf]
python -m pip install .[torch]
python -m pip install .[openl3]
python -m pip install .[autopool]
python -m pip install .[tests]
Expand Down Expand Up @@ -102,3 +103,27 @@ jobs:
- name: Test with pytest
run: |
pytest --cov-report term-missing --cov-report=xml --cov dcase_models ./tests

build_torch:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get install -y wget libsndfile-dev sox
- name: Install library
run: |
python -m pip install .[torch]
python -m pip install .[tests]
- name: Test with pytest
run: |
pytest --cov-report term-missing --cov-report=xml --cov dcase_models ./tests
25 changes: 25 additions & 0 deletions dcase_models/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
backends = []

try:
import tensorflow as tf

tensorflow_version = '2' if tf.__version__.split(".")[0] == "2" else '1'

if tensorflow_version == '2':
backends.append('tensorflow2')
else:
backends.append('tensorflow1')
except:
tensorflow_version = None

try:
import torch
backends.append('torch')
except:
torch = None

try:
import sklearn
backends.append('sklearn')
except:
sklearn = None
92 changes: 73 additions & 19 deletions dcase_models/data/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,26 @@
import inspect
import random

import tensorflow as tf
tensorflow2 = tf.__version__.split('.')[0] == '2'
from dcase_models.backend import backends

if tensorflow2:
from tensorflow.keras.utils import Sequence
if 'torch' in backends:
import torch
from torch.utils.data import Dataset as TorchDataset
else:
from keras.utils import Sequence
class TorchDataset:
pass

if ('tensorflow1' in backends) | ('tensorflow2' in backends):
import tensorflow as tf
tensorflow2 = tf.__version__.split('.')[0] == '2'

if tensorflow2:
from tensorflow.keras.utils import Sequence
else:
from keras.utils import Sequence
else:
class Sequence:
pass

from dcase_models.data.feature_extractor import FeatureExtractor
from dcase_models.data.dataset_base import Dataset
Expand Down Expand Up @@ -523,21 +536,62 @@ def set_scaler_outputs(self, scaler_outputs):
self.scaler_outputs = scaler_outputs


class KerasDataGenerator(Sequence):
if ('tensorflow1' in backends) | ('tensorflow2' in backends):
class KerasDataGenerator(Sequence):

def __init__(self, data_generator):
self.data_gen = data_generator
self.data_gen.shuffle_list()
def __init__(self, data_generator):
self.data_gen = data_generator
self.data_gen.shuffle_list()

def __len__(self):
'Denotes the number of batches per epoch'
return len(self.data_gen)
def __len__(self):
'Denotes the number of batches per epoch'
return len(self.data_gen)

def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
return self.data_gen.get_data_batch(index)
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
return self.data_gen.get_data_batch(index)

def on_epoch_end(self):
'Updates indexes after each epoch'
self.data_gen.shuffle_list()
def on_epoch_end(self):
'Updates indexes after each epoch'
self.data_gen.shuffle_list()
else:
class KerasDataGenerator():
def __init__(self, data_generator):
raise ImportError("Tensorflow is not installed")


if 'torch' in backends:
class PyTorchDataGenerator(TorchDataset):

def __init__(self, data_generator):
self.data_gen = data_generator
self.data_gen.shuffle_list()

def __len__(self):
'Denotes the number of batches per epoch'
return len(self.data_gen)

def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
X, Y = self.data_gen.get_data_batch(index)
if type(X) is not list:
X = [X]
if type(Y) is not list:
Y = [Y]
tensor_X = []
tensor_Y = []
for j in range(len(X)):
tensor_X.append(torch.tensor(X[j], dtype=torch.float))
for j in range(len(Y)):
tensor_Y.append(torch.tensor(Y[j], dtype=torch.long))
return tensor_X, tensor_Y

def shuffle_list(self):
'Updates indexes after each epoch'
self.data_gen.shuffle_list()
else:
class PyTorchDataGenerator():
def __init__(self, data_generator):
raise ImportError("Pytorch is not installed")
Loading