diff --git a/project/01_2_performance_plots.ipynb b/project/01_2_performance_plots.ipynb index d22c2c200..0625bfb37 100644 --- a/project/01_2_performance_plots.ipynb +++ b/project/01_2_performance_plots.ipynb @@ -522,7 +522,7 @@ "outputs": [], "source": [ "COLORS_TO_USE = vaep.plotting.defaults.assign_colors(list(k.upper() for k in ORDER_MODELS))\n", - "sns.color_palette(COLORS_TO_USE)" + "vaep.plotting.defaults.ModelColorVisualizer(ORDER_MODELS, COLORS_TO_USE)" ] }, { diff --git a/project/01_2_performance_plots.py b/project/01_2_performance_plots.py index 7150f8e33..71cfab04a 100644 --- a/project/01_2_performance_plots.py +++ b/project/01_2_performance_plots.py @@ -265,7 +265,7 @@ def build_text(s): # %% COLORS_TO_USE = vaep.plotting.defaults.assign_colors(list(k.upper() for k in ORDER_MODELS)) -sns.color_palette(COLORS_TO_USE) +vaep.plotting.defaults.ModelColorVisualizer(ORDER_MODELS, COLORS_TO_USE) # %% TOP_N_ORDER = ORDER_MODELS[:args.plot_to_n] diff --git a/vaep/plotting/defaults.py b/vaep/plotting/defaults.py index d6e2b5b81..f4a470abe 100644 --- a/vaep/plotting/defaults.py +++ b/vaep/plotting/defaults.py @@ -1,4 +1,5 @@ import logging +import matplotlib as mpl import seaborn as sns logger = logging.getLogger(__name__) @@ -22,10 +23,11 @@ # other_colors = sns.color_palette()[8:] other_colors = sns.color_palette("husl", 20) color_model_mapping['IMPSEQ'] = other_colors[0] +color_model_mapping['QRILC'] = other_colors[1] color_model_mapping['IMPSEQROB'] = other_colors[1] color_model_mapping['MICE-NORM'] = other_colors[2] color_model_mapping['SEQKNN'] = other_colors[3] -color_model_mapping['QRILC'] = other_colors[4] +color_model_mapping['IMPSEQROB'] = other_colors[4] color_model_mapping['GSIMP'] = other_colors[5] color_model_mapping['MSIMPUTE'] = other_colors[6] color_model_mapping['MSIMPUTE_MNAR'] = other_colors[7] @@ -49,6 +51,32 @@ def assign_colors(models): return ret_colors +class ModelColorVisualizer: + + def __init__(self, models, palette): + self.models = models + self.palette = map(mpl.colors.colorConverter.to_rgb, palette) + + def as_hex(self): + """Return a color palette with hex codes instead of RGB values.""" + hex = [mpl.colors.rgb2hex(rgb) for rgb in self.palette] + return hex + + def _repr_html_(self): + """Rich display of the color palette in an HTML frontend.""" + s = 55 + n = len(self.models) + html = f'' + for i, (m, c) in enumerate(zip(self.models, self.as_hex())): + html += ( + f'' + ) + html += f'{m}' + html += '' + return html + + labels_dict = {"NA not interpolated valid_collab collab MSE": 'MSE', 'batch_size': 'bs', 'n_hidden_layers': "No. of hidden layers",