From 48a2c6d6eddc4c32e5341e606bc5349a4a964435 Mon Sep 17 00:00:00 2001 From: kpaigwar Date: Thu, 9 May 2024 22:27:33 +0000 Subject: [PATCH] #8322: integrate ssm_eltwise_mul cases with mamba and verify perf --- models/demos/mamba/tests/test_full_model.py | 5 +-- models/demos/mamba/tests/test_mamba_perf.py | 2 +- models/demos/mamba/tt/mamba_one_step_ssm.py | 34 +++++++++++---------- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/models/demos/mamba/tests/test_full_model.py b/models/demos/mamba/tests/test_full_model.py index c9846a363aa..afbdca353e8 100644 --- a/models/demos/mamba/tests/test_full_model.py +++ b/models/demos/mamba/tests/test_full_model.py @@ -39,6 +39,7 @@ def forward(self, x): x = self.lm_head(x) return x + def run_inference( device: ttnn.Device, use_program_cache, @@ -109,7 +110,7 @@ def test_inference( iterations: int, ): run_inference(device, use_program_cache, model_version, batch, pcc, cache_dir, num_layers, iterations) - + @skip_for_grayskull("Not supported on Grayskull") @pytest.mark.parametrize( @@ -123,7 +124,7 @@ def test_device_perf( model_version="state-spaces/mamba-2.8b", batch=32, pcc=0.97, - cache_dir="/tmp", + cache_dir=None, num_layers=1, ): run_inference(device, use_program_cache, model_version, batch, pcc, cache_dir, num_layers, iterations) diff --git a/models/demos/mamba/tests/test_mamba_perf.py b/models/demos/mamba/tests/test_mamba_perf.py index 86265a5118c..b7dfe365c44 100644 --- a/models/demos/mamba/tests/test_mamba_perf.py +++ b/models/demos/mamba/tests/test_mamba_perf.py @@ -115,7 +115,7 @@ def test_mamba_e2e_perf( @pytest.mark.models_device_performance_bare_metal @pytest.mark.parametrize( "batch, warmup, expected_device_fw_duration_ms", - ((32, True, 8.02),), + ((32, True, 2.81),), ) def test_mamba_perf_device(batch, warmup, expected_device_fw_duration_ms, reset_seeds): subdir = "ttnn_mamba" diff --git a/models/demos/mamba/tt/mamba_one_step_ssm.py b/models/demos/mamba/tt/mamba_one_step_ssm.py index ec355c9790c..a4a5c18acbb 100644 --- a/models/demos/mamba/tt/mamba_one_step_ssm.py +++ b/models/demos/mamba/tt/mamba_one_step_ssm.py @@ -179,15 +179,17 @@ def forward(self, x): ttnn.deallocate(delta_t2) ttnn.deallocate(B0) - x0 = self.transformer.repeat_interleave( + # bbar * x + bmulx0 = ttnn.experimental.operations.primary.transformers.ssm_eltwise_mul( + bbar0, x, - memory_config=ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), + output_mem_config=ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 + ), ) - bmulx0 = ttnn.mul(bbar0, x0, memory_config=ttnn.L1_MEMORY_CONFIG) # deallocate bbar ttnn.deallocate(bbar0) - ttnn.deallocate(x0) # add amulh and bmulx hidden_state1 = ttnn.add(amulh0, bmulx0, memory_config=ttnn.L1_MEMORY_CONFIG) @@ -206,25 +208,25 @@ def forward(self, x): core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), ) # b,n - # repeat using mask+matmul instead of ttnn.repeat to avoid fallback - C1 = self.transformer.repeat( + # c * hidden_state + C1 = ttnn.experimental.operations.primary.transformers.ssm_eltwise_mul( C0, - memory_config=ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), + hidden_state1, + output_mem_config=ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 + ), ) - ttnn.deallocate(C0) - - C2 = ttnn.mul(hidden_state1, C1, memory_config=ttnn.L1_MEMORY_CONFIG) ttnn.deallocate(hidden_state1) - ttnn.deallocate(C1) + ttnn.deallocate(C0) # Reduction matmul - C3 = ttnn.experimental.operations.primary.transformers.ssm_1d_sum_reduce( - C2, + C2 = ttnn.experimental.operations.primary.transformers.ssm_1d_sum_reduce( + C1, output_mem_config=ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 ), ) - ttnn.deallocate(C2) + ttnn.deallocate(C1) # x * D D = ttnn.to_memory_config(self.D, memory_config=ttnn.L1_MEMORY_CONFIG) @@ -232,8 +234,8 @@ def forward(self, x): ttnn.deallocate(x) # add xD and x - output = ttnn.add(xD, C3, memory_config=ttnn.L1_MEMORY_CONFIG) - ttnn.deallocate(C3) + output = ttnn.add(xD, C2, memory_config=ttnn.L1_MEMORY_CONFIG) + ttnn.deallocate(C2) ttnn.deallocate(xD) return output