Skip to content

Commit

Permalink
[Misc] Simplify code and fix type annotations in conftest.py (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Jun 2, 2024
1 parent a66cf40 commit dfbe60d
Showing 1 changed file with 42 additions and 50 deletions.
92 changes: 42 additions & 50 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@

import pytest
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
LlavaConfig, LlavaForConditionalGeneration)

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.inputs import PromptInputs
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.sequence import MultiModalData
from vllm.sequence import MultiModalData, SampleLogprobs

logger = init_logger(__name__)

Expand Down Expand Up @@ -188,10 +189,11 @@ def generate(
prompts: List[str],
images: Optional[List[Image.Image]] = None,
**kwargs,
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
) -> List[Tuple[List[List[int]], List[str]]]:
if images:
assert len(prompts) == len(images)

outputs: List[Tuple[List[List[int]], List[str]]] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
Expand All @@ -201,17 +203,13 @@ def generate(
processor_kwargs["images"] = images[i]

inputs = self.processor(**processor_kwargs)
inputs = {
key: value.cuda() if value is not None else None
for key, value in inputs.items()
}

output_ids = self.model.generate(
**inputs,
**inputs.to("cuda"),
use_cache=True,
**kwargs,
)
output_str = self.tokenizer.batch_decode(
output_str = self.processor.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
Expand All @@ -224,23 +222,22 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional["torch.Tensor"] = None,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
images=images)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
outputs[i] = (output_ids[0], output_str[0])
return outputs

return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[List[int]], List[str]]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
Expand Down Expand Up @@ -282,9 +279,7 @@ def generate_greedy_logprobs(
if self.model.get_output_embeddings().bias is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
all_logprobs.append(seq_logprobs)
return all_logprobs
Expand All @@ -294,10 +289,10 @@ def generate_greedy_logprobs_limit(
prompts: List[str],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str]]:
all_logprobs = []
all_output_ids = []
all_output_strs = []
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []

for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
Expand All @@ -310,7 +305,7 @@ def generate_greedy_logprobs_limit(
return_dict_in_generate=True,
)

seq_logprobs = []
seq_logprobs: List[torch.Tensor] = []
for _, hidden_states in enumerate(output.hidden_states):
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
Expand All @@ -321,13 +316,11 @@ def generate_greedy_logprobs_limit(
None) is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)

# convert to dict
seq_logprobs_lst = []
seq_logprobs_lst: List[Dict[int, float]] = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
Expand Down Expand Up @@ -372,13 +365,13 @@ def __init__(
tokenizer_name: Optional[str] = None,
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len=1024,
max_model_len: int = 1024,
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space=4,
swap_space: int = 4,
**kwargs,
) -> None:
self.model = LLM(
Expand All @@ -399,32 +392,31 @@ def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional["torch.Tensor"] = None,
) -> List[Tuple[List[int], str]]:
images: Optional[torch.Tensor] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == images.shape[0]
assert len(prompts) == len(images)

prompt_inputs: List[PromptInputs] = []
prompt_inputs: List[TextPrompt] = []
for i, prompt in enumerate(prompts):
image = None if images is None else images[i:i + 1]
mm_data = None if image is None else MultiModalData(
type=MultiModalData.Type.IMAGE,
data=image,
)
prompt = TextPrompt(prompt=prompt)
if images is not None:
prompt["multi_modal_data"] = MultiModalData(
type=MultiModalData.Type.IMAGE,
data=images[i:i + 1],
)

prompt_inputs.append({
"prompt": prompt,
"multi_modal_data": mm_data,
})
prompt_inputs.append(prompt)

req_outputs = self.model.generate(prompt_inputs,
sampling_params=sampling_params)
outputs = []

outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids = []
req_sample_output_strs = []
req_sample_output_ids: List[List[int]] = []
req_sample_output_strs: List[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
Expand All @@ -437,12 +429,12 @@ def generate_w_logprobs(
self,
prompts: List[str],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None

req_outputs = self.model.generate(prompts,
sampling_params=sampling_params)
outputs = []
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
for req_output in req_outputs:
for sample in req_output.outputs:
output_str = sample.text
Expand All @@ -467,7 +459,7 @@ def generate_greedy_logprobs(
prompts: List[str],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs)
Expand All @@ -481,7 +473,7 @@ def generate_beam_search(
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[List[int]], List[str]]]:
beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
Expand Down

0 comments on commit dfbe60d

Please sign in to comment.