Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OLD] New PR: #35029. [[Universal Speculative Decoding CandidateGenerator]] #34760

Closed

Conversation

keyboardAnt
Copy link
Contributor

@keyboardAnt keyboardAnt commented Nov 16, 2024

Please see the new PR: #35029


This PR is ready for initial review, though some aspects are still a work-in-progress.

What does this PR do?

This PR introduces the UniversalSpeculativeDecodingGenerator class, enabling speculative decoding for assistants with slightly different tokenizers. The key addition is two logits processors (LogitsProcessor) that ensure the assistant generates tokens exclusively from the target vocabulary, maintaining alignment and preserving the target distribution without altering the verification method. Theoretically, it is agnostic to the do_sample choice. This avoids issues like #32867 and #33534 and sets the stage for advanced universal speculative decoding techniques (that we are currently researching and have not yet been published).


Motivation and Context

This update resolves prior inconsistencies in speculative decoding caused by misaligned vocabularies. Key benefits include:

  • Ensuring the assistant generates only tokens present in the target vocabulary.
  • Lossless preservation of the target distribution.
  • Compatibility with future speculative decoding advancements.

This PR is a step toward advancements in Universal Assisted Generation, in collaboration with @danielkorat, @orenpereg, @mosheber, @jmamou, @gante, @lewtun, and @MosheWasserb.


Related

Issues:

PRs:


Dependencies

No additional dependencies.


Before Submitting Checklist


Who can review?

@keyboardAnt keyboardAnt marked this pull request as ready for review November 16, 2024 20:34
@gauravjain14
Copy link

gauravjain14 commented Nov 20, 2024

https://arxiv.org/pdf/2404.09492

Found this paper that attempts to align different vocabularies that works across different LLM families and then creates a projection matrix that projects the different LLM outputs to the same embedding domain.

Screenshot 2024-11-20 at 3 27 40 PM

@keyboardAnt
Copy link
Contributor Author

keyboardAnt commented Nov 21, 2024

https://arxiv.org/pdf/2404.09492

Found this paper that attempts to align different vocabularies that works across different LLM families and then creates a projection matrix that projects the different LLM outputs to the same embedding domain.

Screenshot 2024-11-20 at 3 27 40 PM

Thanks for sharing, Gaurav.

My takeaways from the paper:

  1. Out of the box, without additional training, we can test the PR on the pair of TigerBot and Chinese Alpaca. Their vocabularies differ from each other and have the maximum overlap. The Chinese Alpaca repo (https://github.com/ymcui/Chinese-LLaMA-Alpaca) has more than 18k stars so speeding it up could be an impactful example. What are your thoughts? The TigerBot and Chinese Alpaca model families only offer models with ≥7B params, which might be too large for a drafter.
  2. With additional training we can do some cool stuff. :)

@jmamou
Copy link
Contributor

jmamou commented Nov 21, 2024

https://arxiv.org/pdf/2404.09492
Found this paper that attempts to align different vocabularies that works across different LLM families and then creates a projection matrix that projects the different LLM outputs to the same embedding domain.
Screenshot 2024-11-20 at 3 27 40 PM

Thanks for sharing, Gaurav.

My takeaways from the paper:

  1. Out of the box, without additional training, we can test the PR on the pair of TigerBot and Chinese Alpaca. Their vocabularies differ from each other and have the maximum overlap. The Chinese Alpaca repo (https://github.com/ymcui/Chinese-LLaMA-Alpaca) has more than 18k stars so speeding it up could be an impactful example. What are your thoughts? The TigerBot and Chinese Alpaca model families only offer models with ≥7B params, which might be too large for a drafter.
  2. With additional training we can do some cool stuff. :)

I have run similar analysis on the model pairs used in Universal Assisted Generation blog.

Last columns in the table below represent respectively the overlap percentage of draft vocab w.r.t. to target vocab and the overlap percentage of target vocab w.r.t. to draft vocab.

target model target vocab size  draft model draft vocab size vocab overlap overlap D/T % overlap T/D %
codellama/CodeLlama-13b-Instruct-hf 32016 bigcode/tiny_starcoder_py 49152 8481 26 17
mistralai/Mixtral-8x22B-Instruct-v0.1 32768 double7/vicuna-68m 32000 24184 74 76
google/gemma-2-9b 256000 double7/vicuna-68m 32000 30489 12 95
mistralai/Mixtral-8x22B-Instruct-v0.1 32768 Qwen/Qwen2-0.5B-Instruct 151646 10566 32 7
meta-llama/Llama-3.1-70B 128256 Qwen/Qwen2-0.5B-Instruct 151646 109566 85 72
microsoft/Phi-3-medium-128k-instruct 32011 Qwen/Qwen2-0.5B-Instruct 151646 9588 30 6

candidate_ids = assistant_output.sequences
device = candidate_ids.device
candidate_ids = candidate_ids.cpu()
candidate_ids.apply_(lambda x: self._assistant_to_target_input_ids[x])
Copy link
Contributor

@jmamou jmamou Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to fix a bug here, some x values are missing from the _assistant_to_target_input_ids dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might indicate a deeper problem because the drafter shouldn't generate tokens that are not in self._assistant_to_target_input_ids (representing the intersection between the vocabularies). I expect the suppress processor to zero out the probability of generating such tokens, but I might have missed something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bug occurs when some prompt draft token ids are missing from _assistant_to_target_input_ids. I believe I already addressed this bug in the get_target_ids function.

@gauravjain14
Copy link

Seeing this error on the latest commit to universal-speculatie-decoding -

