Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inputs_embeds #687

Merged
merged 13 commits into from
Dec 1, 2023
Merged

Support inputs_embeds #687

merged 13 commits into from
Dec 1, 2023

Conversation

samhavens
Copy link
Contributor

This allows users to pass in embeddings directly instead of looking them up based on input_ids. This is useful for PEFT (prompt/prefix-tuning) and multimodal models.

@samhavens samhavens requested review from dakinggg and vchiley October 20, 2023 20:37
Copy link
Contributor

@vchiley vchiley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should input_ids default to None with this change?

input_ids: Optional[torch.LongTensor] = None,

@samhavens samhavens requested a review from vchiley October 23, 2023 16:46
in MPTForCausalLM. It is checked in the base model, and it is actually a common practice to pass both during autoregressive generation. Embeds are used first, then once the kvcache is nonempty, iids are used instead
@dakinggg
Copy link
Collaborator

@samhavens
Copy link
Contributor Author

@samhavens here is what HF expects wrt these two args: https://github.com/huggingface/transformers/blob/9333bf0769561c048700377c2e0813221ab9d2c9/src/transformers/models/llama/modeling_llama.py#L955-L963

@dakinggg That validation happens in the case model, but it the CausalLM model they can both be there https://github.com/huggingface/transformers/blob/9333bf0769561c048700377c2e0813221ab9d2c9/src/transformers/models/llama/modeling_llama.py#L1140-L1150

I think this is because in prepare_inputs_for_generation they use them on the first decoding step then ignore them https://github.com/huggingface/transformers/blob/9333bf0769561c048700377c2e0813221ab9d2c9/src/transformers/models/llama/modeling_llama.py#L1209-L1213

Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving cause LGTM, but please add a test for the edge cases:
(1) both input_ids and input_embeds are None
(2a) both input_ids and input_embeds are specified without kv cache
(2b) both input_ids and input_embeds are specified with kv cache

@samhavens
Copy link
Contributor Author

Approving cause LGTM, but please add a test for the edge cases: (1) both input_ids and input_embeds are None (2a) both input_ids and input_embeds are specified without kv cache (2b) both input_ids and input_embeds are specified with kv cache

@dakinggg do you mean new tests or add these edge cases to the 2 tests that have input embeds

@dakinggg
Copy link
Collaborator

dakinggg commented Nov 30, 2023

@samhavens add these edge cases to existing tests is fine (assuming that fits the tests). Anything that tests the right thing happens for those combinations is sufficient.

@samhavens samhavens merged commit 22ae919 into main Dec 1, 2023
10 checks passed
@samhavens samhavens deleted the support-inp-emb branch December 1, 2023 01:47
aspfohl pushed a commit that referenced this pull request Dec 1, 2023
* support inputs_embeds

* update tests to test inputs_embeds

* make iids optional inputs to fwd

* remove check for both iids and inputs_embeds

in MPTForCausalLM. It is checked in the base model, and it is actually a common practice to pass both during autoregressive generation. Embeds are used first, then once the kvcache is nonempty, iids are used instead

* reorder kwargs

* add more tests

* fix device merge artifact in test_model.oy

* fix generate test

* yapf
aspfohl added a commit that referenced this pull request Dec 2, 2023
* Add eval loader to eval script

* small input tests

* updates

* fix typing and formatting

* fixes, add tests

* remove circular dependency

* tests pass

* nits + small fixes

* add metrics at the end, refactor to put icl/gauntlet as helpers

* NOT

* metrics instead of models, add unit tests

* Move tests into directories

* add copyright to inits

* fix relative paths

* fixes

* revert gauntlet test change

* Support inputs_embeds (#687)

* support inputs_embeds

* update tests to test inputs_embeds

* make iids optional inputs to fwd

* remove check for both iids and inputs_embeds

in MPTForCausalLM. It is checked in the base model, and it is actually a common practice to pass both during autoregressive generation. Embeds are used first, then once the kvcache is nonempty, iids are used instead

* reorder kwargs

* add more tests

* fix device merge artifact in test_model.oy

* fix generate test

* yapf

* Better error message when test does not complete (#769)

* run script tests first

* comment out

* ascripts -> scripts

* bad dirs

* try this

* hacks

* add a note about a_scripts

---------

Co-authored-by: Sam Havens <[email protected]>
dakinggg pushed a commit to dakinggg/llm-foundry that referenced this pull request Dec 2, 2023
* Add eval loader to eval script

* small input tests

* updates

* fix typing and formatting

* fixes, add tests

* remove circular dependency

* tests pass

* nits + small fixes

* add metrics at the end, refactor to put icl/gauntlet as helpers

* NOT

* metrics instead of models, add unit tests

* Move tests into directories

* add copyright to inits

* fix relative paths

* fixes

* revert gauntlet test change

* Support inputs_embeds (mosaicml#687)

* support inputs_embeds

* update tests to test inputs_embeds

* make iids optional inputs to fwd

* remove check for both iids and inputs_embeds

in MPTForCausalLM. It is checked in the base model, and it is actually a common practice to pass both during autoregressive generation. Embeds are used first, then once the kvcache is nonempty, iids are used instead

* reorder kwargs

* add more tests

* fix device merge artifact in test_model.oy

* fix generate test

* yapf

* Better error message when test does not complete (mosaicml#769)

* run script tests first

* comment out

* ascripts -> scripts

* bad dirs

* try this

* hacks

* add a note about a_scripts

---------

Co-authored-by: Sam Havens <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants