Skip to content

Commit

Permalink
Automatically adjust the number of subplots based on the number of li…
Browse files Browse the repository at this point in the history
…braries
  • Loading branch information
DoraDong-2023 committed Apr 9, 2024
1 parent 3d011e4 commit bed91ca
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 76 deletions.
66 changes: 37 additions & 29 deletions src/scripts/step4_analysis_retriever.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,49 @@
import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv('plot/retriever_topk_results.csv')

columns_to_keep = ['retrieved_api_nums', 'Validation Accuracy', 'Test Accuracy', 'val ambiguous Accuracy', 'test ambiguous Accuracy'] # 'Training Accuracy', 'Training ambiguous Accuracy',

labels = [
'Synthetic instruction', 'Annotation instruction',
'Synthetic instruction w/ ambiguity removal', 'Annotation instruction w/ ambiguity removal'
]

num_rows = 2
num_cols = 2

libs = df['LIB'].unique()[:num_rows * num_cols]

fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 10))
plt.subplots_adjust(bottom=0.25) # 调整底部空间
"""
Author: Zhengyuan Dong
Created Date: 2024-02-01
Last Edited Date: 2024-04-08
Description:
Plot the comparison of the retriever accuracy results for different libraries and models.
Automatically adjust the number of subplots based on the number of libraries.
"""

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import math

df = pd.read_csv('./output/retriever_topk_results.csv')
columns_to_keep = ['retrieved_api_nums', 'Validation Accuracy', 'Test Accuracy', 'val ambiguous Accuracy', 'test ambiguous Accuracy']
labels = ['Synthetic instruction', 'Annotation instruction',
'Synthetic instruction w/ ambiguity removal', 'Annotation instruction w/ ambiguity removal']
libs = df['LIB'].unique()
num_libs = len(libs)
if num_libs <= 2:
fig, axs = plt.subplots(1, num_libs, figsize=(15 * num_libs, 6))
if num_libs == 1:
axs = np.array([axs])
else:
num_rows = np.ceil(num_libs / 2).astype(int)
fig, axs = plt.subplots(num_rows, 2, figsize=(15, 6 * num_rows))
axs = axs.flatten()
for index, lib in enumerate(libs):
ax = axs.flatten()[index]
lib_df = df[df['LIB'] == lib][columns_to_keep]
ax = axs[index]
lib_df = df[df['LIB'] == lib]
for col, label in zip(columns_to_keep[1:], labels):
ax.plot(lib_df['retrieved_api_nums'], lib_df[col], label=label)
ax.set_title(f'Fine-tuned Retriever Accuracy v.s. Topk for {lib}')
ax.set_xlabel('Topk')
ax.set_ylabel('Accuracy')

# 用flatten()将axs从2D数组转化为1D数组以便简单索引
ax.grid(True)
if num_libs > 1:
for ax in axs[num_libs:]:
ax.set_visible(False)
handles, labels = axs.flatten()[0].get_legend_handles_labels()

# 在整个Figure的底部绘制一个共享的图例
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), ncol=4)

# Tighten the layout with the rect parameter to fit the new legend
plt.tight_layout(rect=[0, 0.1, 1, 1])

plt.savefig('./plot/retriever_acc_lib_4.jpg')
plt.subplots_adjust(hspace=0.4)
plt_path = "./output/step4_analysis_retriever_acc_lib.jpg"
plt.savefig(plt_path)
plt.show()

plt_path
6 changes: 3 additions & 3 deletions src/scripts/step4_analysis_retriever.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# usage: bash -x scripts/step4_analysis_retriever.sh

export HUGGINGPATH=./hugging_models
libs=("scanpy" "squidpy" "ehrapy" "snapatac2")
libs=("scanpy" "squidpy" "ehrapy" "snapatac2") # "scanpy_subset"

mkdir -p output
mkdir -p plot/
csv_file="plot/retriever_topk_results.csv"
mkdir -p output/
csv_file="output/retriever_topk_results.csv"

echo "LIB,retrieved_api_nums,Training Accuracy,Validation Accuracy,Test Accuracy,Training ambiguous Accuracy,val ambiguous Accuracy,test ambiguous Accuracy" > $csv_file

Expand Down
90 changes: 47 additions & 43 deletions src/scripts/step5_analysis_compare_retriever.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,79 @@

"""
Author: Zhengyuan Dong
Created Date: 2024-02-01
Last Edited Date: 2024-04-08
Description:
Plot the comparison of the retriever accuracy results for different libraries and models.
Automatically adjust the number of subplots based on the number of libraries.
"""

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
# Based on the provided data structure and the requirements for subplots, let's create the desired bar graph.
import math

# Load the CSV data
df = pd.read_csv('output/retriever_accuracy_results.csv')
default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
color_b = default_colors[:4]
#color_b = ['blue', 'orange', 'green', 'red']

legend_labels = ['Synthetic instruction', 'Annotated instruction',
'Synthetic instruction w/ ambiguity removal', 'Annotated instruction w/ ambiguity removal']
# Define the function for plotting
def create_subplot_bar_graphs(df, ax, lib, title, colors, display_test=False):
# Columns for BM25, Un-Finetuned, and Finetuned (Val and Test)
cols = [f'{lib} Val', f'{lib} Test', f'{lib} Ambiguous Val', f'{lib} Ambiguous Test']

x_labels = ['BM25', 'SENTENCE-BERT \nw/o \nfine-tuning', 'SENTENCE-BERT \nw/ \nfine-tuning']

# The width of the bars
bar_width = 0.15

# Set the positions of the bars
indices = np.arange(len(x_labels))

# Plot bars
for i, col_prefix in enumerate(['BM25', 'Un-Finetuned', 'Finetuned']):
#print('----------', col_prefix)
val_acc = df.loc[df['LIB'] == lib, f'{col_prefix} Val'].values[0]
test_acc = df.loc[df['LIB'] == lib, f'{col_prefix} Test'].values[0]
#if col_prefix!='BM25':
val_am_acc = df.loc[df['LIB'] == lib, f'{col_prefix} Ambiguous Val'].values[0]
test_am_acc = df.loc[df['LIB'] == lib, f'{col_prefix} Ambiguous Test'].values[0]

ax.bar(indices[i] - 3*bar_width/2, val_acc, bar_width, label='Synthetic instruction' if i==0 else "", color=colors[0])
ax.bar(indices[i] - 3*bar_width/2, val_acc, bar_width, color=colors[0])
if display_test:
ax.bar(indices[i] - bar_width/2, test_acc, bar_width, label='Annotated instruction' if i==0 else "", color=colors[1])
#if col_prefix!='BM25':
ax.bar(indices[i] + bar_width/2, val_am_acc, bar_width, label='Synthetic instruction w/ ambiguity removal' if i==1 else "", color=colors[2])
ax.bar(indices[i] - bar_width/2, test_acc, bar_width, color=colors[1])
ax.bar(indices[i] + bar_width/2, val_am_acc, bar_width, color=colors[2])
if display_test:
ax.bar(indices[i] + 3*bar_width/2, test_am_acc, bar_width, label='Annotated instruction w/ ambiguity removal' if i==1 else "", color=colors[3])

# Add data labels
ax.bar(indices[i] + 3*bar_width/2, test_am_acc, bar_width, color=colors[3])
ax.text(indices[i] - 3*bar_width/2, val_acc, f'{val_acc:.2f}', ha='center', va='bottom', fontsize=8)
if display_test:
ax.text(indices[i] - bar_width/2, test_acc, f'{test_acc:.2f}', ha='center', va='bottom', fontsize=8)
#if col_prefix!='BM25':
ax.text(indices[i] + bar_width/2, val_am_acc, f'{val_am_acc:.2f}', ha='center', va='bottom', fontsize=8)
if display_test:
ax.text(indices[i] + 3*bar_width/2, test_am_acc, f'{test_am_acc:.2f}', ha='center', va='bottom', fontsize=8)

# Set the title, x-ticks, and labels
ax.set_title(f'{lib} {title}')
ax.set_xticks(indices)
ax.set_xticklabels(x_labels)

# Set the legend only for the first subplot
#if lib == df['LIB'].unique()[0]:
# ax.legend()

# Set up the subplots
#fig, axs = plt.subplots(2, 2, figsize=(12, 10))
fig, axs = plt.subplots(2, 2, figsize=(12, 10), constrained_layout=True)
axs = axs.ravel() # Flatten the array of axes for easier indexing

fig.subplots_adjust(right=0.8)
for i, lib in enumerate(df['LIB'].unique()):
libs = df['LIB'].unique()
num_libs = len(libs)
# Calculate the number of rows and columns for the subplots
num_cols = 2
num_rows = math.ceil(num_libs / num_cols)
if num_libs <= 2:
# Adjust subplot layout for 1 or 2 libs
fig, axs = plt.subplots(1, num_libs, figsize=(12 * num_libs, 5), constrained_layout=True)
if num_libs == 1:
axs = np.array([axs])
else:
# Calculate the number of rows and columns for the subplots for more than 2 libs
num_cols = 2
num_rows = math.ceil(num_libs / num_cols)
fig, axs = plt.subplots(num_rows, num_cols, figsize=(12, num_rows * 5), constrained_layout=True)
axs = axs.ravel()
for i, lib in enumerate(libs):
ax = axs[i]
create_subplot_bar_graphs(df, axs[i], lib, 'Prediction Accuracy', color_b, display_test=True)

handles, labels = axs[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0), ncol=4, frameon=False)
plt.subplots_adjust(bottom=0.2)
#plt.tight_layout()
plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.savefig('./output/step5_analysis_compare_retriever.jpg')
for j in range(i + 1, len(axs)):
fig.delaxes(axs[j])
# Adjust layout and add a global legend at the bottom
plt.subplots_adjust(bottom=0.25, top=0.9)
# Add a global legend at the bottom
handles = [plt.Rectangle((0,0),1,1, color=color) for color in color_b[:len(legend_labels)]]
fig.legend(handles, legend_labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), ncol=4)
plt.tight_layout(rect=[0, 0.1, 1, 1])
plt.subplots_adjust(hspace=0.4) # Adjust this value to make space for the legend
plt.savefig('output/step5_analysis_compare_retriever.jpg')
plt.show()


2 changes: 1 addition & 1 deletion src/scripts/step5_analysis_compare_retriever.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Usage: bash -x scripts/step5_analysis_compare_retriever.sh

export HUGGINGPATH=./hugging_models
libs=("scanpy" "squidpy" "ehrapy" "snapatac2")
libs=("scanpy" "squidpy" "ehrapy" "snapatac2") # scanpy_subset
csv_file="output/retriever_accuracy_results.csv"

# Header for CSV file
Expand Down

0 comments on commit bed91ca

Please sign in to comment.