-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from sdevenes/feature/algorithm
Feature/algorithm
- Loading branch information
Showing
4 changed files
with
121 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
requests | ||
requests | ||
sklearn | ||
plotly.express |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.metrics import accuracy_score, confusion_matrix | ||
import numpy as np | ||
|
||
import logging | ||
logger = logging.getLogger() | ||
|
||
|
||
def make_labels(X): | ||
return np.hstack([k*np.ones(len(X[k]), dtype=int) for k in range(len(X))]) | ||
|
||
|
||
class Model: | ||
def __init__(self, nb_tree_per_forest=50, max_depth=10): | ||
# Create a random forest model | ||
self.model = RandomForestClassifier(n_estimators=nb_tree_per_forest, max_depth=max_depth, | ||
random_state=0) | ||
|
||
|
||
def train(self, X): | ||
|
||
# Get features | ||
X_features = np.vstack([k for k in X]) | ||
|
||
# Get labels | ||
y = make_labels(X) | ||
|
||
# Train the model | ||
self.model.fit(X_features, y) | ||
|
||
|
||
def predict(self, X): | ||
# Get features | ||
X_features = np.vstack([k for k in X]) | ||
|
||
# Predict using the trained model | ||
prediction = self.model.predict(X_features) | ||
|
||
return prediction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import plotly.express as px | ||
from sklearn.metrics import confusion_matrix | ||
import numpy as np | ||
import itertools | ||
|
||
def get_confusion_matrix(prediction_label, true_label): | ||
return confusion_matrix(true_label, prediction_label) | ||
|
||
def plot_confusion_matrix(cm, classes, | ||
normalize=False, | ||
title='Confusion matrix'): | ||
""" | ||
This function prints and plots the confusion matrix. | ||
Normalization can be applied by setting `normalize=True`. | ||
""" | ||
# Normalize if wanted | ||
if normalize: | ||
cm = cm / np.sum(cm) | ||
|
||
# Create confusion matrix graph | ||
fig = px.imshow(cm, | ||
labels=dict(x="Predicted label", y="True label", color="value"), | ||
x=classes, | ||
y=classes, | ||
color_continuous_scale='Blues', | ||
title=title | ||
) | ||
fig.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import algorithm | ||
import database | ||
import analysis | ||
import numpy as np | ||
|
||
def test1(): | ||
x = np.arange(len(database.CLASSES)) | ||
cm = np.dot(x.reshape(len(x),1), x.reshape(1,len(x))) | ||
print(cm) | ||
print(database.CLASSES) | ||
analysis.plot_confusion_matrix(cm, database.CLASSES) | ||
|
||
def test(): | ||
|
||
# Import train data | ||
train = database.get("proto1", 'train') | ||
#print(train) | ||
|
||
# Prepare train data | ||
# norm = preprocessor.estimate_norm(numpy.vstack(train)) | ||
# train_normed = preprocessor.normalize(train, norm) | ||
|
||
# Train algo | ||
model = algorithm.Model() | ||
model.train(train) | ||
|
||
# Import test data | ||
test = database.get('proto1', 'test') | ||
|
||
# Prepare test data | ||
# test_normed = preprocessor.normalize(test, norm) | ||
|
||
# Make prediction | ||
test_predictions = model.predict(test) | ||
print(test_predictions) | ||
|
||
# Get real labels | ||
test_labels = algorithm.make_labels(test).astype(int) | ||
|
||
# Get confusion matrix | ||
cm = analysis.get_confusion_matrix(test_predictions, test_labels) | ||
|
||
# Plot confusion matrix | ||
analysis.plot_confusion_matrix(cm, database.CLASSES, normalize=True) | ||
|
||
# Plot confusion matrix (ignore other activities) | ||
# analysis.plot_confusion_matrix(cm[1:, 1:], database.CLASSES[1:], normalize=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
test() |