Skip to content

Commit

Permalink
refactored grad variance callback
Browse files Browse the repository at this point in the history
  • Loading branch information
egorkrash committed Feb 20, 2024
1 parent 6f66801 commit 56759a0
Showing 1 changed file with 15 additions and 50 deletions.
65 changes: 15 additions & 50 deletions src/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def on_epoch_end(self,


class GradientVarianceCallback(EvaluationCallbackBase):
"""Calculates gradient variance and distance between definitions and corresponding questions.
Requires a tokenized eval dataset with keys: [<d1 definitions dataset>, '<d2 definitions dataset>', '<d1 questions>', '<d2 questions>'].
Example: ['train_defs_d1consis', 'train_defs_d2consis', 'd1consis', 'd2consis']
"""
def __init__(self, eval_dataset_tokenized,
keys,
tb_writer=None,
Expand All @@ -205,79 +209,48 @@ def __init__(self, eval_dataset_tokenized,


def evaluate_fn(self, args, state, model, tokenizer):
"""Compute gradient distance between definitions and corresponding questions."""
def compute_mean_distance(eval_dataset_questions, eval_dataset_defs, tag, mean_grad=None):
"""Compute mean distance between definitions and corresponding questions as well as mean gradient norms."""
# assuming eval_dataset_questions and eval_dataset_defs are already tokenized, on device and sorted
step_size = len(eval_dataset_questions) // len(eval_dataset_defs) # number of questions per definition

mean_dist = 0
mean_sim_cos = 0

mean_l1_d_norm, mean_l1_q_norm = 0, 0
mean_l2_d_norm, mean_l2_q_norm = 0, 0
mean_linf_d_norm, mean_linf_q_norm = 0, 0

l1_d_norms, l1_q_norms = [], []
l2_d_norms, l2_q_norms = [], []
linf_d_norms, linf_q_norms = [], []
distances = []
sim_cos = []

for i in tqdm(range(len(eval_dataset_defs))):
# for every definition, compute distances and cosine similarities with corresponding questions
d = eval_dataset_defs[i]
d_grad = get_gradient(model, d)
mean_d_dist = 0
mean_d_sim_cos = 0

# update mean_grad
# update mean_grad (used for variance calculation)
if mean_grad is None:
mean_grad = d_grad
else:
mean_grad += d_grad

# update gradient norms (definitions)
# mean_l1_d_norm += torch.norm(d_grad, p=1).item()
# mean_l2_d_norm += torch.norm(d_grad, p=2).item()
# mean_linf_d_norm += torch.norm(d_grad, p=float('inf')).item()
l1_d_norms.append(torch.norm(d_grad, p=1).item())
l2_d_norms.append(torch.norm(d_grad, p=2).item())
linf_d_norms.append(torch.norm(d_grad, p=float('inf')).item())

for j in range(step_size):
# for every question, compute distances and cosine similarities with definition
n = i * step_size + j # index of question
q = eval_dataset_questions[n]
# get gradient of question
q_grad = get_gradient(model, q)
# update distance and cosine similarity using current question
#mean_d_dist += torch.sqrt(torch.sum((d_grad - q_grad)**2)) # l2 distance between gradient of definition and gradient of question
distances.append(torch.sqrt(torch.sum((d_grad - q_grad)**2)))
# mean_d_sim_cos += torch.nn.functional.cosine_similarity(d_grad, q_grad, dim=0)
sim_cos.append(torch.cosine_similarity(d_grad, q_grad, dim=0).item())
# update gradient norms (questions)
#mean_l1_q_norm += torch.norm(q_grad, p=1).item()
#mean_l2_q_norm += torch.norm(q_grad, p=2).item()
#mean_linf_q_norm += torch.norm(q_grad, p=float('inf')).item()
l1_q_norms.append(torch.norm(q_grad, p=1).item())
l2_q_norms.append(torch.norm(q_grad, p=2).item())
linf_q_norms.append(torch.norm(q_grad, p=float('inf')).item())
mean_grad += q_grad

mean_d_dist /= step_size # transform sum into mean distance for this definition
mean_d_sim_cos /= step_size

mean_dist += mean_d_dist # mean distance for all definitions
mean_sim_cos += mean_d_sim_cos

mean_dist /= len(eval_dataset_defs)
mean_sim_cos /= len(eval_dataset_defs)

# divide by number of questions/definitions to get mean norms
mean_l1_d_norm /= len(eval_dataset_defs)
mean_l2_d_norm /= len(eval_dataset_defs)
mean_linf_d_norm /= len(eval_dataset_defs)

mean_l1_q_norm /= len(eval_dataset_questions)
mean_l2_q_norm /= len(eval_dataset_questions)
mean_linf_q_norm /= len(eval_dataset_questions)

return distances, sim_cos, mean_grad, {f'grad_mean_l1_q_norm_{tag}': l1_q_norms, f'grad_mean_l2_q_norm_{tag}': l2_q_norms,
f'grad_mean_linf_q_norm_{tag}': linf_q_norms, f'grad_mean_l1_d_norm_{tag}': l1_d_norms,
Expand All @@ -288,7 +261,8 @@ def compute_mean_distance(eval_dataset_questions, eval_dataset_defs, tag, mean_g
self._init_summary_writer(args)

model.train()
keys = self.keys# ['train_defs_d1consis', 'train_defs_d2consis', 'd1consis', 'd2consis']
keys = self.keys
# keys = ['train_defs_d1consis', 'train_defs_d2consis', 'd1consis', 'd2consis']
# keys = ['train_defs_qd1consis', 'train_defs_qd2incons', 'train_questions_qd1consis', 'train_questions_qd2incons']
tag1 = keys[0].split('_')[-1]
tag2 = keys[1].split('_')[-1]
Expand All @@ -308,10 +282,6 @@ def compute_mean_distance(eval_dataset_questions, eval_dataset_defs, tag, mean_g

mean_grad /= n_datapoints

# logger.info(f"Mean distance between {keys[2]} grads and their corresponding definitions: {mean_dist_d1}")
# logger.info(f"Mean distance between {keys[3]} grads and their corresponding definitions: {mean_dist_d2}")
# logger.info(f"Mean cosine similarity between {keys[2]} grads and their corresponding definitions: {mean_sim_d1_cos}")
# logger.info(f"Mean cosine similarity between {keys[3]} grads and their corresponding definitions: {mean_sim_d2_cos}")

# Calculate variance
logger.info('*** Computing gradient variance ***')
Expand All @@ -332,7 +302,7 @@ def compute_mean_distance(eval_dataset_questions, eval_dataset_defs, tag, mean_g

# delete eval datasets and log metrics
del eval_dataset_d1cons, eval_dataset_d2cons, eval_dataset_d1defs, eval_dataset_d2defs

# log metrics
self.tb_writer.add_tensor(f"eval/grad_mean_dist_{tag1}", torch.tensor(distances_d1), state.global_step)
self.tb_writer.add_tensor(f"eval/grad_mean_dist_{tag2}", torch.tensor(distances_d2), state.global_step)
self.tb_writer.add_scalar("eval/grad_variance", variance, state.global_step)
Expand All @@ -345,6 +315,8 @@ def compute_mean_distance(eval_dataset_questions, eval_dataset_defs, tag, mean_g
self.tb_writer.add_tensor(f"eval/{norm}", torch.tensor(norms1[norm]), state.global_step)


# wandb logging is currently turned off for this callback

# wandb.log({f"eval/grad_mean_dist_d1": mean_dist_d1}, state.global_step)
# wandb.log({f"eval/grad_mean_dist_d2": mean_dist_d2}, state.global_step)
# wandb.log({f"eval/grad_variance": variance}, state.global_step)
Expand All @@ -366,15 +338,8 @@ def get_gradient(model, input_dict):
loss = outputs.loss
loss.backward()

# gradients = {name: param.grad.cpu().detach() for name, param in model.named_parameters()}

# grad = []
# for name in sorted(gradients.keys()):
# grad.append(gradients[name].view(-1))
# grad = torch.cat(grad)

grad = []
for name, param in model.named_parameters():
for _, param in model.named_parameters():
if param.requires_grad:
grad.append(param.grad.view(-1))

Expand Down

0 comments on commit 56759a0

Please sign in to comment.