From 335a8bac777273ab2048f5b993005c0b43ec5e2b Mon Sep 17 00:00:00 2001 From: flux9665 Date: Mon, 7 Oct 2024 14:51:05 +0200 Subject: [PATCH] overhaul boxplots once more --- .../eval_lang_emb_approximation.py | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/Preprocessing/multilinguality/eval_lang_emb_approximation.py b/Preprocessing/multilinguality/eval_lang_emb_approximation.py index ca372add..9e3108ea 100644 --- a/Preprocessing/multilinguality/eval_lang_emb_approximation.py +++ b/Preprocessing/multilinguality/eval_lang_emb_approximation.py @@ -7,12 +7,13 @@ import torch from huggingface_hub import hf_hub_download -# matplotlib.rcParams['mathtext.fontset'] = 'stix' -# matplotlib.rcParams['font.family'] = 'STIXGeneral' +matplotlib.rcParams['mathtext.fontset'] = 'stix' +matplotlib.rcParams['font.family'] = 'STIXGeneral' matplotlib.rcParams['font.size'] = 7 import matplotlib.pyplot as plt from Utility.utils import load_json_from_path + def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embeddings, weighted_avg=False, min_n_langs=5, max_n_langs=30, threshold_percentile=95, loss_fn="MSE"): df = pd.read_csv(csv_path, sep="|") @@ -23,7 +24,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe features_per_closest_lang = 2 # for combined, df has up to 5 features (if containing individual distances) per closest lang + 1 target lang column - if "combined_dist_0" in df.columns: + if "combined_dist_0" in df.columns: if "map_dist_0" in df.columns: features_per_closest_lang += 1 if "asp_dist_0" in df.columns: @@ -77,7 +78,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe lang_emb = language_embeddings[iso_lookup[-1][lang]] avg_emb += lang_emb normalization_factor = len(langs) - avg_emb /= normalization_factor # normalize + avg_emb /= normalization_factor # normalize current_loss = loss_fn(avg_emb, y).item() all_losses.append(current_loss) @@ -111,36 +112,34 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe os.makedirs(OUT_DIR, exist_ok=True) fig, ax = plt.subplots(figsize=(6, 4)) - plt.ylabel(f"{args.loss_fn} between Approximated and Real") + plt.ylabel(args.loss_fn) for i, csv_path in enumerate(csv_paths): print(f"csv_path: {os.path.basename(csv_path)}") for condition in weighted: - losses = compute_loss_for_approximated_embeddings(csv_path, - iso_lookup, - lang_embs, - condition, - min_n_langs=args.min_n_langs, - max_n_langs=args.max_n_langs, - threshold_percentile=args.threshold_percentile, - loss_fn=args.loss_fn) + losses = compute_loss_for_approximated_embeddings(csv_path, + iso_lookup, + lang_embs, + condition, + min_n_langs=args.min_n_langs, + max_n_langs=args.max_n_langs, + threshold_percentile=args.threshold_percentile, + loss_fn=args.loss_fn) print(f"weighted average: {condition} | mean loss: {np.mean(losses)}") losses_of_multiple_datasets.append(losses) bp_dict = ax.boxplot(losses_of_multiple_datasets, - labels =[ - "Random Neighbors", - "Nearest according \nto inverse ASPF", - "Nearest according \nto Map Distance", - "Nearest according \nto Tree Distance", - "Nearest according \nto Learned Distance", - "Actual Nearest\n(Oracle)", - ], + labels=["Random", + "Inverse ASP", + "Map Distance", + "Tree Distance", + "Learned Distance", + "Oracle"], patch_artist=True, - boxprops=dict(facecolor = "lightblue", + boxprops=dict(facecolor="lightblue", ), - showfliers=False, + showfliers=False, widths=0.55 - ) + ) # major ticks every 0.1, minor ticks every 0.05, between 0.0 and 0.6 major_ticks = np.arange(0, 1.0, 0.1) minor_ticks = np.arange(0, 1.0, 0.05) @@ -148,7 +147,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe ax.set_yticks(minor_ticks, minor=True) # horizontal grid lines for minor and major ticks ax.grid(which='both', linestyle='-', color='lightgray', linewidth=0.3, axis='y') - plt.title(f"Using between {args.min_n_langs} and {args.max_n_langs} Nearest Neighbors to approximate an unseen Embedding") + # plt.title(f"Using between {args.min_n_langs} and {args.max_n_langs} Nearest Neighbors to approximate an unseen Embedding") plt.xticks(rotation=45) plt.tight_layout() plt.show()