Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unittests for Universal Assisted generation #8

Merged
merged 4 commits into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions tests/generation/test_candidate_generator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import gc
import logging
import threading
import unittest
import weakref
from unittest.mock import MagicMock

from zmq import device
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gauravjain14 please remove this line


import numpy as np
import torch

from transformers.generation.candidate_generator import (
AssistantToTargetTranslator,
AssistantVocabTranslatorCache,
AssistedCandidateGeneratorDifferentTokenizers,
UniversalSpeculativeDecodingGenerator
)

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig


class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
def test_no_intersection(self):
Expand Down Expand Up @@ -216,3 +222,87 @@ def get_translator():
# All translators should be the same instance
for translator in translators:
self.assertIs(translators[0], translator, "All translators should be identical across threads")


class TestUniversalSpeculativeDecoding(unittest.TestCase):
device = "cuda" if torch.cuda.is_available() else "cpu"

@classmethod
def setUpClass(cls):
cls.assistant_model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-gpt2").to(cls.device)
cls.main_tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct")
cls.assistant_tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-gpt2")
cls.generation_config = GenerationConfig()

# Ensure required tokens exist
if cls.main_tokenizer.pad_token_id is None:
cls.main_tokenizer.pad_token_id = cls.main_tokenizer.eos_token_id
if cls.main_tokenizer.bos_token_id is None:
cls.main_tokenizer.bos_token_id = cls.main_tokenizer.eos_token_id

def setUp(self):
self.input_ids = torch.tensor([[1, 2, 3]]).to(self.device)
self.model_kwargs = {
"attention_mask": torch.ones_like(self.input_ids).to(self.device),
}
self.generator = UniversalSpeculativeDecodingGenerator(
input_ids=self.input_ids,
assistant_model=self.assistant_model,
target_tokenizer=self.main_tokenizer,
assistant_tokenizer=self.assistant_tokenizer,
generation_config=self.generation_config,
model_kwargs=self.model_kwargs,
target_vocab_size=self.main_tokenizer.vocab_size,
)

def test_basic_generation(self):
"""Test basic speculative decoding works"""
input_text = "The quick brown fox"
input_ids = self.main_tokenizer.encode(input_text, return_tensors="pt")
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)

self.assertIsNotNone(candidates)
self.assertIsNotNone(scores)
self.assertTrue(torch.is_tensor(candidates))
self.assertTrue(torch.is_tensor(scores))

def test_mismatched_vocabularies(self):
"""Test handling of mismatched vocabularies between models"""
# Create input with tokens present in main but not assistant vocab
# Find a token that is not in the assistant tokenizer but in
# the main tokenizer.
missing_token = next(
token for token in self.main_tokenizer.get_vocab()
if token not in self.assistant_tokenizer.get_vocab() and
token not in self.main_tokenizer.all_special_tokens and
"reserved_" not in token
)
input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]])
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)
self.assertIsNotNone(candidates)

def test_speculation_depth(self):
"""Test different speculation depths"""
input_ids = self.main_tokenizer.encode("Test text", return_tensors="pt")
self.generator.input_ids = input_ids

for depth in [1, 8, 17]:
self.generator.num_assistant_tokens = depth
candidates, scores = self.generator.get_candidates(input_ids)
self.assertLessEqual(
candidates.shape[1] - input_ids.shape[1], depth
)

def test_device_consistency(self):
"""Test handling of inputs on different devices"""
if torch.cuda.is_available():
input_ids = torch.tensor([[1, 2, 3]]).to(
self.generator.assistant_model.device)
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)
self.assertEqual(candidates.device, input_ids.device)