Skip to content

Commit

Permalink
Update plot.py
Browse files Browse the repository at this point in the history
Fix missing labels problems and taking number as categorical for visualization problems.
  • Loading branch information
GPZ-Bioinfo authored Aug 27, 2018
1 parent 3501c0a commit 4f10b23
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions tmap/tda/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ def __init__(self, target, dtype="numerical", target_by="sample"):

if ((type(target[0][0]) != int)
and (type(target[0][0]) != float)
and (not isinstance(target[0][0],np.number))
and (not isinstance(target[0][0], np.number))
):
self.label_encoder = LabelEncoder()
self.target = self.label_encoder.fit_transform(target)
elif dtype == "categorical":
self.label_encoder = LabelEncoder()
self.target = self.label_encoder.fit_transform(target.astype(str))
else:
self.label_encoder = None
self.target = target
Expand Down Expand Up @@ -191,14 +194,15 @@ def show(data, graph, color=None, fig_size=(10, 10), node_size=10, edge_width=2,
legend_lookup = dict(zip(node_target_values.reshape(-1,), node_colors))

# add categorical legend
if isinstance(color,Color):
if isinstance(color, Color):
if color.dtype == "categorical":
for label in set([it[0] for it in color.labels]):
if color.label_encoder:
label_color = legend_lookup[color.label_encoder.transform([label])[0]]
label_color = legend_lookup.get(color.label_encoder.transform([label])[0], None)
else:
label_color = legend_lookup[label]
ax.plot([], [], 'o', color=label_color, label=label, markersize=10)
label_color = legend_lookup.get(label, None)
if label_color is not None:
ax.plot([], [], 'o', color=label_color, label=label, markersize=10)
legend = ax.legend(numpoints=1, loc="upper right")
legend.get_frame().set_facecolor('#bebebe')

Expand Down

0 comments on commit 4f10b23

Please sign in to comment.