Skip to content

Commit

Permalink
Merge pull request #13 from sdevenes/feature/algorithm
Browse files Browse the repository at this point in the history
Feature/algorithm
  • Loading branch information
spanoamara authored Sep 20, 2020
2 parents a018adc + 2bdaf16 commit 1f3742b
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 1 deletion.
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
requests
requests
sklearn
plotly.express
39 changes: 39 additions & 0 deletions scripts/algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
import numpy as np

import logging
logger = logging.getLogger()


def make_labels(X):
return np.hstack([k*np.ones(len(X[k]), dtype=int) for k in range(len(X))])


class Model:
def __init__(self, nb_tree_per_forest=50, max_depth=10):
# Create a random forest model
self.model = RandomForestClassifier(n_estimators=nb_tree_per_forest, max_depth=max_depth,
random_state=0)


def train(self, X):

# Get features
X_features = np.vstack([k for k in X])

# Get labels
y = make_labels(X)

# Train the model
self.model.fit(X_features, y)


def predict(self, X):
# Get features
X_features = np.vstack([k for k in X])

# Predict using the trained model
prediction = self.model.predict(X_features)

return prediction
28 changes: 28 additions & 0 deletions scripts/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import plotly.express as px
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools

def get_confusion_matrix(prediction_label, true_label):
return confusion_matrix(true_label, prediction_label)

def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix'):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
# Normalize if wanted
if normalize:
cm = cm / np.sum(cm)

# Create confusion matrix graph
fig = px.imshow(cm,
labels=dict(x="Predicted label", y="True label", color="value"),
x=classes,
y=classes,
color_continuous_scale='Blues',
title=title
)
fig.show()
51 changes: 51 additions & 0 deletions scripts/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import algorithm
import database
import analysis
import numpy as np

def test1():
x = np.arange(len(database.CLASSES))
cm = np.dot(x.reshape(len(x),1), x.reshape(1,len(x)))
print(cm)
print(database.CLASSES)
analysis.plot_confusion_matrix(cm, database.CLASSES)

def test():

# Import train data
train = database.get("proto1", 'train')
#print(train)

# Prepare train data
# norm = preprocessor.estimate_norm(numpy.vstack(train))
# train_normed = preprocessor.normalize(train, norm)

# Train algo
model = algorithm.Model()
model.train(train)

# Import test data
test = database.get('proto1', 'test')

# Prepare test data
# test_normed = preprocessor.normalize(test, norm)

# Make prediction
test_predictions = model.predict(test)
print(test_predictions)

# Get real labels
test_labels = algorithm.make_labels(test).astype(int)

# Get confusion matrix
cm = analysis.get_confusion_matrix(test_predictions, test_labels)

# Plot confusion matrix
analysis.plot_confusion_matrix(cm, database.CLASSES, normalize=True)

# Plot confusion matrix (ignore other activities)
# analysis.plot_confusion_matrix(cm[1:, 1:], database.CLASSES[1:], normalize=True)


if __name__ == '__main__':
test()

0 comments on commit 1f3742b

Please sign in to comment.