Skip to content

Commit

Permalink
Readability improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
krasheninnikov committed Feb 22, 2024
1 parent bd7ca2f commit 692e274
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
24 changes: 15 additions & 9 deletions src/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,22 +199,27 @@ class GradientVarianceCallback(EvaluationCallbackBase):
https://colab.research.google.com/drive/1K-bWitUMffNlB1cIl8ELq6jtnyG6_J5b?usp=sharing
"""
def __init__(self, eval_dataset_tokenized,
keys,
keys : str, # comma separated string of keys. NOTE: order here matters!
tb_writer=None,
numeric_experiment=False,
eval_each_epochs=1,
eval_each_steps=False,
evaluation_strategy='epoch') -> None:

super().__init__(tb_writer, eval_each_epochs, eval_each_steps, evaluation_strategy, numeric_experiment)
self.keys = keys
self.keys = keys.split(',')
assert len(self.keys) == 4, "There must be exactly 4 keys in the keys argument."
self.eval_dataset_tokenized = eval_dataset_tokenized


def evaluate_fn(self, args, state, model, tokenizer):
def compute_mean_distance(eval_dataset_questions, eval_dataset_defs, tag, mean_grad=None):
"""Compute mean distance between definitions and corresponding questions as well as mean gradient norms."""
# assuming eval_dataset_questions and eval_dataset_defs are already tokenized, on device and sorted
"""
Compute mean distances between definitions and corresponding questions as well as mean gradient norms.
eval_dataset_questions and eval_dataset_defs must be already tokenized, on device,
and sorted s.t. for every definition, there are n=step_size questions *in the same order as the definitions*.
"""
assert len(eval_dataset_questions) % len(eval_dataset_defs) == 0, "Number of questions must be a multiple of number of definitions."
step_size = len(eval_dataset_questions) // len(eval_dataset_defs) # number of questions per definition

l1_d_norms, l1_q_norms = [], []
Expand All @@ -223,6 +228,7 @@ def compute_mean_distance(eval_dataset_questions, eval_dataset_defs, tag, mean_g
distances = []
sim_cos = []

# iterate over definitions
for i in tqdm(range(len(eval_dataset_defs))):
# for every definition, compute distances and cosine similarities with corresponding questions
d = eval_dataset_defs[i]
Expand All @@ -238,11 +244,11 @@ def compute_mean_distance(eval_dataset_questions, eval_dataset_defs, tag, mean_g
l1_d_norms.append(torch.norm(d_grad, p=1).item())
l2_d_norms.append(torch.norm(d_grad, p=2).item())
linf_d_norms.append(torch.norm(d_grad, p=float('inf')).item())


# iterate over questions corresponding to definition d
for j in range(step_size):
# for every question, compute distances and cosine similarities with definition
n = i * step_size + j # index of question
q = eval_dataset_questions[n]
q = eval_dataset_questions[i * step_size + j]
# get gradient of question
q_grad = get_gradient(model, q)
# update distance and cosine similarity using current question
Expand All @@ -256,8 +262,8 @@ def compute_mean_distance(eval_dataset_questions, eval_dataset_defs, tag, mean_g


return distances, sim_cos, mean_grad, {f'grad_mean_l1_q_norm_{tag}': l1_q_norms, f'grad_mean_l2_q_norm_{tag}': l2_q_norms,
f'grad_mean_linf_q_norm_{tag}': linf_q_norms, f'grad_mean_l1_d_norm_{tag}': l1_d_norms,
f'grad_mean_l2_d_norm_{tag}': l2_d_norms, f'grad_mean_linf_d_norm_{tag}': linf_d_norms}
f'grad_mean_linf_q_norm_{tag}': linf_q_norms, f'grad_mean_l1_d_norm_{tag}': l1_d_norms,
f'grad_mean_l2_d_norm_{tag}': l2_d_norms, f'grad_mean_linf_d_norm_{tag}': linf_d_norms}


if self.tb_writer is None:
Expand Down
2 changes: 1 addition & 1 deletion src/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def compute_objective(metrics: Dict[str, float]) -> float:

if training_args.calculate_grad_variance:
grad_callback = GradientVarianceCallback(eval_dataset_tokenized,
keys=training_args.grad_keys.split(','),
keys=training_args.grad_keys,
eval_each_epochs=training_args.eval_each_epochs,
eval_each_steps=training_args.eval_steps,
evaluation_strategy=training_args.evaluation_strategy,)
Expand Down
3 changes: 2 additions & 1 deletion utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ class ModelTrainingArguments(Seq2SeqTrainingArguments):
default=False, metadata={"help": "Whether to calculate gradient variance; note that this slows down training substantially."}
)
grad_keys: Optional[str] = field(
default='train_defs_d1consis,train_defs_d2consis,d1consis,d2consis', metadata={"help": "Keys to calculate gradient variance for."}
default='train_defs_d1consis,train_defs_d2consis,d1consis,d2consis',
metadata={"help": "Keys to calculate gradient variance for; NOTE: order matters here. See src/callbacks/GradientVarianceCallback for usage."}
)
eval_callback_type: Optional[str] = field(
default='pipeline', metadata={"help": "Evaluation callback type. Use `pipeline` for clm and `generate` for seq2seq"}
Expand Down

0 comments on commit 692e274

Please sign in to comment.