Skip to content

Commit

Permalink
bugfix: torch.export failure caused by _make_causal_mask (#35291)
Browse files Browse the repository at this point in the history
* bugfix: torch.export failure caused by `_make_causal_mask`

Recent changes in torch dynamo prevent mutations on tensors converted with aten::_to_copy. To address this, we can clone such tensor before performing in-place operation `masked_fill_` only when the code is being compiled by torch dynamo.
(relevant issue: pytorch/pytorch#127571)

* chore: use `is_torchdynamo_compiling` instead of `torch._dynamo.is_compiling`
  • Loading branch information
jiwoong-choi authored Dec 20, 2024
1 parent 05de764 commit 40292aa
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def _make_causal_mask(
diagonal = past_key_values_length - sliding_window - 1

context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
# Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
# See https://github.com/pytorch/pytorch/issues/127571
if is_torchdynamo_compiling():
mask = mask.clone()
mask.masked_fill_(context_mask, torch.finfo(dtype).min)

return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
Expand Down

0 comments on commit 40292aa

Please sign in to comment.