diff --git a/tests/baseline_images/test_classifier/test_classification_report/test_quick_method.png b/tests/baseline_images/test_classifier/test_classification_report/test_quick_method.png index 9764abfa8..e6ec1b4fc 100644 Binary files a/tests/baseline_images/test_classifier/test_classification_report/test_quick_method.png and b/tests/baseline_images/test_classifier/test_classification_report/test_quick_method.png differ diff --git a/tests/baseline_images/test_contrib/test_prepredict/test_prepredict_classifier.png b/tests/baseline_images/test_contrib/test_prepredict/test_prepredict_classifier.png index bd810c8b3..fb72408fb 100644 Binary files a/tests/baseline_images/test_contrib/test_prepredict/test_prepredict_classifier.png and b/tests/baseline_images/test_contrib/test_prepredict/test_prepredict_classifier.png differ diff --git a/tests/baseline_images/test_contrib/test_prepredict/test_prepredict_regressor.png b/tests/baseline_images/test_contrib/test_prepredict/test_prepredict_regressor.png index 6ae06abdc..133221d39 100644 Binary files a/tests/baseline_images/test_contrib/test_prepredict/test_prepredict_regressor.png and b/tests/baseline_images/test_contrib/test_prepredict/test_prepredict_regressor.png differ diff --git a/tests/test_contrib/test_prepredict.py b/tests/test_contrib/test_prepredict.py index 22ea7efe7..cc930b4f7 100644 --- a/tests/test_contrib/test_prepredict.py +++ b/tests/test_contrib/test_prepredict.py @@ -29,7 +29,10 @@ from yellowbrick.contrib.prepredict import * from yellowbrick.regressor import PredictionError from yellowbrick.classifier import ClassificationReport +import numpy as np +# Set random state +np.random.seed() ########################################################################## ## Fixtures diff --git a/yellowbrick/classifier/classification_report.py b/yellowbrick/classifier/classification_report.py index 49bd2838c..fca128e1b 100644 --- a/yellowbrick/classifier/classification_report.py +++ b/yellowbrick/classifier/classification_report.py @@ -94,6 +94,9 @@ class ClassificationReport(ClassificationScoreVisualizer): colorbar : bool, default: True Specify if the color bar should be present + fontsize : int or None, default: None + Specify the font size of the x and y labels + kwargs : dict Keyword arguments passed to the visualizer base classes. @@ -136,6 +139,7 @@ def __init__( is_fitted="auto", force_model=False, colorbar=True, + fontsize=None, **kwargs ): super(ClassificationReport, self).__init__( @@ -154,6 +158,7 @@ def __init__( self.cmap.set_over(color=CMAP_OVERCOLOR) self.cmap.set_under(color=CMAP_UNDERCOLOR) self._displayed_scores = [key for key in SCORES_KEYS] + self.fontsize=fontsize if support not in {None, True, False, "percent", "count"}: raise YellowbrickValueError( @@ -228,6 +233,26 @@ def draw(self): self.ax.set_ylim(bottom=0, top=cr_display.shape[0]) self.ax.set_xlim(left=0, right=cr_display.shape[1]) + # Get the human readable labels + labels = self._labels() + if labels is None: + labels = self.classes_ + + # Fetch the grid labels from the classes in correct order; set ticks. + xticklabels = self._displayed_scores + yticklabels = labels[::-1] + + yticks = np.arange(len(labels)) + 0.5 + xticks = np.arange(len(self._displayed_scores)) + 0.5 + + self.ax.set(yticks=yticks, xticks=xticks) + + self.ax.set_xticklabels( + xticklabels, rotation=45, fontsize=self.fontsize + ) + self.ax.set_yticklabels(yticklabels, fontsize=self.fontsize) + + # Set data labels in the grid, enumerating over class, metric pairs # NOTE: X and Y are one element longer than the classification report # so skip the last element to label the grid correctly. @@ -309,6 +334,7 @@ def classification_report( force_model=False, show=True, colorbar=True, + fontsize=None, **kwargs ): """Classification Report @@ -386,6 +412,9 @@ def classification_report( colorbar : bool, default: True Specify if the color bar should be present + fontsize : int or None, default: None + Specify the font size of the x and y labels + kwargs : dict Keyword arguments passed to the visualizer base classes. @@ -405,6 +434,7 @@ def classification_report( is_fitted=is_fitted, force_model=force_model, colorbar=colorbar, + fontsize=fontsize, **kwargs )