Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unittests for Universal Assisted generation #8

Merged
merged 4 commits into from
Dec 19, 2024
Merged

Conversation

gauravjain14
Copy link
Collaborator

Introduce Unit tests for Universal Assisted Generation

tests/test_universal_assisted_generation.py is intended to test the functionality introduced by universal assisted generation.

Note: All but test_basic_generation have been disabled for now.

Who can review?

@keyboardAnt @jmamou

@gauravjain14
Copy link
Collaborator Author

Proposing to include some unittests to ensure functionality.

I am, however, encountering some errors in these dummy examples. Any inputs into what this test might be missing?

The following is the error I am seeing -

======================================================================
ERROR: test_basic_generation (__main__.TestUniversalSpeculativeDecoding.test_basic_generation)
Test basic speculative decoding works
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/disk1/universal_assisted_generation/transformers/tests/test_universal_assisted_generation.py", line 45, in test_basic_generation
    candidates, scores = self.generator.get_candidates(input_ids)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 744, in get_candidates
    target_logits = self._atm_translator.get_target_logits(candidate_logits)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 628, in get_target_logits
    .apply_(lambda x: self._assistant_to_target_input_ids[x])
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 628, in <lambda>
    .apply_(lambda x: self._assistant_to_target_input_ids[x])
                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
KeyError: 151665

----------------------------------------------------------------------
Ran 8 tests in 20.816s

FAILED (errors=1)

@jmamou
Copy link
Collaborator

jmamou commented Dec 12, 2024

Thanks @gauravjain14
For now, I propose to run the tests on #7 branch that contains bug fixes and not on the main branch

@gauravjain14
Copy link
Collaborator Author

@jmamou I'll rebase this on that.

However, this error occurs on #7 as well.

@keyboardAnt
Copy link
Owner

keyboardAnt commented Dec 12, 2024

Proposing to include some unittests to ensure functionality.

I am, however, encountering some errors in these dummy examples. Any inputs into what this test might be missing?

The following is the error I am seeing -


======================================================================

ERROR: test_basic_generation (__main__.TestUniversalSpeculativeDecoding.test_basic_generation)

Test basic speculative decoding works

----------------------------------------------------------------------

Traceback (most recent call last):

  File "/disk1/universal_assisted_generation/transformers/tests/test_universal_assisted_generation.py", line 45, in test_basic_generation

    candidates, scores = self.generator.get_candidates(input_ids)

                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/disk1/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 744, in get_candidates

    target_logits = self._atm_translator.get_target_logits(candidate_logits)

                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/disk1/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 628, in get_target_logits

    .apply_(lambda x: self._assistant_to_target_input_ids[x])

     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/disk1/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 628, in <lambda>

    .apply_(lambda x: self._assistant_to_target_input_ids[x])

                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^

KeyError: 151665



----------------------------------------------------------------------

Ran 8 tests in 20.816s



FAILED (errors=1)

It seems the drafter sampled a token that the translator does not include. Perhaps that token is not in the target vocabulary?

@jmamou
Copy link
Collaborator

jmamou commented Dec 12, 2024

@gauravjain14
Fixed in last push
#7

@gauravjain14
Copy link
Collaborator Author

Rebased on @jmamou's changes in #7.

Removed some tests that seemed unnecessary. All tests pass.

I have disabled this test for now -

    def test_long_sequence(self):
        if False:
            """Test handling of very long input sequences"""
            long_input = torch.ones((1, 2048), dtype=torch.long)
            self.generator.input_ids = long_input
            candidates, scores = self.generator.get_candidates(long_input)
            self.assertLessEqual(
                candidates.shape[1],
                self.main_model.config.max_position_embeddings,
            )

Let me know what y'all think about it. If we should have it I can enable it. Disabled due to the context length.

@keyboardAnt keyboardAnt requested review from keyboardAnt and jmamou and removed request for jmamou December 13, 2024 07:37
Copy link
Owner

@keyboardAnt keyboardAnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gauravjain14, thank you, it looks good! I added minor comments.

tests/test_universal_assisted_generation.py Outdated Show resolved Hide resolved
tests/test_universal_assisted_generation.py Outdated Show resolved Hide resolved
@classmethod
def setUpClass(cls):
# Setup main and assistant models
cls.main_model = AutoModelForCausalLM.from_pretrained(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it take <5s to load this 1B model? (Please see @gante's comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it takes about 33 seconds on a T4 machine. I think we can just add the tag @slow as mentioned in the comment. Wdyt?

Copy link
Owner

@keyboardAnt keyboardAnt Dec 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using smaller models? There are a few examples of fast models used in existing Hugging Face tests.

@slow tests run less frequently, so I suggest striving for faster tests.

Copy link
Owner

@keyboardAnt keyboardAnt Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gauravj14 @gauravjain14
Models for testing: https://huggingface.co/hf-internal-testing. For example, hf-internal-testing/tiny-random-gpt2 as used here.

tests/test_universal_assisted_generation.py Outdated Show resolved Hide resolved
tests/test_universal_assisted_generation.py Outdated Show resolved Hide resolved
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wdyt about moving the content of this file to tests/generation/test_candidate_generator.py?

@gauravjain14
Copy link
Collaborator Author

Here's a quick update -

I am running the test 'test_mismatched_vocabulariesand for certain tokens the test is failing with the cache andpast_key_values` being empty (or of size 0). I am looking into but in case either of you want to give it a try, this is one of the tokens I have seen this issue on -

input_ids = torch.tensor([[128245]])

The test -

def test_mismatched_vocabularies(self):

@jmamou
Copy link
Collaborator

jmamou commented Dec 18, 2024

input_ids = torch.tensor([[128245]])

According to the target tokenizer, token_id 128245 corresponds to the special token '<|reserved_special_token_237|>'. We currently don't handle the case when the original prompt contains only special tokens.
I will try to handle that case.

@gauravjain14
Copy link
Collaborator Author

input_ids = torch.tensor([[128245]])

According to the target tokenizer, token_id 128245 corresponds to the special token '<|reserved_special_token_237|>'. We currently don't handle the case when the original prompt contains only special tokens. I will try to handle that case.

Got it. Thanks for the quick response on that.

So, how do you propose we handle this for now? Should we skip special tokens in the target vocab or you think this will be a quick fix?

@jmamou
Copy link
Collaborator

jmamou commented Dec 18, 2024

input_ids = torch.tensor([[128245]])

According to the target tokenizer, token_id 128245 corresponds to the special token '<|reserved_special_token_237|>'. We currently don't handle the case when the original prompt contains only special tokens. I will try to handle that case.

Got it. Thanks for the quick response on that.

So, how do you propose we handle this for now? Should we skip special tokens in the target vocab or you think this will be a quick fix?

Let's skip special tokens for now. Note that UAG does not handle that case also.

Copy link
Owner

@keyboardAnt keyboardAnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I only added two small comments.

tests/test_universal_assisted_generation.py Outdated Show resolved Hide resolved
tests/test_universal_assisted_generation.py Outdated Show resolved Hide resolved
@gauravjain14 gauravjain14 merged commit 7088978 into usd Dec 19, 2024
@keyboardAnt keyboardAnt deleted the unit_tests_usd branch December 19, 2024 19:02
@keyboardAnt keyboardAnt restored the unit_tests_usd branch December 19, 2024 19:02
Copy link
Owner

@keyboardAnt keyboardAnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gauravjain14, Please see the minor comment below.

import threading
import unittest
import weakref
from unittest.mock import MagicMock

from zmq import device
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gauravjain14 please remove this line

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants