Skip to content

Commit

Permalink
Update plots and improve their setup (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 authored Oct 12, 2024
1 parent 454c66a commit 621fc66
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions analysis/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def plot_eng_drop_line(
for (label, group), color in zip(data.groupby("Model_Type"), colors):
mrewardbench_scores = group["Avg_Multilingual"]
rewardbench_scores = group["eng_Latn"]
ax.scatter(rewardbench_scores, mrewardbench_scores, marker="o", s=30, label=label, color=color)
ax.scatter(rewardbench_scores, mrewardbench_scores, marker="o", s=40, label=label, color=color)

mrewardbench_scores = data["Avg_Multilingual"]
rewardbench_scores = data["eng_Latn"]
Expand All @@ -184,6 +184,7 @@ def plot_eng_drop_line(
ax.plot([min_val, max_val], [min_val, max_val], linestyle="--", color="black", alpha=0.25)
ax.set_xlabel("RewardBench (Lambert et al., 2024)")
ax.set_ylabel("M-RewardBench")
ax.grid(color="gray", alpha=0.2, which="both")
ax.set_aspect("equal")
ax.legend(frameon=False, handletextpad=0.2, fontsize=12)

Expand All @@ -193,7 +194,7 @@ def plot_eng_drop_line(
rewardbench_scores[idx],
mrewardbench_scores[idx],
model_names[idx],
fontsize=12,
fontsize=14,
)
for idx in range(len(data))
]
Expand All @@ -216,8 +217,8 @@ def plot_eng_drop_line(
# # bbox=dict(facecolor="white", edgecolor="black", boxstyle="round,pad=0.5"),
# )

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
# ax.spines["right"].set_visible(False)
# ax.spines["top"].set_visible(False)
plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight")

Expand All @@ -238,14 +239,17 @@ def plot_ling_dims(
raw = pd.read_csv(input_path).set_index("Model")
if top_n:
raw = raw.head(top_n)
raw = raw[[col for col in raw.columns if col not in ("Model_Type", "eng_Latn", "Avg_Multilingual")]]
raw = raw[[col for col in raw.columns if col not in ("Model_Type", "eng_Latn", "Avg_Multilingual", "Family")]]
raw = raw.T
langdata = pd.read_csv(langdata).set_index("Language")
combined = raw.merge(langdata, left_index=True, right_index=True)
combined["Avg"] = raw.mean(axis=1) * 100
combined["Std"] = raw.std(axis=1) * 100

combined = combined.rename(columns={"Resource_Type": "Resource Availability"})
# Remove Class 0 because it's misleading
combined = combined[combined["Resource Availability"] != "Class-0"].reset_index()

linguistic_dims = [
"Resource Availability",
"Family",
Expand All @@ -254,22 +258,28 @@ def plot_ling_dims(
fig, axs = plt.subplots(1, len(linguistic_dims), figsize=figsize, sharex=True)
for ax, dim in zip(axs, linguistic_dims):
lingdf = combined.groupby(dim).agg({"Avg": "mean", "Std": "mean"}).reset_index()
if dim != "Resource Availability":
lingdf = lingdf.sort_values(by="Avg", ascending=False)
else:
lingdf = lingdf[::-1]

ax.grid(color="gray", alpha=0.2, which="both", axis="x")
ax.set_axisbelow(True)
sns.barplot(
x="Avg",
y=dim,
data=lingdf,
ax=ax,
color="green",
width=0.5 if dim == "Resource Availability" else 0.7,
width=0.4 if dim == "Resource Availability" else 0.7,
)
ax.set_title(dim)
ax.set_xlim([60, 70])
ax.set_ylabel("")
ax.set_xlabel("M-RewardBench Score")

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
# ax.spines["right"].set_visible(False)
# ax.spines["top"].set_visible(False)

plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight")
Expand Down

0 comments on commit 621fc66

Please sign in to comment.