-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Automatically adjust the number of subplots based on the number of li…
…braries
- Loading branch information
1 parent
3d011e4
commit bed91ca
Showing
4 changed files
with
88 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,49 @@ | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
|
||
df = pd.read_csv('plot/retriever_topk_results.csv') | ||
|
||
columns_to_keep = ['retrieved_api_nums', 'Validation Accuracy', 'Test Accuracy', 'val ambiguous Accuracy', 'test ambiguous Accuracy'] # 'Training Accuracy', 'Training ambiguous Accuracy', | ||
|
||
labels = [ | ||
'Synthetic instruction', 'Annotation instruction', | ||
'Synthetic instruction w/ ambiguity removal', 'Annotation instruction w/ ambiguity removal' | ||
] | ||
|
||
num_rows = 2 | ||
num_cols = 2 | ||
|
||
libs = df['LIB'].unique()[:num_rows * num_cols] | ||
|
||
fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 10)) | ||
plt.subplots_adjust(bottom=0.25) # 调整底部空间 | ||
""" | ||
Author: Zhengyuan Dong | ||
Created Date: 2024-02-01 | ||
Last Edited Date: 2024-04-08 | ||
Description: | ||
Plot the comparison of the retriever accuracy results for different libraries and models. | ||
Automatically adjust the number of subplots based on the number of libraries. | ||
""" | ||
|
||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import math | ||
|
||
df = pd.read_csv('./output/retriever_topk_results.csv') | ||
columns_to_keep = ['retrieved_api_nums', 'Validation Accuracy', 'Test Accuracy', 'val ambiguous Accuracy', 'test ambiguous Accuracy'] | ||
labels = ['Synthetic instruction', 'Annotation instruction', | ||
'Synthetic instruction w/ ambiguity removal', 'Annotation instruction w/ ambiguity removal'] | ||
libs = df['LIB'].unique() | ||
num_libs = len(libs) | ||
if num_libs <= 2: | ||
fig, axs = plt.subplots(1, num_libs, figsize=(15 * num_libs, 6)) | ||
if num_libs == 1: | ||
axs = np.array([axs]) | ||
else: | ||
num_rows = np.ceil(num_libs / 2).astype(int) | ||
fig, axs = plt.subplots(num_rows, 2, figsize=(15, 6 * num_rows)) | ||
axs = axs.flatten() | ||
for index, lib in enumerate(libs): | ||
ax = axs.flatten()[index] | ||
lib_df = df[df['LIB'] == lib][columns_to_keep] | ||
ax = axs[index] | ||
lib_df = df[df['LIB'] == lib] | ||
for col, label in zip(columns_to_keep[1:], labels): | ||
ax.plot(lib_df['retrieved_api_nums'], lib_df[col], label=label) | ||
ax.set_title(f'Fine-tuned Retriever Accuracy v.s. Topk for {lib}') | ||
ax.set_xlabel('Topk') | ||
ax.set_ylabel('Accuracy') | ||
|
||
# 用flatten()将axs从2D数组转化为1D数组以便简单索引 | ||
ax.grid(True) | ||
if num_libs > 1: | ||
for ax in axs[num_libs:]: | ||
ax.set_visible(False) | ||
handles, labels = axs.flatten()[0].get_legend_handles_labels() | ||
|
||
# 在整个Figure的底部绘制一个共享的图例 | ||
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), ncol=4) | ||
|
||
# Tighten the layout with the rect parameter to fit the new legend | ||
plt.tight_layout(rect=[0, 0.1, 1, 1]) | ||
|
||
plt.savefig('./plot/retriever_acc_lib_4.jpg') | ||
plt.subplots_adjust(hspace=0.4) | ||
plt_path = "./output/step4_analysis_retriever_acc_lib.jpg" | ||
plt.savefig(plt_path) | ||
plt.show() | ||
|
||
plt_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,79 @@ | ||
|
||
""" | ||
Author: Zhengyuan Dong | ||
Created Date: 2024-02-01 | ||
Last Edited Date: 2024-04-08 | ||
Description: | ||
Plot the comparison of the retriever accuracy results for different libraries and models. | ||
Automatically adjust the number of subplots based on the number of libraries. | ||
""" | ||
|
||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
# Based on the provided data structure and the requirements for subplots, let's create the desired bar graph. | ||
import math | ||
|
||
# Load the CSV data | ||
df = pd.read_csv('output/retriever_accuracy_results.csv') | ||
default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] | ||
color_b = default_colors[:4] | ||
#color_b = ['blue', 'orange', 'green', 'red'] | ||
|
||
legend_labels = ['Synthetic instruction', 'Annotated instruction', | ||
'Synthetic instruction w/ ambiguity removal', 'Annotated instruction w/ ambiguity removal'] | ||
# Define the function for plotting | ||
def create_subplot_bar_graphs(df, ax, lib, title, colors, display_test=False): | ||
# Columns for BM25, Un-Finetuned, and Finetuned (Val and Test) | ||
cols = [f'{lib} Val', f'{lib} Test', f'{lib} Ambiguous Val', f'{lib} Ambiguous Test'] | ||
|
||
x_labels = ['BM25', 'SENTENCE-BERT \nw/o \nfine-tuning', 'SENTENCE-BERT \nw/ \nfine-tuning'] | ||
|
||
# The width of the bars | ||
bar_width = 0.15 | ||
|
||
# Set the positions of the bars | ||
indices = np.arange(len(x_labels)) | ||
|
||
# Plot bars | ||
for i, col_prefix in enumerate(['BM25', 'Un-Finetuned', 'Finetuned']): | ||
#print('----------', col_prefix) | ||
val_acc = df.loc[df['LIB'] == lib, f'{col_prefix} Val'].values[0] | ||
test_acc = df.loc[df['LIB'] == lib, f'{col_prefix} Test'].values[0] | ||
#if col_prefix!='BM25': | ||
val_am_acc = df.loc[df['LIB'] == lib, f'{col_prefix} Ambiguous Val'].values[0] | ||
test_am_acc = df.loc[df['LIB'] == lib, f'{col_prefix} Ambiguous Test'].values[0] | ||
|
||
ax.bar(indices[i] - 3*bar_width/2, val_acc, bar_width, label='Synthetic instruction' if i==0 else "", color=colors[0]) | ||
ax.bar(indices[i] - 3*bar_width/2, val_acc, bar_width, color=colors[0]) | ||
if display_test: | ||
ax.bar(indices[i] - bar_width/2, test_acc, bar_width, label='Annotated instruction' if i==0 else "", color=colors[1]) | ||
#if col_prefix!='BM25': | ||
ax.bar(indices[i] + bar_width/2, val_am_acc, bar_width, label='Synthetic instruction w/ ambiguity removal' if i==1 else "", color=colors[2]) | ||
ax.bar(indices[i] - bar_width/2, test_acc, bar_width, color=colors[1]) | ||
ax.bar(indices[i] + bar_width/2, val_am_acc, bar_width, color=colors[2]) | ||
if display_test: | ||
ax.bar(indices[i] + 3*bar_width/2, test_am_acc, bar_width, label='Annotated instruction w/ ambiguity removal' if i==1 else "", color=colors[3]) | ||
|
||
# Add data labels | ||
ax.bar(indices[i] + 3*bar_width/2, test_am_acc, bar_width, color=colors[3]) | ||
ax.text(indices[i] - 3*bar_width/2, val_acc, f'{val_acc:.2f}', ha='center', va='bottom', fontsize=8) | ||
if display_test: | ||
ax.text(indices[i] - bar_width/2, test_acc, f'{test_acc:.2f}', ha='center', va='bottom', fontsize=8) | ||
#if col_prefix!='BM25': | ||
ax.text(indices[i] + bar_width/2, val_am_acc, f'{val_am_acc:.2f}', ha='center', va='bottom', fontsize=8) | ||
if display_test: | ||
ax.text(indices[i] + 3*bar_width/2, test_am_acc, f'{test_am_acc:.2f}', ha='center', va='bottom', fontsize=8) | ||
|
||
# Set the title, x-ticks, and labels | ||
ax.set_title(f'{lib} {title}') | ||
ax.set_xticks(indices) | ||
ax.set_xticklabels(x_labels) | ||
|
||
# Set the legend only for the first subplot | ||
#if lib == df['LIB'].unique()[0]: | ||
# ax.legend() | ||
|
||
# Set up the subplots | ||
#fig, axs = plt.subplots(2, 2, figsize=(12, 10)) | ||
fig, axs = plt.subplots(2, 2, figsize=(12, 10), constrained_layout=True) | ||
axs = axs.ravel() # Flatten the array of axes for easier indexing | ||
|
||
fig.subplots_adjust(right=0.8) | ||
for i, lib in enumerate(df['LIB'].unique()): | ||
libs = df['LIB'].unique() | ||
num_libs = len(libs) | ||
# Calculate the number of rows and columns for the subplots | ||
num_cols = 2 | ||
num_rows = math.ceil(num_libs / num_cols) | ||
if num_libs <= 2: | ||
# Adjust subplot layout for 1 or 2 libs | ||
fig, axs = plt.subplots(1, num_libs, figsize=(12 * num_libs, 5), constrained_layout=True) | ||
if num_libs == 1: | ||
axs = np.array([axs]) | ||
else: | ||
# Calculate the number of rows and columns for the subplots for more than 2 libs | ||
num_cols = 2 | ||
num_rows = math.ceil(num_libs / num_cols) | ||
fig, axs = plt.subplots(num_rows, num_cols, figsize=(12, num_rows * 5), constrained_layout=True) | ||
axs = axs.ravel() | ||
for i, lib in enumerate(libs): | ||
ax = axs[i] | ||
create_subplot_bar_graphs(df, axs[i], lib, 'Prediction Accuracy', color_b, display_test=True) | ||
|
||
handles, labels = axs[0].get_legend_handles_labels() | ||
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0), ncol=4, frameon=False) | ||
plt.subplots_adjust(bottom=0.2) | ||
#plt.tight_layout() | ||
plt.tight_layout(rect=[0, 0.05, 1, 1]) | ||
plt.savefig('./output/step5_analysis_compare_retriever.jpg') | ||
for j in range(i + 1, len(axs)): | ||
fig.delaxes(axs[j]) | ||
# Adjust layout and add a global legend at the bottom | ||
plt.subplots_adjust(bottom=0.25, top=0.9) | ||
# Add a global legend at the bottom | ||
handles = [plt.Rectangle((0,0),1,1, color=color) for color in color_b[:len(legend_labels)]] | ||
fig.legend(handles, legend_labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), ncol=4) | ||
plt.tight_layout(rect=[0, 0.1, 1, 1]) | ||
plt.subplots_adjust(hspace=0.4) # Adjust this value to make space for the legend | ||
plt.savefig('output/step5_analysis_compare_retriever.jpg') | ||
plt.show() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters