From 40292aa4e95abb827847f77318a32efc1e76d973 Mon Sep 17 00:00:00 2001 From: Jiwoong Date: Fri, 20 Dec 2024 22:37:04 +0900 Subject: [PATCH] bugfix: torch.export failure caused by `_make_causal_mask` (#35291) * bugfix: torch.export failure caused by `_make_causal_mask` Recent changes in torch dynamo prevent mutations on tensors converted with aten::_to_copy. To address this, we can clone such tensor before performing in-place operation `masked_fill_` only when the code is being compiled by torch dynamo. (relevant issue: https://github.com/pytorch/pytorch/issues/127571) * chore: use `is_torchdynamo_compiling` instead of `torch._dynamo.is_compiling` --- src/transformers/modeling_attn_mask_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 4319c021cb2bc3..09fc77e46b07ed 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -169,6 +169,10 @@ def _make_causal_mask( diagonal = past_key_values_length - sliding_window - 1 context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) + # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy + # See https://github.com/pytorch/pytorch/issues/127571 + if is_torchdynamo_compiling(): + mask = mask.clone() mask.masked_fill_(context_mask, torch.finfo(dtype).min) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)