Skip to content

Commit

Permalink
Better plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
krasheninnikov committed Jan 19, 2024
1 parent bc27c7d commit bcf781e
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions utils/linear_probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,30 +109,32 @@ def plot_score_grid(scores, tokens: List[str], title=None, vmin=0.49, vmax=1.01,
# larger font size and times new roman font
plt.rc('font', size=14, family='Times New Roman')

ax = plt.figure(figsize=(6, 3.8)).gca()
ax = plt.figure(figsize=(6, 2.6)).gca()

# brainstorming cmaps; some to try: with blues and reds: 'PuOr', 'RdBu', 'RdYlBu', 'RdYlGn', 'Spectral', 'coolwarm'
sns.heatmap(scores.T, cmap=cmap, vmin=vmin, vmax=vmax)
sns.heatmap(scores.T, cmap=cmap, vmin=vmin, vmax=vmax, cbar_kws={'ticks': [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]})

plt.xticks(np.arange(len(tokens)), tokens, rotation=60) # set x ticks to the actual str tokens
plt.gca().set_xticks(np.arange(len(tokens))+0.5, minor=True) # position x ticks in the middle of the cell without extra ticks
plt.gca().tick_params(which='major', length=0) # remove major ticks and keep only minor ticks
plt.gca().invert_yaxis() # make y axis go from bottom to top

plt.xlabel('Token')
plt.xlabel('Token', labelpad=-10)
plt.ylabel('Layer')
if title is not None:
plt.title(title)
plt.title(title, y=1.08, x=0.53)

# plt.yticks(np.arange(len(scores.T))+0.5, range(1, len(scores.T)+1)) # add 1 to every y tick without changing its position
plt.yticks(np.arange(len(scores.T)), range(1, len(scores.T)+1)) # add 1 to every y tick without changing its position

# thin grid lines
plt.grid(which='major', color='gray', linestyle='-', linewidth=0.3)

# # remove every 2nd y tick label
for label in plt.gca().yaxis.get_ticklabels()[::2]:
label.set_visible(False)
# leave only every 4th y tick
for i, label in enumerate(plt.gca().yaxis.get_ticklabels()):
if (i+1) % 4 != 0:
label.set_visible(False)

for label in plt.gca().yaxis.get_ticklabels():
label.set_verticalalignment('bottom')
# label.set_position((0, 0.5))
Expand Down

0 comments on commit bcf781e

Please sign in to comment.