Skip to content

Commit

Permalink
#12330: updated model and test cases that uses sdpa decode to take th…
Browse files Browse the repository at this point in the history
…e new input kv format and fixed logging in sdpa decode
  • Loading branch information
caixunshiren committed Sep 25, 2024
1 parent 72cac85 commit 8312d0a
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 40 deletions.
7 changes: 2 additions & 5 deletions models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,10 @@ def attn_mqa(self, query_layer, key_layer, value_layer, start_pos: int, cache_id
)

else:
# Have to reshape back since sdpa expects batch in dim 1
keys_reshaped = ttnn.reshape(keys, [self.n_local_kv_heads, self.max_batch_size, -1, self.head_dim])
values_reshaped = ttnn.reshape(values, [self.n_local_kv_heads, self.max_batch_size, -1, self.head_dim])
attn_output = ttnn.transformer.scaled_dot_product_attention_decode(
query_layer,
keys_reshaped,
values_reshaped,
keys,
values,
# [start_pos for _ in range(self.max_batch_size)],
cur_pos_tensor=cache_idxs,
scale=self.scale,
Expand Down
3 changes: 0 additions & 3 deletions models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,6 @@ def forward_decode(
k_heads_1B1D.deallocate(True)
v_heads_1B1D.deallocate(True)

keys_1BPD = ttnn.reshape(keys_1BPD, [self.n_local_kv_heads, self.max_batch_size, -1, self.head_dim])
values_1BPD = ttnn.reshape(values_1BPD, [self.n_local_kv_heads, self.max_batch_size, -1, self.head_dim])

attn_output_1B4D = ttnn.transformer.scaled_dot_product_attention_decode(
q_heads_1B4D,
keys_1BPD,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,9 @@ def run_test_LlamaAttention_inference(
tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaAttention_model.layer_past]

tt_layer_present_all = [
ttnn.to_torch(
lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(1, 0), cluster_shape=cluster_shape)
).transpose(0, 1)[:batch, ...]
ttnn.to_torch(lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(0, 1), cluster_shape=cluster_shape))[
:batch, ...
]
for lp in tt_layer_present_all
]

Expand Down
6 changes: 3 additions & 3 deletions models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,9 @@ def run_test_LlamaDecoder_inference(

tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaDecoder_model.attention.layer_past]
tt_layer_present_all = [
ttnn.to_torch(
lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(1, 0), cluster_shape=cluster_shape)
).transpose(0, 1)[:batch, ...]
ttnn.to_torch(lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(0, 1), cluster_shape=cluster_shape))[
:batch, ...
]
for lp in tt_layer_present_all
]

Expand Down
4 changes: 2 additions & 2 deletions models/demos/tg/llama3_70b/tests/test_llama_model_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ def run_test_LlamaModel_inference(
tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_model.layers[layer_id].attention.layer_past]
tt_layer_present_all = [
ttnn.to_torch(
lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(1, 0), cluster_shape=cluster_shape)
).transpose(0, 1)[:batch, ...]
lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(0, 1), cluster_shape=cluster_shape)
)[:batch, ...]
for lp in tt_layer_present_all
]

Expand Down
31 changes: 19 additions & 12 deletions models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ def init_kv_cache(self):
)
for lp in layer_past
]
# work around for CI error
self.layer_past = [
ttnn.reshape(lp, [self.batch_size_per_device_group, self.n_local_kv_heads, -1, self.head_dim])
for lp in self.layer_past
]

def load_weights(self):
assert not hasattr(self, "qkv_list"), "qkv_list is already an attribute of this object"
Expand Down Expand Up @@ -463,12 +468,22 @@ def attn_mqa(
):
# K CACHE UPDATE
keys = self.layer_past[0]
ttnn.update_cache(keys, key_layer, start_pos, batch_offset=batch_offset)
ttnn.experimental.paged_update_cache(
keys,
key_layer,
update_idxs=[start_pos for _ in range(self.batch_size_per_device_group)],
batch_offset=batch_offset,
)
key_layer.deallocate(True)

