-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Sequence-Level KD #2220
Add Sequence-Level KD #2220
Conversation
Add the third issue fixs
thanks @mst272 can you also kindly add these options to the docs-strings and the documentation of |
# Conflicts: # examples/scripts/dpo.py # trl/commands/cli_utils.py
hi@kashif,I've added these to the docs-strings and the documentation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I leave the last word to @kashif
Co-authored-by: Quentin Gallouédec <[email protected]>
Hi there, I don't really understand how this PR adds seq_kd. To my understanding the seq_kd computes the standard In this PR we are simply generating the teacher output then computing the same In the documentation it says:
But this is the definition of the |
@moussaKam So recall there are 2 other parameters apart from the |
@moussaKam also note that in this case, the KL-div is the same as the CE and a constant term, i.e. the entropy of the target which we assume does not change |
@kashif thanks for you're reply, yes according to the definition from the paper
I understand that we compute the cross-entropy in the case of seq_kd. This is what we do in standard sft no? |
@moussaKam so in the first we generate completions and in the 2nd we calculate the logits of the completions... i suppose we could do that once, and then keep track of it with a bunch of if-else but opted for some cleaner logic here that could work for any of the different hyperparams... any ideas on how to make it a bit more dry? |
@moussaKam Mind you there is an orthogonal abstraction i have been working on where instead of the logits (which are assumed to come from the same vocab. size for both the student and teacher) we allow for the student-teacher to have different vocabs: see #2263 and I would welcome any thoughts if this should be a separate class? |
@kashif, we don't need to compute the logits, we generate the output with the teacher which becomes the new labels, then we run the forward of the student and compute the cross entropy using just the teacher output tokens. I can implement it in the afternoon if it sounds good for you. |
What does this PR do?
In the original paper, they compared Sequence-Level KD 、Supervised KD and GKD (On-policy). In trl GKDTrainer, Supervised KD and GKD have been implemented. So i add Sequence-Level KD in GKDTrainer, control whether to perform Sequence-Level KD through a seq_kd parameter.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.