-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modular phi #34361
base: main
Are you sure you want to change the base?
Modular phi #34361
Conversation
Awesome PR! 🥳 CC @Cyrilvallez |
48303e6
to
d942994
Compare
Rebased for failing CIs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey~
As a first review, let's try to use more strength of inheritance
If you look at the PhiAttention
a lot of stuff can be skimmed out because it already exists in Gemma for example. Thus you can only add what's required and use del
to remove what's to remove (for example dense is o_proj I think)
PhiForCausalLM and all should inherit from Gemma rather than Llama as they can be 1-1 the same this way!
Congrats on the PR 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! 🤗 I let some comments to use more inheritance, you could try to go even further if you can!
class PhiForTokenClassification(LlamaForTokenClassification): | ||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This cannot use Llama
here as it causes self.classifier
to become self.score
, which may break weights initialization of pretrained models. You could try to find another model with the same names in the init 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found LlamaForTokenClassification closest to PhiForTokenClassification. Although it's copied from MptForTokenClassification, But SubClassing it causes issues for build_phi_alibi_tensor
(i.e. build_mpt_alibi_tensor).
But I've addressed the issue by deleting self.score and adding self.classifier (basically renaming) from LlamaForTokenClassification.
@Cyrilvallez Let me know if changes required.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
d942994
to
12b3901
Compare
Feel free to ping us again once ready! 🤗 |
@ArthurZucker . Oh, I think that's all. Suggest me any changes if required. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Just some final details to take care of 🤗 Sorry for the delay, all the team was in Martinique for a big offsite lately, we just came back
def __init__(self, config): | ||
super().__init__(config) | ||
self.model = PhiModel(config) | ||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) | ||
self.post_init() | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
labels: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
num_logits_to_keep: int = 0, | ||
**loss_kwargs, | ||
) -> Union[Tuple, CausalLMOutputWithPast]: | ||
r""" | ||
```python | ||
>>> from transformers import AutoTokenizer, PhiForCausalLM | ||
|
||
>>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1") | ||
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1") | ||
|
||
>>> prompt = "This is an example script ." | ||
>>> inputs = tokenizer(prompt, return_tensors="pt") | ||
|
||
>>> # Generate | ||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) | ||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||
'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str' | ||
```""" | ||
|
||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
outputs = self.model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_values=past_key_values, | ||
inputs_embeds=inputs_embeds, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
cache_position=cache_position, | ||
) | ||
|
||
hidden_states = outputs[0] | ||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss | ||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) | ||
|
||
loss = None | ||
if labels is not None: | ||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) | ||
|
||
if not return_dict: | ||
output = (logits,) + outputs[1:] | ||
return (loss,) + output if loss is not None else output | ||
|
||
return CausalLMOutputWithPast( | ||
loss=loss, | ||
logits=logits, | ||
past_key_values=outputs.past_key_values, | ||
hidden_states=outputs.hidden_states, | ||
attentions=outputs.attentions, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All this code should not be needed, a single pass
should be enough as this is identical with GemmaForCausalLM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cannot possibly remove the forward method completely, as it is required at least for the documentation. Changed to super().forward(.....)
. @Cyrilvallez
d67d5b8
to
89199ac
Compare
Rebased. @Cyrilvallez |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Sorry for the delay!! Last small comments to cut even more code, then we're good to go! 🤗
def __init__(self, config): | ||
super().__init__(config) | ||
self.model = PhiModel(config) | ||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) | ||
self.post_init() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi. @Cyrilvallez. I'm concerned. That removing this will result in the lm_head with bias=False
as it is in Gemma. And I'm skeptic that it will affect the model initialization Or at least the output. I don't think we can neglect the bias. So kept the lm_head with bias=True
in PhiForCausalLM
.
Let me know if I'm correct or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes sorry, did not notice the bias! Then of course we need to keep it. You can still remove the self.model = ...
and the post_init()
lines though
77d1fd3
to
67ebbaf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All right, LGTM! Great job thanks!
If #34858 gets merged before this one, you'll just need to make sure to modify accordingly, but it will be extremely straightforward!
🤗 |
…amline initialization
…e past_key_values type
…osition embeddings handling
67ebbaf
to
9bf4eae
Compare
Done! @Cyrilvallez |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Do you want to include the recent changes in #35235 ? 🤗
@@ -1089,8 +1082,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position( | |||
|
|||
class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): | |||
_tied_weights_keys = ["lm_head.weight"] | |||
_tp_plan = {"lm_head": "colwise_rep"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for TP, the config needs TP as well!
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True | ||
) | ||
|
||
self.rotary_emb = PhiRotaryEmbedding(config=self.config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be removed!
What does this PR do?
Adds Modular Phi #33916
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @LysandreJik
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.