# V CACHE UPDATE
values = self.layer_past[1]
ttnn.update_cache(values, value_layer, start_pos, batch_offset=batch_offset)
ttnn.experimental.paged_update_cache(
values,
value_layer,
update_idxs=[start_pos for _ in range(self.batch_size_per_device_group)],
batch_offset=batch_offset,
)
value_layer.deallocate(True)

program_config = ttnn.SDPAProgramConfig(
Expand Down Expand Up @@ -629,28 +644,20 @@ def prefill_attn_mqa(
):
# FILL K CACHE
keys = self.layer_past[0]
# Fill cache expects batch in dim0
keys_reshaped = ttnn.reshape(keys, [self.batch_size_per_device_group, self.n_local_kv_heads, -1, self.head_dim])

single_user_key_layer = self.prefill_prepare_tensor_for_kv_cache(key_layer, user_id)

# Fill cache with multi-device tensor
ttnn.fill_cache(
ttnn.experimental.paged_fill_cache(
keys_reshaped,
ttnn.experimental.typecast(single_user_key_layer, ttnn.bfloat8_b, memory_config=ttnn.DRAM_MEMORY_CONFIG),
user_id % self.batch_size_per_device_group,
)

# FILL V CACHE
values = self.layer_past[1]
# Fill cache expects batch in dim0
values_reshaped = ttnn.reshape(
values, [self.batch_size_per_device_group, self.n_local_kv_heads, -1, self.head_dim]
)

single_user_value_layer = self.prefill_prepare_tensor_for_kv_cache(value_layer, user_id)

ttnn.fill_cache(
ttnn.experimental.paged_fill_cache(
values_reshaped,
ttnn.experimental.typecast(single_user_value_layer, ttnn.bfloat8_b, memory_config=ttnn.DRAM_MEMORY_CONFIG),
user_id % self.batch_size_per_device_group,
Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/multichip_unit_tests/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ def run_test_sdpa_decode_single_iter(
V = torch.randn(nkv, b, s, d)

tt_K = ttnn.from_torch(
K,
K.permute(1, 0, 2, 3),
device=mesh_device,
dtype=dtype,
layout=ttnn.TILE_LAYOUT,
Expand All @@ -933,7 +933,7 @@ def run_test_sdpa_decode_single_iter(
)

tt_V = ttnn.from_torch(
V,
V.permute(1, 0, 2, 3),
device=mesh_device,
dtype=dtype,
layout=ttnn.TILE_LAYOUT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,16 +600,16 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(

uint32_t cur_pos = use_cur_pos_tensor ? -1 : cur_pos_ids.at(cur_batch);

log_trace("---- core_id: {}, coord: {} ----", i, core);
log_trace("worker_id_for_reduce: {}", worker_id_for_reduce);
log_trace("worker_id_for_output: {}", worker_id_for_output);
log_trace("do_reduce: {}", do_reduce);
log_trace("do_output: {}", do_output);
log_trace("cur_head: {}", cur_head);
log_trace("cur_batch: {}", cur_batch);
log_trace("core_num_in_reduce: {}", core_num_in_reduce);
log_trace("core_num_in_output: {}", core_num_in_output);
log_trace("cur_pos: {}", cur_pos);
log_debug("---- core_id: {}, coord: {} ----", i, core);
log_debug("worker_id_for_reduce: {}", worker_id_for_reduce);
log_debug("worker_id_for_output: {}", worker_id_for_output);
log_debug("do_reduce: {}", do_reduce);
log_debug("do_output: {}", do_output);
log_debug("cur_head: {}", cur_head);
log_debug("cur_batch: {}", cur_batch);
log_debug("core_num_in_reduce: {}", core_num_in_reduce);
log_debug("core_num_in_output: {}", core_num_in_output);
log_debug("cur_pos: {}", cur_pos);

// reader runtime args
std::vector<uint32_t> reader_rt_args = { q_addr, k_addr, v_addr, pos_addr, page_table_addr, do_reduce, do_output, cur_head, cur_batch, core_num_in_reduce, core_num_in_output, cur_pos};
Expand Down

0 comments on commit 8312d0a

Please sign in to comment.