From 62b7be04c1f20cbd9f68897444afe697c03f2224 Mon Sep 17 00:00:00 2001 From: Sravani Nanduri Date: Fri, 15 Mar 2024 14:58:31 -0700 Subject: [PATCH] adding changes, no seaborn --- src/pathogen_embed/pathogen_embed.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/pathogen_embed/pathogen_embed.py b/src/pathogen_embed/pathogen_embed.py index d2c1190..3e7b9c6 100644 --- a/src/pathogen_embed/pathogen_embed.py +++ b/src/pathogen_embed/pathogen_embed.py @@ -337,12 +337,9 @@ def embed(args): } plot_df = pd.DataFrame(plot_data) - ax = sns.scatterplot( - data=plot_df, - x="x", - y="y", - alpha=0.5, - ) + plt.scatter(plot_df["x"], plot_df["y"], alpha=0.5) + plt.xlabel("x") + plt.ylabel("y") plt.savefig(args.output_figure) plt.close() @@ -371,13 +368,15 @@ def cluster(args): plot_data["cluster"] = clusterer.labels_.astype(str) plot_df = pd.DataFrame(plot_data) - ax = sns.scatterplot( - data=plot_df, - x="x", - y="y", - hue="cluster", - alpha=0.5, - ) + clusters = plot_df['cluster'].unique() + colors = plt.cm.tab10.colors[:len(clusters)] + for i, cluster in enumerate(clusters): + cluster_data = plot_df[plot_df['cluster'] == cluster] + plt.scatter(cluster_data["x"], cluster_data["y"], color=colors[i], label=f'Cluster {cluster}', alpha=0.5) + + plt.xlabel("x") + plt.ylabel("y") + plt.legend() plt.savefig(args.output_figure) plt.close()