diff --git a/tmap/tda/plot.py b/tmap/tda/plot.py index 9aab6e5..4a03946 100644 --- a/tmap/tda/plot.py +++ b/tmap/tda/plot.py @@ -56,10 +56,13 @@ def __init__(self, target, dtype="numerical", target_by="sample"): if ((type(target[0][0]) != int) and (type(target[0][0]) != float) - and (not isinstance(target[0][0],np.number)) + and (not isinstance(target[0][0], np.number)) ): self.label_encoder = LabelEncoder() self.target = self.label_encoder.fit_transform(target) + elif dtype == "categorical": + self.label_encoder = LabelEncoder() + self.target = self.label_encoder.fit_transform(target.astype(str)) else: self.label_encoder = None self.target = target @@ -191,14 +194,15 @@ def show(data, graph, color=None, fig_size=(10, 10), node_size=10, edge_width=2, legend_lookup = dict(zip(node_target_values.reshape(-1,), node_colors)) # add categorical legend - if isinstance(color,Color): + if isinstance(color, Color): if color.dtype == "categorical": for label in set([it[0] for it in color.labels]): if color.label_encoder: - label_color = legend_lookup[color.label_encoder.transform([label])[0]] + label_color = legend_lookup.get(color.label_encoder.transform([label])[0], None) else: - label_color = legend_lookup[label] - ax.plot([], [], 'o', color=label_color, label=label, markersize=10) + label_color = legend_lookup.get(label, None) + if label_color is not None: + ax.plot([], [], 'o', color=label_color, label=label, markersize=10) legend = ax.legend(numpoints=1, loc="upper right") legend.get_frame().set_facecolor('#bebebe')