Skip to content

Commit

Permalink
BUG: Adds missing X and Y axes labels in ClassificationReport (#1210)
Browse files Browse the repository at this point in the history
Co-authored-by: Pkaf <[email protected]>
Co-authored-by: Larry Gray <[email protected]>
  • Loading branch information
3 people authored Feb 19, 2022
1 parent 062ad14 commit d866a21
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 0 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions tests/test_contrib/test_prepredict.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
from yellowbrick.contrib.prepredict import *
from yellowbrick.regressor import PredictionError
from yellowbrick.classifier import ClassificationReport
import numpy as np

# Set random state
np.random.seed()

##########################################################################
## Fixtures
Expand Down
30 changes: 30 additions & 0 deletions yellowbrick/classifier/classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ class ClassificationReport(ClassificationScoreVisualizer):
colorbar : bool, default: True
Specify if the color bar should be present
fontsize : int or None, default: None
Specify the font size of the x and y labels
kwargs : dict
Keyword arguments passed to the visualizer base classes.
Expand Down Expand Up @@ -136,6 +139,7 @@ def __init__(
is_fitted="auto",
force_model=False,
colorbar=True,
fontsize=None,
**kwargs
):
super(ClassificationReport, self).__init__(
Expand All @@ -154,6 +158,7 @@ def __init__(
self.cmap.set_over(color=CMAP_OVERCOLOR)
self.cmap.set_under(color=CMAP_UNDERCOLOR)
self._displayed_scores = [key for key in SCORES_KEYS]
self.fontsize=fontsize

if support not in {None, True, False, "percent", "count"}:
raise YellowbrickValueError(
Expand Down Expand Up @@ -228,6 +233,26 @@ def draw(self):
self.ax.set_ylim(bottom=0, top=cr_display.shape[0])
self.ax.set_xlim(left=0, right=cr_display.shape[1])

# Get the human readable labels
labels = self._labels()
if labels is None:
labels = self.classes_

# Fetch the grid labels from the classes in correct order; set ticks.
xticklabels = self._displayed_scores
yticklabels = labels[::-1]

yticks = np.arange(len(labels)) + 0.5
xticks = np.arange(len(self._displayed_scores)) + 0.5

self.ax.set(yticks=yticks, xticks=xticks)

self.ax.set_xticklabels(
xticklabels, rotation=45, fontsize=self.fontsize
)
self.ax.set_yticklabels(yticklabels, fontsize=self.fontsize)


# Set data labels in the grid, enumerating over class, metric pairs
# NOTE: X and Y are one element longer than the classification report
# so skip the last element to label the grid correctly.
Expand Down Expand Up @@ -309,6 +334,7 @@ def classification_report(
force_model=False,
show=True,
colorbar=True,
fontsize=None,
**kwargs
):
"""Classification Report
Expand Down Expand Up @@ -386,6 +412,9 @@ def classification_report(
colorbar : bool, default: True
Specify if the color bar should be present
fontsize : int or None, default: None
Specify the font size of the x and y labels
kwargs : dict
Keyword arguments passed to the visualizer base classes.
Expand All @@ -405,6 +434,7 @@ def classification_report(
is_fitted=is_fitted,
force_model=force_model,
colorbar=colorbar,
fontsize=fontsize,
**kwargs
)

Expand Down

0 comments on commit d866a21

Please sign in to comment.