-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Calculate position ids in modeling utils for all generative models #30053
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
About the framework changes: I found that tf/flax has a slightly different way to get I made tf and flax same way as torch is now with a cumsum over attention mask, so that the equivalence over frameworks tests pass. I am not sure if we need similar test for tf/flax to "test_position_ids". Tests should pass now, at least locally it seemed okay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general looks good, thank you for tackling this refactor 💪
A few notes:
- No TF function to infer the position IDs? 😢 TF feels neglected 💔
- There are CI errors in the model equivalence. Model equivalence is flaky by nature, make sure you run model equivalence for all models with flake finder locally!
- After you're happy with the changes, commit with
[test_all]
and tag me again. I've glanced over the model-level changes after the first few models, I'll do a final check more carefully after the full CI is green 🤗
if length_diff < 0: | ||
position_ids = position_ids[:, :length_diff] | ||
elif length_diff > 0: | ||
new_position_ids = torch.arange(position_ids[0, -1], new_length, device=position_ids.device).unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment briefly explaining when each situation can be triggered, and why we want that operation? Our future selves will probably be happy with that comment
e.g. I'm assuming length_diff > 0
is used when candidates are proposed, and thus we want the corresponding position ids. But I'm not immediately seeing when length_diff < 0
can be triggered :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^ this function still needs better variable names and/or a docstring
tests/generation/test_utils.py
Outdated
@@ -1189,6 +1189,66 @@ def test_assisted_decoding_matches_greedy_search(self): | |||
for output in (output_greedy, output_assisted): | |||
self._check_outputs(output, input_ids, model.config, use_cache=True) | |||
|
|||
@is_flaky() | |||
def test_assisted_decoding_position_ids(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like in the other PR you added an assisted generation test: let's make this a parameterization of the original test, since it's a minor variation :)
Okay, will work on it.
|
@gante the comments are addressed now. TF cannot have the "get_position_ids" method in PretrainedModel because all input related preparations in TF happen in a "keras.layers.Layer" class. I am not sure if we can or should be moving the position_id preparation into the "PretrainedModel", since there are only 3 TF models that were needed change. Also, to note for Flax-based encoder-decoder models: the attention mask for decoder part is overriden to be full, because when using decoder-only model as decoder part the position ids are calculated differently (I mean only the unattended part). In random initialized models it is causing logits mismatch, even if the attention masks out unattended positions. In pre-trained models that does not happen 🤔 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hold it for a while, not stale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more pattern fixes and should be ready to go 🤞
if length_diff < 0: | ||
position_ids = position_ids[:, :length_diff] | ||
elif length_diff > 0: | ||
new_position_ids = torch.arange(position_ids[0, -1], new_length, device=position_ids.device).unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^ this function still needs better variable names and/or a docstring
seq_length = ( | ||
inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] | ||
) | ||
if position_ids is None: | ||
device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
position_ids = self.get_position_ids_from_attention_mask( | ||
attention_mask, past_length, seq_length=seq_length, device=device | ||
) | ||
else: | ||
position_ids = position_ids[:, -seq_length:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove all this code, actually 👀 I see the following cases:
position_ids is None
-> the forward pass correctly computesposition_ids
, due to the changes in this PRposition_ids is not None
-> the user has definedposition_ids
, it's its own responsibility to pass them correctly
WDYT? (this logic would apply to all models, and would make maintenance easier for us 👼 )
@@ -593,6 +593,7 @@ def forward( | |||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") | |||
} | |||
|
|||
print(self.encoder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print(self.encoder) |
@@ -702,7 +709,9 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= | |||
attention_mask = kwargs.get("attention_mask", None) | |||
|
|||
if attention_mask is not None and position_ids is None: | |||
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one should be correct, no? 🤔
(the same comment applies to other TF models)
position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 | ||
# create ones tensor to match dtypes, otherwise we get errors | ||
ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) | ||
position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) | ||
position_ids = position_ids[..., -input_shape[-1] :] | ||
position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) | ||
else: | ||
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 | |
# create ones tensor to match dtypes, otherwise we get errors | |
ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) | |
position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) | |
position_ids = position_ids[..., -input_shape[-1] :] | |
position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) | |
else: | |
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) | |
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) |
(see comment below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the same logic applies to other TF models
# when model weights are random init masking with attn_mask still leads to logits | ||
# mismatch, which does not happen if pre-trained models are used. That causes error in encoder-decoder models | ||
# when decoder_only is used in as backbone (GPT2), because GPT prepares positions depending on attn mask (for torch) | ||
# and as arange in flax. That's why we init attn mask with all `1` | ||
if "decoder_attention_mask" in pt_inputs: | ||
pt_inputs["decoder_attention_mask"] = torch.ones_like(pt_inputs["decoder_attention_mask"]) | ||
inputs_dict["decoder_attention_mask"] = jnp.ones_like(inputs_dict["decoder_attention_mask"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change should no longer be needed, correct?
(as a general rule, we shouldn't fudge these equivalence tests :) )
# make full attn mask since below we are preparing position ids assuming it's all ones | ||
attention_mask = jnp.ones_like(attention_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the other way around: we should update the creation of position_ids
(below) to match the mask
The same comment applies to other FLAX test changes
@@ -149,8 +149,8 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input | |||
) | |||
|
|||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length) | |||
position_ids = jnp.broadcast_to( | |||
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1) | |||
position_ids = model.get_position_ids_from_attention_mask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, like this!
# when model weights are random init masking with attn_mask still leads to logits | ||
# mismatch, which does not happen if pre-trained models are used. That causes error in encoder-decoder models | ||
# when decoder_only is used in as backbone (GPT2), because GPT prepares positions depending on attn mask (for torch) | ||
# and as arange in flax. That's why we init attn mask with all `1` | ||
if "decoder_attention_mask" in pt_inputs: | ||
pt_inputs["decoder_attention_mask"] = torch.ones_like(pt_inputs["decoder_attention_mask"]) | ||
inputs_dict["decoder_attention_mask"] = jnp.ones_like(inputs_dict["decoder_attention_mask"]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as in the same pattern above, we should remove this
Will reopen this one later, as a new PR. It will need resolving merge conflicts and propagating changes to new models + PR comments. |
What does this PR do?
As it was discussed under this PR, position ids in some models are not calculated/inferred from attn mask in
forward
, which gives incorrect positions when the inputs is left padded.To be consistent and for ease of maintaining, the logic of inferring position ids is moved to "modeling_utils.py" and all generative models call that method in their
forward
andprepare_inputs_for_generation
. I added two tests, to check whether model outputs are same when position ids are passed by a user vs. when inferred from input ids or embeds.Also Fixes #29149.
The newly added tests are passing. Plus slow tests on vision models, because they still do not have GenerationTesterMixin.
Btw, I see that non-generative models already use
create_position_ids_from_input_ids
method which is copied separately in each model's file. The logic is a bit different from generative models, because they start counting from "padding_idx" and not "0". Anyway, I guess it is still possible to merge that method and the one proposed here, to have one "get_position_id" for all models in the "modeling_utils".@gante WDYT ?