Skip to content

Commit

Permalink
#8322: integrate ssm_eltwise_mul cases with mamba and verify perf
Browse files Browse the repository at this point in the history
  • Loading branch information
kpaigwar committed May 10, 2024
1 parent 892399c commit 48a2c6d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 19 deletions.
5 changes: 3 additions & 2 deletions models/demos/mamba/tests/test_full_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def forward(self, x):
x = self.lm_head(x)
return x


def run_inference(
device: ttnn.Device,
use_program_cache,
Expand Down Expand Up @@ -109,7 +110,7 @@ def test_inference(
iterations: int,
):
run_inference(device, use_program_cache, model_version, batch, pcc, cache_dir, num_layers, iterations)


@skip_for_grayskull("Not supported on Grayskull")
@pytest.mark.parametrize(
Expand All @@ -123,7 +124,7 @@ def test_device_perf(
model_version="state-spaces/mamba-2.8b",
batch=32,
pcc=0.97,
cache_dir="/tmp",
cache_dir=None,
num_layers=1,
):
run_inference(device, use_program_cache, model_version, batch, pcc, cache_dir, num_layers, iterations)
2 changes: 1 addition & 1 deletion models/demos/mamba/tests/test_mamba_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_mamba_e2e_perf(
@pytest.mark.models_device_performance_bare_metal
@pytest.mark.parametrize(
"batch, warmup, expected_device_fw_duration_ms",
((32, True, 8.02),),
((32, True, 2.81),),
)
def test_mamba_perf_device(batch, warmup, expected_device_fw_duration_ms, reset_seeds):
subdir = "ttnn_mamba"
Expand Down
34 changes: 18 additions & 16 deletions models/demos/mamba/tt/mamba_one_step_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,17 @@ def forward(self, x):
ttnn.deallocate(delta_t2)
ttnn.deallocate(B0)

x0 = self.transformer.repeat_interleave(
# bbar * x
bmulx0 = ttnn.experimental.operations.primary.transformers.ssm_eltwise_mul(
bbar0,
x,
memory_config=ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
output_mem_config=ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1
),
)
bmulx0 = ttnn.mul(bbar0, x0, memory_config=ttnn.L1_MEMORY_CONFIG)

# deallocate bbar
ttnn.deallocate(bbar0)
ttnn.deallocate(x0)

# add amulh and bmulx
hidden_state1 = ttnn.add(amulh0, bmulx0, memory_config=ttnn.L1_MEMORY_CONFIG)
Expand All @@ -206,34 +208,34 @@ def forward(self, x):
core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col),
) # b,n

# repeat using mask+matmul instead of ttnn.repeat to avoid fallback
C1 = self.transformer.repeat(
# c * hidden_state
C1 = ttnn.experimental.operations.primary.transformers.ssm_eltwise_mul(
C0,
memory_config=ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
hidden_state1,
output_mem_config=ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1
),
)
ttnn.deallocate(C0)

C2 = ttnn.mul(hidden_state1, C1, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(hidden_state1)
ttnn.deallocate(C1)
ttnn.deallocate(C0)

# Reduction matmul
C3 = ttnn.experimental.operations.primary.transformers.ssm_1d_sum_reduce(
C2,
C2 = ttnn.experimental.operations.primary.transformers.ssm_1d_sum_reduce(
C1,
output_mem_config=ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1
),
)
ttnn.deallocate(C2)
ttnn.deallocate(C1)

# x * D
D = ttnn.to_memory_config(self.D, memory_config=ttnn.L1_MEMORY_CONFIG)
xD = ttnn.mul(x, D, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(x)

# add xD and x
output = ttnn.add(xD, C3, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(C3)
output = ttnn.add(xD, C2, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(C2)
ttnn.deallocate(xD)

return output

0 comments on commit 48a2c6d

Please sign in to comment.