From 456eda67aa962e7bd2d82e4514ccee1cc4ceac25 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 24 Oct 2024 14:06:50 +0200 Subject: [PATCH] replace get_seq_length --- src/transformers/models/moshi/modeling_moshi.py | 4 ++-- src/transformers/models/mt5/modeling_mt5.py | 2 +- src/transformers/models/phimoe/modeling_phimoe.py | 2 +- src/transformers/models/pix2struct/modeling_pix2struct.py | 4 ++-- src/transformers/models/pop2piano/modeling_pop2piano.py | 2 +- .../switch_transformers/modeling_switch_transformers.py | 2 +- src/transformers/models/t5/modeling_t5.py | 2 +- src/transformers/models/udop/modeling_udop.py | 2 +- src/transformers/models/umt5/modeling_umt5.py | 2 +- 9 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index a7614bebbb80a1..8741011b8bf842 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1160,7 +1160,7 @@ def forward( if use_cache and past_key_values is None and not self.training: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_past_seen_tokens() if cache_position is None: cache_position = torch.arange( past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device @@ -1496,7 +1496,7 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index bf0cd55551eec3..9cf62df4b90a10 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1016,7 +1016,7 @@ def forward( # it messes indexing later in decoder-stack because cache object is modified in-place past_key_values = None - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index bed0ef966be1c9..356004b204fdd8 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1089,7 +1089,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 0997ef434572e1..25793f94b2160b 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1411,7 +1411,7 @@ def forward( if cache_position is not None: past_key_values_length = cache_position[0] elif past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() + past_key_values_length = past_key_values.get_past_seen_tokens() if cache_position is None: cache_position = torch.arange( @@ -1421,7 +1421,7 @@ def forward( if attention_mask is None: # required mask seq length can be calculated via length of past mask_seq_length = ( - past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length + past_key_values.get_past_seen_tokens() + seq_length if past_key_values is not None else seq_length ) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index f7ddb1b81aa6cf..ea8f7a52f24773 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -854,7 +854,7 @@ def forward( # it messes indexing later in decoder-stack because cache object is modified in-place past_key_values = None - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 7a450e6a8c321c..ac52ec0846352a 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -976,7 +976,7 @@ def forward( # it messes indexing later in decoder-stack because cache object is modified in-place past_key_values = None - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 28b849ab06ed20..b143f60fd454c3 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1029,7 +1029,7 @@ def forward( # it messes indexing later in decoder-stack because cache object is modified in-place past_key_values = None - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 867418b54fb72f..079964635ef49a 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1410,7 +1410,7 @@ def forward( # it messes indexing later in decoder-stack because cache object is modified in-place past_key_values = None - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 8161a7e488f5ab..d50b1f51536467 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -715,7 +715,7 @@ def forward( # it messes indexing later in decoder-stack because cache object is modified in-place past_key_values = None - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device