Skip to content

Commit

Permalink
replace get_seq_length
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Oct 24, 2024
1 parent f4718e6 commit 456eda6
Show file tree
Hide file tree
Showing 9 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ def forward(
if use_cache and past_key_values is None and not self.training:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length()
past_seen_tokens = 0 if past_key_values is None else past_key_values.get_past_seen_tokens()
if cache_position is None:
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device
Expand Down Expand Up @@ -1496,7 +1496,7 @@ def forward(
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ def forward(
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None

past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phimoe/modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/pix2struct/modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,7 @@ def forward(
if cache_position is not None:
past_key_values_length = cache_position[0]
elif past_key_values is not None:
past_key_values_length = past_key_values.get_seq_length()
past_key_values_length = past_key_values.get_past_seen_tokens()

if cache_position is None:
cache_position = torch.arange(
Expand All @@ -1421,7 +1421,7 @@ def forward(
if attention_mask is None:
# required mask seq length can be calculated via length of past
mask_seq_length = (
past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length
past_key_values.get_past_seen_tokens() + seq_length if past_key_values is not None else seq_length
)
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/pop2piano/modeling_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ def forward(
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None

past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,7 @@ def forward(
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None

past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def forward(
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None

past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/udop/modeling_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,7 +1410,7 @@ def forward(
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None

past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/umt5/modeling_umt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def forward(
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None

past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
Expand Down

0 comments on commit 456eda6

Please sign in to comment.