Skip to content

Commit

Permalink
replace confusion matrix plot using plotly.express
Browse files Browse the repository at this point in the history
  • Loading branch information
sdevenes committed Sep 14, 2020
1 parent 66cecef commit 2bdaf16
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 30 deletions.
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
requests
requests
sklearn
plotly.express
40 changes: 14 additions & 26 deletions scripts/analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import matplotlib.pyplot as plt
import plotly.express as px
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools
Expand All @@ -8,33 +8,21 @@ def get_confusion_matrix(prediction_label, true_label):

def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
title='Confusion matrix'):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
# Create plot
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
# Normalize if wanted
if normalize:
cm = cm / np.sum(cm)

# Update axis value
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)

# Normalize if needed
fmt = '.2f' if normalize else 'd'
thresh = cm.max() *2./3
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")

# Update axis label
# plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')

plt.show()
# Create confusion matrix graph
fig = px.imshow(cm,
labels=dict(x="Predicted label", y="True label", color="value"),
x=classes,
y=classes,
color_continuous_scale='Blues',
title=title
)
fig.show()
8 changes: 5 additions & 3 deletions scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def test1():
x = np.arange(len(database.CLASSES))
cm = np.dot(x.reshape(len(x),1), x.reshape(1,len(x)))
print(cm)
print(database.CLASSES)
analysis.plot_confusion_matrix(cm, database.CLASSES)

def test():
Expand Down Expand Up @@ -40,10 +41,11 @@ def test():
cm = analysis.get_confusion_matrix(test_predictions, test_labels)

# Plot confusion matrix
analysis.plot_confusion_matrix(cm, database.CLASSES)

analysis.plot_confusion_matrix(cm, database.CLASSES, normalize=True)

# Plot confusion matrix (ignore other activities)
# analysis.plot_confusion_matrix(cm[1:, 1:], database.CLASSES[1:], normalize=True)


if __name__ == '__main__':
test1()
test()

0 comments on commit 2bdaf16

Please sign in to comment.