From a8d049432e46bfe162af561997b9aaea4d2c550f Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Wed, 20 Nov 2024 00:38:24 -0800 Subject: [PATCH] #0: Update vision model to work with sharded text decoder --- models/demos/llama3/demo/demo.py | 4 ++-- models/demos/llama3/tests/test_llama_accuracy.py | 2 +- models/demos/llama3/tests/test_llama_decoder.py | 4 ++-- models/demos/llama3/tests/test_llama_model.py | 2 +- models/demos/llama3/tests/test_llama_perf.py | 2 +- models/demos/llama3/tests/test_llama_rms_norm.py | 2 +- models/demos/llama3/tt/llama_attention.py | 6 +++--- models/demos/llama3/tt/llama_decoder.py | 9 +-------- models/demos/llama3/tt/llama_mlp.py | 2 +- models/demos/llama3/tt/llama_model.py | 3 +++ models/demos/llama3/tt/model_config.py | 4 ++-- .../demos/llama3/tt/multimodal/llama_cross_attention.py | 9 ++++----- models/demos/llama3/tt/multimodal/llama_cross_block.py | 5 +++++ models/demos/llama3/tt/multimodal/llama_vision_model.py | 4 +++- 14 files changed, 30 insertions(+), 28 deletions(-) diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index a35ea5274c04..837e03a5dbc8 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -378,7 +378,7 @@ def run_llama3_demo( # Compile logger.info(f"Compiling model trace...") decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) - decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DEC_SKIP_OUTPUT_MEMCFG"]) + decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"]) tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) if tt_model.args.num_devices > 1: tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) @@ -400,7 +400,7 @@ def run_llama3_demo( trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) - decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DEC_SKIP_OUTPUT_MEMCFG"]) + decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"]) tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) if tt_model.args.num_devices > 1: tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) diff --git a/models/demos/llama3/tests/test_llama_accuracy.py b/models/demos/llama3/tests/test_llama_accuracy.py index 01ef54b34695..acdcc2579016 100644 --- a/models/demos/llama3/tests/test_llama_accuracy.py +++ b/models/demos/llama3/tests/test_llama_accuracy.py @@ -161,7 +161,7 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac # Prepare input for TT model decode_input = model_args.prepare_inputs_ttnn_decode( pt_decode_input, - model_args.model_config["DEC_SKIP_OUTPUT_MEMCFG"], + model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) # Run TT model tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) diff --git a/models/demos/llama3/tests/test_llama_decoder.py b/models/demos/llama3/tests/test_llama_decoder.py index dc607f312921..1fad070640b2 100644 --- a/models/demos/llama3/tests/test_llama_decoder.py +++ b/models/demos/llama3/tests/test_llama_decoder.py @@ -34,7 +34,7 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat8_b - mesh_device.enable_async(False) # NOCOMMIT + mesh_device.enable_async(True) model_args = TtModelArgs(mesh_device) model_args.n_layers = 1 @@ -93,7 +93,7 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, en decode_input = model_args.prepare_inputs_ttnn_decode( tt_decode_input, # ttnn.DRAM_MEMORY_CONFIG, - model_args.model_config["DEC_SKIP_OUTPUT_MEMCFG"], + model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) # Run TT model diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py index d359a9f6afc3..afd8aa833126 100644 --- a/models/demos/llama3/tests/test_llama_model.py +++ b/models/demos/llama3/tests/test_llama_model.py @@ -176,7 +176,7 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, decode_input = model_args.prepare_inputs_ttnn_decode( tt_decode_input, - model_args.model_config["DEC_SKIP_OUTPUT_MEMCFG"], + model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) current_pos_tensor = ttnn.from_torch( torch.tensor([current_pos] * batch), diff --git a/models/demos/llama3/tests/test_llama_perf.py b/models/demos/llama3/tests/test_llama_perf.py index a6999269ca0b..422c239051fe 100644 --- a/models/demos/llama3/tests/test_llama_perf.py +++ b/models/demos/llama3/tests/test_llama_perf.py @@ -184,7 +184,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos profiler.start(f"model_run_for_inference_{i}") decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) - decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DEC_SKIP_OUTPUT_MEMCFG"]) + decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"]) tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) ttnn.deallocate(tt_out) diff --git a/models/demos/llama3/tests/test_llama_rms_norm.py b/models/demos/llama3/tests/test_llama_rms_norm.py index 52a630096499..bf0ce8289006 100644 --- a/models/demos/llama3/tests/test_llama_rms_norm.py +++ b/models/demos/llama3/tests/test_llama_rms_norm.py @@ -73,7 +73,7 @@ def test_llama_rms_norm_inference(mesh_device, use_program_cache, reset_seeds, e dtype=dtype, layout=ttnn.TILE_LAYOUT, mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), - memory_config=model_args.get_model_config()["DEC_SKIP_OUTPUT_MEMCFG"] + memory_config=model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG, ) diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index 02d9ae5ec7f8..86c3865c57d0 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -338,7 +338,7 @@ def forward_decode( program_config=self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"], compute_kernel_config=self.compute_kernel_config_hifi2, memory_config_ag=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"], - memory_config_mm=self.model_config["DEC_SKIP_OUTPUT_MEMCFG"], + memory_config_mm=self.model_config["DECODE_RESIDUAL_MEMCFG"], ) else: # program config matched to output of nlp_concat_heads_decode @@ -360,13 +360,13 @@ def forward_decode( math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=self.model_config[ - "DEC_SKIP_OUTPUT_MEMCFG" + "DECODE_RESIDUAL_MEMCFG" ], # Unlike matmuls, CCL ops can reshard to any valid output sharding for free ) ttnn.deallocate(dense_out_sharded) return dense_out_reduced else: - dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DEC_SKIP_OUTPUT_MEMCFG"]) + dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"]) return dense_out_sharded def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = 0, page_table=None): diff --git a/models/demos/llama3/tt/llama_decoder.py b/models/demos/llama3/tt/llama_decoder.py index 06774f620ebe..578e0bf81a63 100644 --- a/models/demos/llama3/tt/llama_decoder.py +++ b/models/demos/llama3/tt/llama_decoder.py @@ -89,14 +89,7 @@ def forward( page_table=None, ) -> ttnn.Tensor: # x is fractured across devices and interleaved in DRAM (for prefill) and sharded in L1 (for decode) - # FIXME: move to sharded residuals once support for this is added - # FIXME: Currently, for decode mode, we are using DRAM intereleaved as L1 interleaved results in h being corrupted in MLP - skip_mem_cfg = ( - # ttnn.DRAM_MEMORY_CONFIG - self.model_config["DEC_SKIP_OUTPUT_MEMCFG"] - if mode == "decode" - else ttnn.DRAM_MEMORY_CONFIG - ) + skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG assert ( x.memory_config() == skip_mem_cfg ), f"decoder input memcfg mismatch: {x.memory_config()} != {skip_mem_cfg}" diff --git a/models/demos/llama3/tt/llama_mlp.py b/models/demos/llama3/tt/llama_mlp.py index 8982f137521b..c36f0a0845bb 100644 --- a/models/demos/llama3/tt/llama_mlp.py +++ b/models/demos/llama3/tt/llama_mlp.py @@ -145,5 +145,5 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: # reshard to residual, no-op if already correct if mode == "decode": - result = ttnn.to_memory_config(result, self.model_config["DEC_SKIP_OUTPUT_MEMCFG"]) + result = ttnn.to_memory_config(result, self.model_config["DECODE_RESIDUAL_MEMCFG"]) return result diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 7ae34f174c2c..5b1b5a49b3bd 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -83,6 +83,9 @@ def forward( page_table=None, get_last_token=-1, ): + # No-op if callers already provide the right memory config + x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"]) + for layer in self.layers: x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode, page_table) diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 605e0b93ef92..091eeb67fb1f 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -56,7 +56,7 @@ class TtModelArgs: "ATTN_OUTPUT", "ATTN_W_LAYOUT", # Decoder - "DEC_SKIP_OUTPUT", + "DECODE_RESIDUAL", "OUTPUT_MM", ) @@ -229,7 +229,7 @@ def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_s self.model_config["COMPUTE_KERNEL_CONFIG_HIFI2"] = self.compute_kernel_config_hifi2 residual_grid = self.dram_shard_core_grid_for_k(self.dim // self.num_devices) - self.model_config["DEC_SKIP_OUTPUT_MEMCFG"] = ttnn.create_sharded_memory_config( + self.model_config["DECODE_RESIDUAL_MEMCFG"] = ttnn.create_sharded_memory_config( ( self.tile_padded_batch_rows, self.dim // residual_grid.num_cores // self.num_devices, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index 63f87fbeb731..71fe78f6de93 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -269,16 +269,15 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, # All reduce if self.is_multichip: - dense_out_reduced = ttnn.reduce_scatter( + output = ttnn.reduce_scatter( output, scatter_dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, - memory_config=ttnn.L1_MEMORY_CONFIG, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - return dense_out_reduced - else: - return output + + return ttnn.to_memory_config(output, self.model_config["DECODE_RESIDUAL_MEMCFG"]) def forward_prefill( self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, user_id, vision_tokens diff --git a/models/demos/llama3/tt/multimodal/llama_cross_block.py b/models/demos/llama3/tt/multimodal/llama_cross_block.py index 3ba172a7d39b..1761bc7ac664 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_block.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_block.py @@ -135,10 +135,15 @@ def forward( user_id=user_id, vision_tokens=vision_tokens, ) + # FIXME: DRAM workaround for No circular buffer with id error + attn_out = ttnn.to_memory_config(attn_out, memory_config=ttnn.DRAM_MEMORY_CONFIG) attn_out = ttnn.mul(attn_out, ttnn.tanh(self.gate_attn)) res = ttnn.add(x_11SH, attn_out) mlp_out = self.feed_forward(self.ffn_norm(res, mode=mode), mode=mode) + # FIXME: DRAM workaround for No circular buffer with id error + mlp_out = ttnn.to_memory_config(mlp_out, memory_config=ttnn.DRAM_MEMORY_CONFIG) + if mode == "prefill": mlp_out = ttnn.mul(mlp_out, full_text_row_masked_out_mask_11SD) mlp_out = ttnn.mul(mlp_out, ttnn.tanh(self.gate_ffwd)) diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 2611a43582c3..fc75adad00d2 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -394,6 +394,8 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas B=tokens.shape[0], ) + tt_h = ttnn.to_memory_config(tt_h, self.configuration.model_config["DECODE_RESIDUAL_MEMCFG"]) + return ( tt_h, tt_xattn_mask, @@ -413,7 +415,7 @@ def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_ro h = self.prepare_inputs_common(position_ids, tokens) tt_h = self.configuration.prepare_inputs_ttnn_decode( h, - ttnn.DRAM_MEMORY_CONFIG, + ttnn.DRAM_MEMORY_CONFIG, # L1 memory_configs are not respected for on_host tensors on_host=True, )