Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
airaria committed Dec 17, 2021
1 parent af4e268 commit 08ef6f4
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/textpruner/pruners/transformer_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,10 @@ def get_importance_score(self, dataloader,
batch = move_to_device(batch, device)
if isinstance(batch,abc.Mapping):
outputs = model(**batch)
batch_num_examples = len(list(batch.values())[0])
else:
outputs = model(*batch)

batch_num_examples = len(batch[0])
if adaptor is None:
try:
if isinstance(outputs, torch.Tensor):
Expand Down Expand Up @@ -410,7 +411,7 @@ def get_importance_score(self, dataloader,
ffn_importance[layer_num] += ((weight2.grad * weight2).sum(dim=0)).abs().detach()

model.zero_grad()
num_examples += len(batch["attention_mask"])
num_examples += batch_num_examples

head_importance /= num_examples
ffn_importance /= num_examples
Expand Down

0 comments on commit 08ef6f4

Please sign in to comment.