From a31cab7556f540b558b0b454b4a4b9b438542566 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 6 Jun 2024 18:12:00 -0700 Subject: [PATCH] [Core] Avoid copying prompt/output tokens if no penalties are used (#5289) --- vllm/model_executor/sampling_metadata.py | 80 +++++++++++++++--------- 1 file changed, 50 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 0b3b41e69d6bc..7ad84f51b7e4c 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -386,16 +386,18 @@ def from_sampling_metadata( presence_penalties += [0] * prefill_len frequency_penalties += [0] * prefill_len repetition_penalties += [1] * prefill_len - prompt_tokens.extend([] for _ in range(prefill_len)) - output_tokens.extend([] for _ in range(prefill_len)) + if do_penalties: + prompt_tokens.extend([] for _ in range(prefill_len)) + output_tokens.extend([] for _ in range(prefill_len)) if seq_group.do_sample: sample_lens = len(seq_group.sample_indices) assert sample_lens == len(seq_ids) for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] - prompt_tokens.append(seq_data.prompt_token_ids) - output_tokens.append(seq_data.output_token_ids) + if do_penalties: + prompt_tokens.append(seq_data.prompt_token_ids) + output_tokens.append(seq_data.output_token_ids) temperatures += [temperature] * len(seq_ids) top_ps += [top_p] * len(seq_ids) top_ks += [top_k] * len(seq_ids) @@ -443,18 +445,22 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], # Note that the performance will be very bad without # pinned memory. pin_memory = is_pin_memory_available() - prompt_max_len = max([len(tokens) for tokens in prompt_tokens], - default=0) - prompt_padded_tokens = [ - tokens + [vocab_size] * (prompt_max_len - len(tokens)) - for tokens in prompt_tokens - ] - output_max_len = max([len(tokens) for tokens in output_tokens], - default=0) - output_padded_tokens = [ - tokens + [vocab_size] * (output_max_len - len(tokens)) - for tokens in output_tokens - ] + + do_penalties = prompt_tokens or output_tokens + + if do_penalties: + prompt_max_len = max([len(tokens) for tokens in prompt_tokens], + default=0) + prompt_padded_tokens = [ + tokens + [vocab_size] * (prompt_max_len - len(tokens)) + for tokens in prompt_tokens + ] + output_max_len = max([len(tokens) for tokens in output_tokens], + default=0) + output_padded_tokens = [ + tokens + [vocab_size] * (output_max_len - len(tokens)) + for tokens in output_tokens + ] temperatures_t = torch.tensor( temperatures, @@ -504,18 +510,22 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.long, pin_memory=pin_memory, ) - prompt_tensor = torch.tensor( - prompt_padded_tokens, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ) - output_tensor = torch.tensor( - output_padded_tokens, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ) + if do_penalties: + prompt_tensor = torch.tensor( + prompt_padded_tokens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + output_tensor = torch.tensor( + output_padded_tokens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + else: + prompt_tensor = None + output_tensor = None # need to transpose and make contiguous to # copy the tensor correctly. # [batch_size, n_seeds] -> [n_seeds, batch_size] @@ -538,6 +548,16 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], extra_seeds_gpu = None sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] + if do_penalties: + prompt_tokens_gpu = prompt_tensor.to(device=device, + non_blocking=True) + output_tokens_gpu = output_tensor.to(device=device, + non_blocking=True) + else: + empty_tensor = torch.empty(0, device=device, dtype=torch.long) + prompt_tokens_gpu = empty_tensor + output_tokens_gpu = empty_tensor + return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), @@ -549,8 +569,8 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], non_blocking=True), repetition_penalties=repetition_penalties_t.to(device=device, non_blocking=True), - prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), - output_tokens=output_tensor.to(device=device, non_blocking=True), + prompt_tokens=prompt_tokens_gpu, + output_tokens=output_tokens_gpu, sampling_seeds=sampling_seeds_gpu, sample_indices=sample_indices_t.to(device=device, non_blocking=True),