From bcf781e6888eb8b9009600b4c2471711a22ae6e0 Mon Sep 17 00:00:00 2001 From: Dmitrii Krasheninnikov Date: Fri, 19 Jan 2024 16:10:16 +0000 Subject: [PATCH] Better plotting --- utils/linear_probes.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/utils/linear_probes.py b/utils/linear_probes.py index b0eb88f..9c8ae70 100644 --- a/utils/linear_probes.py +++ b/utils/linear_probes.py @@ -109,20 +109,20 @@ def plot_score_grid(scores, tokens: List[str], title=None, vmin=0.49, vmax=1.01, # larger font size and times new roman font plt.rc('font', size=14, family='Times New Roman') - ax = plt.figure(figsize=(6, 3.8)).gca() + ax = plt.figure(figsize=(6, 2.6)).gca() # brainstorming cmaps; some to try: with blues and reds: 'PuOr', 'RdBu', 'RdYlBu', 'RdYlGn', 'Spectral', 'coolwarm' - sns.heatmap(scores.T, cmap=cmap, vmin=vmin, vmax=vmax) + sns.heatmap(scores.T, cmap=cmap, vmin=vmin, vmax=vmax, cbar_kws={'ticks': [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]}) plt.xticks(np.arange(len(tokens)), tokens, rotation=60) # set x ticks to the actual str tokens plt.gca().set_xticks(np.arange(len(tokens))+0.5, minor=True) # position x ticks in the middle of the cell without extra ticks plt.gca().tick_params(which='major', length=0) # remove major ticks and keep only minor ticks plt.gca().invert_yaxis() # make y axis go from bottom to top - plt.xlabel('Token') + plt.xlabel('Token', labelpad=-10) plt.ylabel('Layer') if title is not None: - plt.title(title) + plt.title(title, y=1.08, x=0.53) # plt.yticks(np.arange(len(scores.T))+0.5, range(1, len(scores.T)+1)) # add 1 to every y tick without changing its position plt.yticks(np.arange(len(scores.T)), range(1, len(scores.T)+1)) # add 1 to every y tick without changing its position @@ -130,9 +130,11 @@ def plot_score_grid(scores, tokens: List[str], title=None, vmin=0.49, vmax=1.01, # thin grid lines plt.grid(which='major', color='gray', linestyle='-', linewidth=0.3) - # # remove every 2nd y tick label - for label in plt.gca().yaxis.get_ticklabels()[::2]: - label.set_visible(False) + # leave only every 4th y tick + for i, label in enumerate(plt.gca().yaxis.get_ticklabels()): + if (i+1) % 4 != 0: + label.set_visible(False) + for label in plt.gca().yaxis.get_ticklabels(): label.set_verticalalignment('bottom') # label.set_position((0, 0.5))