Skip to content

Commit

Permalink
Use free_table as a mask tensor (#1086)
Browse files Browse the repository at this point in the history
* use free_table as a mask tensor

Signed-off-by: Wang, Yi A <[email protected]>

* fix beamsearch issue

Signed-off-by: Wang, Yi A <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi authored Jan 2, 2025
1 parent 1733791 commit 753f84d
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
batch_size, -1
)
self.free_blocks = torch.arange(self.num_blocks, device=device)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=device)
self.max_cache_len = max_cache_len
self.num_kv_heads = config.num_key_value_heads
self.num_hidden_layers = config.num_hidden_layers
Expand Down Expand Up @@ -88,12 +88,10 @@ def update_for_prefill(
all_slot_offsets = []
num_blocks = (input_lens + self.block_size - 1) // self.block_size
for i in range(batch_size):
for b_idx in range(num_blocks[i]):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

nb = num_blocks[i]
block_table = self.free_blocks.nonzero().view(-1)[0:nb]
self.block_tables[i][0:nb] = block_table
self.free_blocks[block_table] = 0
slots_range = torch.arange(input_lens[i], device=key_states.device)
block_indices = slots_range // self.block_size
slot_offsets = slots_range % self.block_size
Expand All @@ -103,7 +101,6 @@ def update_for_prefill(
all_block_indices = torch.cat(all_block_indices)
all_slot_offsets = torch.cat(all_slot_offsets)
self.slots = all_block_indices * self.block_size + all_slot_offsets

# Update the cache
PagedAttention.reshape_and_cache(
key_states,
Expand All @@ -127,16 +124,16 @@ def update_for_decode(
):
if layer_idx == 0:
start_block_idx = self._seen_tokens // self.block_size
num_blocks = (self._seen_tokens + self.block_size) // self.block_size
slot_offset_in_block = (self._seen_tokens) % self.block_size
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
for i in range(batch_size):
for b_idx in range(start_block_idx[i], num_blocks[i]):
if slot_offset_in_block[i] == 0:
# need a new block:
b_idx = start_block_idx[i]
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

self.block_tables[i][b_idx] = self.free_blocks.nonzero().view(-1)[0:1]
self.free_blocks[self.block_tables[i][b_idx]] = 0
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
# Update the cache
PagedAttention.reshape_and_cache(
Expand Down Expand Up @@ -196,7 +193,7 @@ def reset(self):
"""Resets the cache values while preserving the objects"""
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
self.block_tables.fill_(-1)
self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device)
self.max_seq_len = 0

def reorder_cache(self, beam_idx: torch.LongTensor):
Expand All @@ -206,16 +203,18 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
num_blocks = mask.cumsum(-1)[:, -1]
updated_table = []
updated_table = torch.zeros_like(beam_idx)
for i in range(beam_idx.shape[0]):
self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1]
updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]])
updated_table = torch.cat(tuple(updated_table), dim=0)
nb = num_blocks[i]
self.block_tables[i, 0 : nb - 1] = updated_block_tables[i, 0 : nb - 1]
updated_table[i] = self.block_tables[i][nb - 1]
for layer_idx in range(self.num_hidden_layers):
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))
for i in free_table:
if not (self.block_tables == i).any():
self.free_blocks[i] = 1

def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
Expand All @@ -235,4 +234,6 @@ def crop(self, maximum_length: int):
self._seen_tokens[bs] = new_tokens
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))
for i in free_table:
if not (self.block_tables == i).any():
self.free_blocks[i] = 1

0 comments on commit 753f84d

Please sign in to comment.