diff --git a/requirements.txt b/requirements.txt index 663bd1f..d5d80e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,3 @@ -requests \ No newline at end of file +requests +sklearn +plotly.express \ No newline at end of file diff --git a/scripts/algorithm.py b/scripts/algorithm.py new file mode 100644 index 0000000..0403dba --- /dev/null +++ b/scripts/algorithm.py @@ -0,0 +1,39 @@ +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import accuracy_score, confusion_matrix +import numpy as np + +import logging +logger = logging.getLogger() + + +def make_labels(X): + return np.hstack([k*np.ones(len(X[k]), dtype=int) for k in range(len(X))]) + + +class Model: + def __init__(self, nb_tree_per_forest=50, max_depth=10): + # Create a random forest model + self.model = RandomForestClassifier(n_estimators=nb_tree_per_forest, max_depth=max_depth, + random_state=0) + + + def train(self, X): + + # Get features + X_features = np.vstack([k for k in X]) + + # Get labels + y = make_labels(X) + + # Train the model + self.model.fit(X_features, y) + + + def predict(self, X): + # Get features + X_features = np.vstack([k for k in X]) + + # Predict using the trained model + prediction = self.model.predict(X_features) + + return prediction \ No newline at end of file diff --git a/scripts/analysis.py b/scripts/analysis.py new file mode 100644 index 0000000..168dda7 --- /dev/null +++ b/scripts/analysis.py @@ -0,0 +1,28 @@ +import plotly.express as px +from sklearn.metrics import confusion_matrix +import numpy as np +import itertools + +def get_confusion_matrix(prediction_label, true_label): + return confusion_matrix(true_label, prediction_label) + +def plot_confusion_matrix(cm, classes, + normalize=False, + title='Confusion matrix'): + """ + This function prints and plots the confusion matrix. + Normalization can be applied by setting `normalize=True`. + """ + # Normalize if wanted + if normalize: + cm = cm / np.sum(cm) + + # Create confusion matrix graph + fig = px.imshow(cm, + labels=dict(x="Predicted label", y="True label", color="value"), + x=classes, + y=classes, + color_continuous_scale='Blues', + title=title + ) + fig.show() diff --git a/scripts/test.py b/scripts/test.py new file mode 100644 index 0000000..93f926f --- /dev/null +++ b/scripts/test.py @@ -0,0 +1,51 @@ +import algorithm +import database +import analysis +import numpy as np + +def test1(): + x = np.arange(len(database.CLASSES)) + cm = np.dot(x.reshape(len(x),1), x.reshape(1,len(x))) + print(cm) + print(database.CLASSES) + analysis.plot_confusion_matrix(cm, database.CLASSES) + +def test(): + + # Import train data + train = database.get("proto1", 'train') + #print(train) + + # Prepare train data + # norm = preprocessor.estimate_norm(numpy.vstack(train)) + # train_normed = preprocessor.normalize(train, norm) + + # Train algo + model = algorithm.Model() + model.train(train) + + # Import test data + test = database.get('proto1', 'test') + + # Prepare test data + # test_normed = preprocessor.normalize(test, norm) + + # Make prediction + test_predictions = model.predict(test) + print(test_predictions) + + # Get real labels + test_labels = algorithm.make_labels(test).astype(int) + + # Get confusion matrix + cm = analysis.get_confusion_matrix(test_predictions, test_labels) + + # Plot confusion matrix + analysis.plot_confusion_matrix(cm, database.CLASSES, normalize=True) + + # Plot confusion matrix (ignore other activities) + # analysis.plot_confusion_matrix(cm[1:, 1:], database.CLASSES[1:], normalize=True) + + +if __name__ == '__main__': + test() \ No newline at end of file