Skip to content

Commit

Permalink
Refactor: lint checker
Browse files Browse the repository at this point in the history
Signed-off-by: Dahai Tang <[email protected]>
  • Loading branch information
Dahai Tang committed Dec 4, 2024
1 parent d4bba11 commit 17b0aa5
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 77 deletions.
47 changes: 26 additions & 21 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,30 +923,35 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)

def kv_store_copy_incomplete_blocks(src: torch.Tensor, dst: torch.Tensor,
layer_id: int,
incomplete_block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.kv_store_copy_incomplete_blocks(src, dst,
layer_id,
incomplete_block_mapping)

def kv_store_copy_incomplete_blocks(
src: torch.Tensor, dst: torch.Tensor, layer_id: int,
incomplete_block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.kv_store_copy_incomplete_blocks(
src, dst, layer_id, incomplete_block_mapping)


def kv_store_copy_blocks2CPU(src: torch.Tensor, dst: torch.Tensor,
layer_id: int,
block_mapping: torch.Tensor) -> None:
layer_id: int,
block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.kv_store_copy_blocks2CPU(src, dst, layer_id,
block_mapping)

def kv_store_copy_blocks2GPU(src: torch.Tensor, dst: List[torch.Tensor],
num_layers: int,
block_mapping: torch.Tensor,
block_offsets: torch.Tensor,
req_ids: torch.Tensor,
events: List[int], # the pointer of cudaEvent_t
is_batch_layer: bool) -> None:
torch.ops._C_cache_ops.kv_store_copy_blocks2GPU(
src, dst, num_layers,
block_mapping, block_offsets,
req_ids, events, is_batch_layer)
block_mapping)


def kv_store_copy_blocks2GPU(
src: torch.Tensor,
dst: List[torch.Tensor],
num_layers: int,
block_mapping: torch.Tensor,
block_offsets: torch.Tensor,
req_ids: torch.Tensor,
events: List[int], # the pointer of cudaEvent_t
is_batch_layer: bool) -> None:
torch.ops._C_cache_ops.kv_store_copy_blocks2GPU(src, dst, num_layers,
block_mapping,
block_offsets, req_ids,
events, is_batch_layer)


def convert_fp8(output: torch.Tensor,
input: torch.Tensor,
Expand Down
47 changes: 22 additions & 25 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,9 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
if isinstance(request_id, str):
request_id = (request_id, )
request_ids = set(request_id)
for state_queue in [self.waiting, self.running,
self.swapped, self.kv_store_waiting]:
for state_queue in [
self.waiting, self.running, self.swapped, self.kv_store_waiting
]:
aborted_groups: List[SequenceGroup] = []
for seq_group in state_queue:
if not request_ids:
Expand Down Expand Up @@ -931,22 +932,19 @@ def _schedule_prefills(
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
kv_store_leftover_waiting_sequences: Deque[SequenceGroup] = deque()

def _stop_schedule_prefill(num_new_tokens_uncached,
num_new_seqs,
max_num_batched_tokens,
budget):
def _stop_schedule_prefill(num_new_tokens_uncached, num_new_seqs,
max_num_batched_tokens, budget):
ret = False
if (budget.num_batched_tokens >=
self.scheduler_config.max_num_batched_tokens):
ret = True
if (num_new_tokens_uncached == 0 or
not budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs)):
if (num_new_tokens_uncached == 0 or not budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs)):
ret = True
return ret

kv_store_tmp_queue : Deque[SequenceGroup] = deque()
kv_store_tmp_queue: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and kv_store_waiting_queue:

seq_group = kv_store_waiting_queue[0]
Expand Down Expand Up @@ -977,8 +975,8 @@ def _stop_schedule_prefill(num_new_tokens_uncached,
kv_store_waiting_queue.popleft()
continue

if (_stop_schedule_prefill(num_new_tokens_uncached,
num_new_seqs,
if (_stop_schedule_prefill(
num_new_tokens_uncached, num_new_seqs,
self.scheduler_config.max_num_batched_tokens, budget)):
break

Expand Down Expand Up @@ -1071,7 +1069,7 @@ def _stop_schedule_prefill(num_new_tokens_uncached,

if (self.kv_store_manager is not None):
block_ids = self.block_manager.get_block_table(
seq_group.get_seqs()[0])
seq_group.get_seqs()[0])
block_mapping_from_cpu = \
self.kv_store_manager.get_block_mapping_from_python(
block_ids)
Expand All @@ -1097,18 +1095,18 @@ def _stop_schedule_prefill(num_new_tokens_uncached,
if (len(block_mapping_from_cpu) > 0):
waiting_queue.popleft()
kv_store_leftover_waiting_sequences.appendleft(seq_group)
kv_store_block_mapping.extend(
block_mapping_from_cpu)
kv_store_block_mapping_offset.append(kv_store_block_mapping_cnt)
kv_store_block_mapping.extend(block_mapping_from_cpu)
kv_store_block_mapping_offset.append(
kv_store_block_mapping_cnt)
kv_store_block_mapping_req_ids.append(
seq_group.get_seqs()[0].seq_id)
seq_group.get_seqs()[0].seq_id)
kv_store_block_mapping_cnt += len(block_mapping_from_cpu)
continue

num_new_seqs = seq_group.get_max_num_running_seqs()
if (_stop_schedule_prefill(num_new_tokens_uncached, num_new_seqs,
self.scheduler_config.max_num_batched_tokens,
budget)):
if (_stop_schedule_prefill(
num_new_tokens_uncached, num_new_seqs,
self.scheduler_config.max_num_batched_tokens, budget)):
# let it to the next running one
waiting_queue.popleft()
kv_store_leftover_waiting_sequences.appendleft(seq_group)
Expand Down Expand Up @@ -1141,13 +1139,12 @@ def _stop_schedule_prefill(num_new_tokens_uncached,
if (self.kv_store_manager is not None) and \
(len(kv_store_block_mapping) > 0):
self.kv_store_manager.close_send_flags(
[items[1]
for items in kv_store_block_mapping])
[items[1] for items in kv_store_block_mapping])

kv_store_block_mapping_offset.append(kv_store_block_mapping_cnt)
kv_store_block_mapping_from_cpu = BlockMappingFromCPU(
kv_store_block_mapping, kv_store_block_mapping_offset,
kv_store_block_mapping_req_ids)
kv_store_block_mapping, kv_store_block_mapping_offset,
kv_store_block_mapping_req_ids)

return SchedulerPrefillOutputs(
seq_groups=seq_groups,
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class EngineArgs:
use_v2_block_manager: bool = True
swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB
kv_store_space: float = 0 # GiB
kv_store_space: float = 0 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
Expand Down
11 changes: 5 additions & 6 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,14 @@ def forward(
if (self.kv_store is not None) and \
(self.kv_store.batch_layers_to_GPU):
self.kv_store.get_stream_sync(
attn_metadata.kv_store_meta.request_ids)
attn_metadata.kv_store_meta.request_ids)

for i in range(self.start_layer, self.end_layer):
layer_id = (i - self.start_layer)
if (self.kv_store is not None) and \
(not self.kv_store.batch_layers_to_GPU):
self.kv_store.get_stream_layer_sync(
layer_id, attn_metadata.kv_store_meta.request_ids)
layer_id, attn_metadata.kv_store_meta.request_ids)
layer = self.layers[i]
hidden_states, residual = layer(
positions,
Expand All @@ -360,10 +360,9 @@ def forward(

if (self.kv_store is not None):
self.kv_store.put_block_layer(
attn_metadata.kv_store_meta.incomplete_put_block_ids,
attn_metadata.kv_store_meta.put_block_ids_mapping,
layer_id, kv_caches[layer_id],
torch.cuda.current_stream())
attn_metadata.kv_store_meta.incomplete_put_block_ids,
attn_metadata.kv_store_meta.put_block_ids_mapping,
layer_id, kv_caches[layer_id], torch.cuda.current_stream())

if not get_pp_group().is_last_rank:
return IntermediateTensors({
Expand Down
6 changes: 1 addition & 5 deletions vllm/store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from vllm.store.kv_store import KVBlockStore, KVBlockStoreManager, KVStoreMeta

__all__ = [
"KVBlockStore",
"KVBlockStoreManager",
"KVStoreMeta"
]
__all__ = ["KVBlockStore", "KVBlockStoreManager", "KVStoreMeta"]
7 changes: 3 additions & 4 deletions vllm/store/kv_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def __init__(self, block_mapping: list[list[int, int]],

@staticmethod
def null():
return BlockMappingFromCPU(
torch.Tensor(), torch.Tensor(), torch.Tensor())
return BlockMappingFromCPU(torch.Tensor(), torch.Tensor(),
torch.Tensor())

def __str__(self):
return "block_mapping: " + str(self.block_mapping) + \
Expand Down Expand Up @@ -417,8 +417,7 @@ def put_block_layer(self, incomplete_block_ids: torch.Tensor,
layer_id,
incomplete_block_ids)

def get_blocks(self,
block_mapping_from_cpu: BlockMappingFromCPU,
def get_blocks(self, block_mapping_from_cpu: BlockMappingFromCPU,
kv_caches: list[torch.Tensor]):
block_mapping_tensor = block_mapping_from_cpu.block_mapping
block_offset_tensor = block_mapping_from_cpu.block_offset
Expand Down
8 changes: 4 additions & 4 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,10 @@ def issue_blocks_copy(self, worker_input: WorkerInput) -> None:
return
kv_caches = (self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None)
self.kv_store.get_blocks(BlockMappingFromCPU(
worker_input.kv_store_block_mapping,
worker_input.kv_store_block_offsets,
worker_input.kv_store_block_req_ids),
self.kv_store.get_blocks(
BlockMappingFromCPU(worker_input.kv_store_block_mapping,
worker_input.kv_store_block_offsets,
worker_input.kv_store_block_req_ids),
kv_caches)

def _get_cached_seq_group_metadata(
Expand Down
21 changes: 10 additions & 11 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,16 @@ def from_broadcasted_tensor_dict(
Pop fields from the given tensor_dict and populate a new instance of
WorkerInput.
"""
return cls(num_seq_groups=tensor_dict.pop("num_seq_groups"),
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
kv_store_block_mapping=tensor_dict.pop("kv_block_mapping"),
kv_store_block_offsets=tensor_dict.pop(
"kv_block_mapping_offsets"),
kv_store_block_req_ids=tensor_dict.pop(
"kv_block_mapping_req_ids"),
return cls(
num_seq_groups=tensor_dict.pop("num_seq_groups"),
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
kv_store_block_mapping=tensor_dict.pop("kv_block_mapping"),
kv_store_block_offsets=tensor_dict.pop("kv_block_mapping_offsets"),
kv_store_block_req_ids=tensor_dict.pop("kv_block_mapping_req_ids"),
)

def as_broadcastable_tensor_dict(
Expand Down

0 comments on commit 17b0aa5

Please sign in to comment.