Skip to content

Commit

Permalink
#0: Support any sequence length in Mamba prefill demo
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT authored Jul 18, 2024
1 parent 0479d72 commit 9c81d53
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 65 deletions.
18 changes: 16 additions & 2 deletions models/demos/mamba/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from models.demos.mamba.reference.decode_model import MambaPretrainedModelName
from models.demos.mamba.reference.args import ModelMode
from models.demos.mamba.tt import model_config
from models.demos.mamba.tt.preprocessing import split_sequence_length


def get_cpu_reference_model(version: MambaPretrainedModelName, batch_size: int):
Expand Down Expand Up @@ -159,11 +160,24 @@ def run_mamba_prefill_decode_demo(

# Prefill
model.to_prefill()
prefill_chunk_size = 32
num_users = sequences.shape[0]

prefill_tokens = sequences[:, :-1] # Omit the last token in the sequence (B, L - 1)

prefill_tokens = ttnn.from_torch(
prefill_tokens.view(1, 1, prefill_tokens.shape[0], prefill_tokens.shape[1]),
device=device,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
dtype=ttnn.uint32,
)
for user_idx in tqdm(range(num_users), desc="Prefilling the prompt(s)..."):
with torch.no_grad():
model(sequences[user_idx, :-1].unsqueeze(0)) # Omit the last token in the sequence
model.configs["current_user"] += 1
for chunk in split_sequence_length(prefill_tokens, batch=user_idx, chunk_size=prefill_chunk_size):
chunk = ttnn.reshape(chunk, [1, chunk.shape[3]]) # Mamba expects (1, L) in prefill mode
model._forward(chunk)
model.configs["current_user"] += 1

# Decode
decode_model_config = model_config.create_model_config(
Expand Down
9 changes: 8 additions & 1 deletion models/demos/mamba/tests/test_mamba_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@ def test_demo(user_input, model_version, device, use_program_cache, get_tt_cache
"Climate change refers to long-term shifts in temperatures and weather patterns. Such shifts can be natural due to changes in the sun's activity or volcanic eruptions."
],
"state-spaces/mamba-2.8b-slimpj",
2,
4,
),
(
[
"The city of Sarnia is located on the eastern shore of Lake Huron at its extreme southern point where it flows into the St. Clair River . Most of the surrounding area is flat , and the elevation ranges from 169 metres ( 554 ft ) and 281 metres ( 922 ft ) above sea level . The soil mostly comprises clay . Despite this high percentage of clay , the soil is remarkably rich for cultivation . Prior to the Ice Age , glaciers covered most of the area , as can be seen not only by the existence of the Great Lakes themselves but also of alluvial sand deposits, terminal moraines, and rich oil reserves."
],
"state-spaces/mamba-2.8b-slimpj",
4,
),
),
)
Expand Down
54 changes: 54 additions & 0 deletions models/demos/mamba/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from loguru import logger
import pytest
import torch
import ttnn

from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_allclose,
comp_pcc,
)
from models.demos.mamba.tt.preprocessing import split_sequence_length


@pytest.mark.parametrize(
"layout",
[ttnn.ROW_MAJOR_LAYOUT],
)
@pytest.mark.parametrize(
"B, L, chunk_size, num_chunks",
(
(32, 32, 32, 1),
(32, 64, 32, 2),
(32, 128, 32, 4),
(32, 128, 32, 2),
(32, 1024, 32, 8),
),
)
def test_splitting_sequence_length(
B: int, L: int, chunk_size: int, num_chunks: int, layout: ttnn.Layout, device: ttnn.Device
):
expected = torch.randint(0, 255, (1, 1, B, L), dtype=torch.int32)

x = ttnn.from_torch(expected, dtype=ttnn.int32, device=device, layout=layout)

result = []
for batch_idx in range(B):
chunks = []
for chunk in split_sequence_length(x, batch=batch_idx, chunk_size=chunk_size):
assert list(chunk.shape) == [1, 1, 1, chunk_size]
chunks.append(chunk)
result.append(ttnn.to_torch(ttnn.concat(chunks, dim=-1)))

actual = torch.concat(result, dim=-2)
assert actual.shape == x.shape, "Expected input shape to match output shape"

does_pass, output_pcc = comp_pcc(expected, actual, 1.0)
logger.info(f"PCC value: {output_pcc}")
assert does_pass, f"PCC value ({output_pcc}) is lower than 1.0"