Traceback (most recent call last):
  File "/disk/universal_assisted_generation/llama_qwen.py", line 18, in <module>
    outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/anaconda3/envs/uag/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/utils.py", line 2191, in generate
    result = self._assisted_decoding(
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/utils.py", line 4298, in _assisted_decoding
    valid_tokens, n_matches = _speculative_sampling(
                              ^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/utils.py", line 4450, in _speculative_sampling
    q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
          ~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: index 5 is out of bounds for dimension 0 with size 5

This is when using the following models -

prompt = "Alice and Bob"
checkpoint = "meta-llama/Llama-3.1-8B-Instruct"
assistant_checkpoint = "Qwen/Qwen2-0.5B-Instruct"

@keyboardAnt
Copy link
Contributor Author

Update: I added caching and some tests. The current issue is the dimensions of the output logits.

@gauravjain14
Copy link

Update:

I pulled all the recent commits and run the same example as above - using Qwen2-0.5B and Llama-3.1-8B - and hit the following error

Traceback (most recent call last):
  File "/disk/universal_assisted_generation/llama_qwen.py", line 18, in <module>
    outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/utils.py", line 2191, in generate
    result = self._assisted_decoding(
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/utils.py", line 4247, in _assisted_decoding
    candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 782, in get_candidates
    target_logits = self._atm_translator.get_target_logits(candidate_logits)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 620, in get_target_logits
    target_logits_supported_indices: torch.IntTensor = assistant_logits_supported_indices.apply_(
                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 621, in <lambda>
    lambda x: self._assistant_to_target_input_ids[x]
              ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
KeyError: 151646

This looks like expected because the vocab size for qwen2-0.5B is 128256
and for assistant_vocab (llama-3.1-8B) - 151646.

Some debugging tells me this a valid error if the assistant model is still generating the tokens from its vocab space rather than being confined to the target space.

@keyboardAnt keyboardAnt force-pushed the universal-speculatie-decoding branch from c8def75 to f163339 Compare November 25, 2024 12:22
@keyboardAnt keyboardAnt requested a review from jmamou November 26, 2024 05:00
@keyboardAnt
Copy link
Contributor Author

keyboardAnt commented Nov 26, 2024

Update:

I pulled all the recent commits and run the same example as above - using Qwen2-0.5B and Llama-3.1-8B - and hit the following error

Traceback (most recent call last):
  File "/disk/universal_assisted_generation/llama_qwen.py", line 18, in <module>
    outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/utils.py", line 2191, in generate
    result = self._assisted_decoding(
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/utils.py", line 4247, in _assisted_decoding
    candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 782, in get_candidates
    target_logits = self._atm_translator.get_target_logits(candidate_logits)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 620, in get_target_logits
    target_logits_supported_indices: torch.IntTensor = assistant_logits_supported_indices.apply_(
                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/candidate_generator.py", line 621, in <lambda>
    lambda x: self._assistant_to_target_input_ids[x]
              ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
KeyError: 151646

This looks like expected because the vocab size for qwen2-0.5B is 128256 and for assistant_vocab (llama-3.1-8B) - 151646.

Some debugging tells me this a valid error if the assistant model is still generating the tokens from its vocab space rather than being confined to the target space.

Thanks @gauravjain14.
I fixed the code to pass the tests and added more tests.
Could you please retry? Also, please feel free to add your test to tests/generation/test_candidate_generator.py.

@jmamou
Copy link
Contributor

jmamou commented Nov 26, 2024

@gauravjain14
Code below works for me

  prompt = "Alice and Bob"
  checkpoint = "meta-llama/Llama-3.1-8B-Instruct"
  assistant_checkpoint = "Qwen/Qwen2-0.5B-Instruct"

  assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_checkpoint)
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
  inputs = tokenizer(prompt, return_tensors="pt")

  model = AutoModelForCausalLM.from_pretrained(checkpoint)
  assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)

  generate_kwargs = {
      "do_sample": True,
      "assistant_model": assistant_model,
      "assistant_tokenizer": assistant_tokenizer,
      "tokenizer": tokenizer,
  }

  outputs = model.generate(**inputs, **generate_kwargs)
  
  print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

@keyboardAnt
Copy link
Contributor Author

To make this easier to review, I've split off a smaller PR (#35009) that focuses purely on refactoring the existing code, without introducing the new Universal SD features.

The refactor aims to:

  • Reduce technical debt and improve maintainability.
  • Prepare the code for the upcoming Universal SD (this PR).
  • Make it more approachable for community contributions.

@ArthurZucker, @gante, I’d love your feedback and review when you have a moment. Thanks so much!

@keyboardAnt keyboardAnt changed the title [WIP] Universal speculatie decoding Universal Speculative Decoding CandidateGenerator Nov 28, 2024
@keyboardAnt keyboardAnt changed the title Universal Speculative Decoding CandidateGenerator [OLD] Universal Speculative Decoding CandidateGenerator Nov 30, 2024
@keyboardAnt keyboardAnt force-pushed the universal-speculatie-decoding branch from bdff66d to 3e23690 Compare November 30, 2024 19:03
@keyboardAnt keyboardAnt changed the title [OLD] Universal Speculative Decoding CandidateGenerator [OLD] ~~Universal Speculative Decoding CandidateGenerator~~ New PR: #35029 Nov 30, 2024
@keyboardAnt keyboardAnt changed the title [OLD] ~~Universal Speculative Decoding CandidateGenerator~~ New PR: #35029 [OLD] New PR: #35029. [[Universal Speculative Decoding CandidateGenerator]] Nov 30, 2024
@keyboardAnt keyboardAnt reopened this Nov 30, 2024
@keyboardAnt
Copy link
Contributor Author

This branch has diverged from main. To make it easier, I opened a new PR: #35029

Thanks @gauravjain14 for spotting it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants