From e1251ade725323356aaf4fc6de98d4c6f2f3c0e3 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 9 Oct 2024 05:45:21 -0700 Subject: [PATCH] #0: Push rotary embeddings as row-major for e2e speedup, 10 -> 14.3 t/s/u Llama3 in demo. --- .../demos/t3000/llama2_70b/tests/test_llama_attention.py | 1 - models/demos/t3000/llama2_70b/tests/test_llama_decoder.py | 1 - models/demos/t3000/llama2_70b/tests/test_llama_model.py | 3 +++ .../t3000/llama2_70b/tests/test_llama_perf_decode.py | 4 +++- models/demos/t3000/llama2_70b/tt/llama_generation.py | 8 ++++++-- models/demos/t3000/llama2_70b/tt/llama_model_optimized.py | 2 +- 6 files changed, 13 insertions(+), 6 deletions(-) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_attention.py b/models/demos/t3000/llama2_70b/tests/test_llama_attention.py index 9209e4c22c9..b86d877d4f2 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_attention.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_attention.py @@ -194,7 +194,6 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, mode) mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), device=llama_attention_model.mesh_device, ) - rot_mats = ttnn.to_device(rot_mats, llama_attention_model.mesh_device) rot_mats = ttnn.interleaved_to_sharded(rot_mats, llama_attention_model.model_config["ROT_MAT_MM_IN1_MEMCFG"]) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py b/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py index 287426140b2..93376894540 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py @@ -188,7 +188,6 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), device=llama_decoder_model.mesh_device, ) - rot_mats = ttnn.to_device(rot_mats, llama_decoder_model.mesh_device) rot_mats = ttnn.interleaved_to_sharded(rot_mats, llama_decoder_model.model_config["ROT_MAT_MM_IN1_MEMCFG"]) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_model.py b/models/demos/t3000/llama2_70b/tests/test_llama_model.py index b53d108a6ca..55deb1e7626 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_model.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_model.py @@ -165,6 +165,9 @@ def run_test_LlamaModel_inference( if mode == "decode": tt_inp_emb = ttnn.to_device(tt_inp_emb, t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) tt_inp_emb = tt_model.tt_embd(tt_inp_emb) + rot_mat_rm = ttnn.to_device(rot_mat, t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + rot_mat = ttnn.to_layout(rot_mat_rm, ttnn.TILE_LAYOUT) + rot_mat = ttnn.interleaved_to_sharded(rot_mat, model_config["ROT_MAT_MM_IN1_MEMCFG"]) tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) rot_mat = ttnn.to_device(rot_mat, t3k_mesh_device, memory_config=model_config["ROT_MAT_MM_IN1_MEMCFG"]) cache_idxs = ttnn.to_device(cache_idxs, t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py index fbccd4176c3..ccc68dcd626 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py @@ -132,7 +132,9 @@ def run_test_LlamaModel_end_to_end( tt_inp_emb = tt_model.tt_embd(tt_inp_emb) tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, tt_model.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) - rot_mat = ttnn.to_device(rot_mat, mesh_device, memory_config=tt_model.model_config["ROT_MAT_MM_IN1_MEMCFG"]) + rot_mat_rm = ttnn.to_device(rot_mat, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + rot_mat = ttnn.to_layout(rot_mat_rm, ttnn.TILE_LAYOUT) + rot_mat = ttnn.interleaved_to_sharded(rot_mat, model_config["ROT_MAT_MM_IN1_MEMCFG"]) cache_idxs = ttnn.to_device(cache_idxs, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) ##### Compile Model ##### diff --git a/models/demos/t3000/llama2_70b/tt/llama_generation.py b/models/demos/t3000/llama2_70b/tt/llama_generation.py index 926cbc23503..f4583fcb1c9 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_generation.py +++ b/models/demos/t3000/llama2_70b/tt/llama_generation.py @@ -124,7 +124,9 @@ def capture_trace(self, tokens: torch.Tensor, start_pos: int): tt_inp = ttnn.to_device(tt_inp, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) tt_inp_emb = self.tt_model.tt_embd(tt_inp) tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) - rot_mat = ttnn.to_device(rot_mat, self.mesh_device, memory_config=self.model_config["ROT_MAT_MM_IN1_MEMCFG"]) + rot_mat_rm = ttnn.to_device(rot_mat, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + rot_mat = ttnn.to_layout(rot_mat_rm, ttnn.TILE_LAYOUT) + rot_mat = ttnn.interleaved_to_sharded(rot_mat, self.model_config["ROT_MAT_MM_IN1_MEMCFG"]) cache_idxs_tt = ttnn.to_device(cache_idxs_tt, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) tt_logits = self.tt_model(tt_inp_emb, rot_mat, start_pos, cache_idxs=cache_idxs_tt, mode="decode") @@ -134,12 +136,14 @@ def capture_trace(self, tokens: torch.Tensor, start_pos: int): # Run TT model tt_inp_emb = self.tt_model.tt_embd(tt_inp) tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) + rot_mat = ttnn.to_layout(rot_mat_rm, ttnn.TILE_LAYOUT) + rot_mat = ttnn.interleaved_to_sharded(rot_mat, self.model_config["ROT_MAT_MM_IN1_MEMCFG"]) tt_logits = self.tt_model(tt_inp_emb, rot_mat, start_pos, cache_idxs=cache_idxs_tt, mode="decode") ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) logger.info("Done Capturing Decode Trace") - return trace_id, tt_inp, rot_mat, cache_idxs_tt, tt_logits + return trace_id, tt_inp, rot_mat_rm, cache_idxs_tt, tt_logits def delete_trace(self, trace_id): ttnn.release_trace(self.mesh_device, trace_id) diff --git a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py index 570d9db330d..22604321f6f 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py @@ -258,7 +258,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode"): rot_mats = ttnn.as_tensor( rot_mat, dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, + layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=ReplicateTensorToMesh(self.mesh_device), )