does_pass, output_allclose = comp_allclose(expected, actual)
assert does_pass, "Allclose check failed: {output_allclose}"
34 changes: 21 additions & 13 deletions models/demos/mamba/tt/full_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,25 +127,20 @@ def to_decode(self, decode_config):
def embedding(self, x):
assert len(x.shape) == 2, f"Mamba expects inputs to be rank 2 (was {len(x.shape)})"
x = ttnn.embedding(
x, self.embedding_weights, output_dtype=ttnn.bfloat16
x, self.embedding_weights, output_dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG
) # ttnn.embedding always returns (B, L, E)
return ttnn.reshape(x, [1, 1, self.configs["outer_dim"], x.shape[2]])

def forward(self, x):
assert len(x.shape) == 2, f"Mamba expects inputs to be rank 2 (was {len(x.shape)})"

x = ttnn.from_torch(
x,
device=self.device,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=ttnn.uint32,
)
def _forward(self, x):
assert len(x.shape) == 2, f"Expected tensor to be rank 2 (shape was {x.shape})"
assert (
x.shape[-1] <= self.configs["max_seq_length"]
), f"Expected L to be less than or equal to max sequence length (was {x.shape[-1]}, expected <= {self.configs['max_seq_length']})"

x = self.embedding(x)
x = ttnn.typecast(ttnn.to_layout(x, ttnn.TILE_LAYOUT), self.configs["dtype"]["activations"])

assert len(x.shape) == 4, f"Expected embedding to be rank 4 (was {len(x.shape)})"
assert len(x.shape) == 4, f"Expected embedding output to be rank 4 (shape was {x.shape})"
assert x.layout == ttnn.TILE_LAYOUT, f"Expected embedding to be tile layout (was {x.layout})"

for i, layer in enumerate(self.layers):
Expand All @@ -169,7 +164,20 @@ def forward(self, x):
compute_kernel_config=self.compute_kernel_config,
dtype=self.configs["dtype"]["activations"],
)

return x

def forward(self, x):
assert len(x.shape) == 2, f"Mamba expects inputs to be rank 2 (was {len(x.shape)})"
x = ttnn.from_torch(
x,
device=self.device,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=ttnn.uint32,
)
x = self._forward(x)
if self.return_logits or self.configs["mode"] == ModelMode.DECODE:
x = ttnn.to_torch(x).to(torch.float32) # (1, 1, B, E)
x = x.view((self.configs["batch_size"], self.configs["seq_len"], -1))

return x
2 changes: 1 addition & 1 deletion models/demos/mamba/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
def create_model_config(batch_size, hidden_size, mode=ModelMode.DECODE, seq_len=1):
configs = {}
latent = 32
configs["max_seq_length"] = 128
configs["core_grid_row"] = 5
configs["core_grid_col"] = 8
configs["latent_size"] = latent
Expand All @@ -26,7 +27,6 @@ def create_model_config(batch_size, hidden_size, mode=ModelMode.DECODE, seq_len=
outer_dim = seq_len
assert batch_size == 1, "Batch size must be 1 for prefill model"
assert seq_len % 32 == 0, "Sequence length must be a multiple of 32 for prefill model"
assert seq_len <= 128, "Sequence length must be less than 128 for prefill model"
else:
raise ValueError(f"Invalid model mode: {mode}")

Expand Down
34 changes: 34 additions & 0 deletions models/demos/mamba/tt/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn


def split_sequence_length(x, batch: int = 0, chunk_size: int = 32):
"""
Generator function to yield chunks of a tensor of shape (1, 1, B, L) into (1, 1, 1, chunk_size).
Parameters:
tensor (torch.Tensor): The input tensor of shape (1, 1, B, L).
batch (int): The batch dimension to select. Default is 0.
chunk_size (int): The size of each chunk along the third dimension. Default is 32.
Yields:
torch.Tensor: Chunks of the input tensor of shape (1, 1, 1, chunk_size).
"""

assert x.layout == ttnn.ROW_MAJOR_LAYOUT, f"Expected input to be row-major layout (was {x.layout})"
assert len(x.shape) == 4, f"Expected input to be rank 4 (was {x.shape})"
assert x.shape[3] % 32 == 0, "Sequence length size must be multiple of 32"

assert chunk_size % 32 == 0, "Chunk size must be multiple of 32"

_, _, B, L = x.shape

assert batch < B, f"Expected batch index (was {batch}) to be less than the size of batch dimension (was {B})"

for i in range(0, L, chunk_size):
slice_start = (0, 0, batch, i)
slice_end = (0, 0, batch, i + chunk_size - 1)
yield ttnn.slice(x, ttnn.Shape(slice_start), ttnn.Shape(slice_end))

This file was deleted.

0 comments on commit 9c81d53

Please sign in to comment.