Skip to content

Commit

Permalink
#0: Fix performance drop when running Mamba decode after prefill
Browse files Browse the repository at this point in the history
This change also adds a missing symbolic link for Mamba convolution test
and removes some dead code.
  • Loading branch information
esmalTT committed Jul 27, 2024
1 parent 1dbdb17 commit 2229a3b
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 61 deletions.
10 changes: 10 additions & 0 deletions models/demos/wormhole/mamba/tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,13 @@ def test_cache(E, num_users, num_entries, on_host, device):
expected = torch.concat(values[i], dim=2)
did_pass, output_pcc = comp_pcc(expected, ttnn.to_torch(conv_states), 1.0)
assert did_pass

cache.reset()

for i in range(num_entries):
conv_states = cache.concat_users(i)
assert list(conv_states.shape) == [1, 1, num_users, E]

expected = torch.zeros((1, 1, num_users, E), dtype=torch.bfloat16)
did_pass, output_pcc = comp_pcc(expected, ttnn.to_torch(conv_states), 1.0)
assert did_pass
14 changes: 3 additions & 11 deletions models/demos/wormhole/mamba/tt/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,23 @@ def get(self, user_idx: int, entry_idx: int) -> ttnn.Tensor:
self.cache[entry_idx][user_idx], device=self.device, memory_config=self.cache_memory_config
)

def concat_users(self, entry_idx: int, layout=ttnn.TILE_LAYOUT):
def concat_users(self, entry_idx: int, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG):
assert entry_idx < len(self.cache), f"Expected key {entry_idx} to exist in cache"
values = self.cache[entry_idx]
if self.on_host:
values = [
ttnn.to_device(values[i], device=self.device, memory_config=self.cache_memory_config)
ttnn.to_device(values[i], device=self.device, memory_config=memory_config)
for i in range(self.num_users)
]
return ttnn.to_layout(ttnn.concat(values, dim=2), layout)

def reset(self):
for entry_idx in range(len(self.cache)):
for user_idx in range(self.num_users):
ttnn.deallocate(self.cache[entry_idx][user_idx])

self.cache = [
[
ttnn.from_torch(
self.cache[entry_idx][user_idx] = ttnn.from_torch(
torch.zeros(self.entry_shape),
device=self.cache_device,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=self.cache_memory_config,
dtype=ttnn.bfloat16,
)
for _ in range(self.num_users)
]
for _ in range(self.entries_per_user)
]
72 changes: 22 additions & 50 deletions models/demos/wormhole/mamba/tt/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@ def __init__(self, args: ModelArgs, device, configs, load_fn: Callable):
self.conv1d_weights.append(
load_fn(
conv1d_weight_name,
lambda x: x[:, :, i].transpose(-1, -2).repeat(self.batch_size, 1).unsqueeze(0).unsqueeze(0),
postfix=f"{i}_{args.batch_size}",
lambda x: x[:, :, i]
.transpose(-1, -2)
.repeat(self.configs["num_users"], 1)
.unsqueeze(0)
.unsqueeze(0),
postfix=f"{i}_{self.configs['num_users']}",
)
)

Expand All @@ -71,7 +75,6 @@ def __init__(self, args: ModelArgs, device, configs, load_fn: Callable):
lambda x: x.repeat(self.configs["num_users"], 1),
postfix=f"{self.configs['num_users']}",
)

self.conv1d_bias = self.conv1d_bias_prefill

if self.configs["mode"] == ModelMode.DECODE:
Expand All @@ -87,25 +90,13 @@ def __init__(self, args: ModelArgs, device, configs, load_fn: Callable):
elif self.configs["mode"] == ModelMode.PREFILL:
self.convolution_cache = TensorCache(configs["num_users"], 4, self.args.d_inner, device)

self.use_torch_conv = False
if self.use_torch_conv:
self.torch_depthwise_conv1d = torch.nn.Conv1d(
in_channels=self.args.d_inner,
out_channels=self.args.d_inner,
kernel_size=4,
padding=0,
groups=self.args.d_inner,
bias=True,
)
self.torch_depthwise_conv1d.weight.data = load_fn(conv1d_weight_name, return_as_torch=True)
self.torch_depthwise_conv1d.bias.data = load_fn(conv1d_bias_name, return_as_torch=True)
else:
mamba_conv_config = MambaConvConfig(
input_length=self.configs["outer_dim"] + (args.d_conv - 1),
weights_dtype=ttnn.bfloat16,
output_dtype=ttnn.bfloat16,
)
self.mamba_conv = MambaConv(device, load_fn, mamba_conv_config)
mamba_conv_config = MambaConvConfig(
input_length=self.configs["outer_dim"] + (args.d_conv - 1),
weights_dtype=ttnn.bfloat16,
output_dtype=ttnn.bfloat16,
)
self.mamba_conv = MambaConv(device, load_fn, mamba_conv_config)

self.tt_ssm = TtMambaSSM(self.args, self.device, configs, load_fn)

self.compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig(
Expand Down Expand Up @@ -214,34 +205,15 @@ def forward(self, x):
self.convolution_cache.set(self.configs["current_user"], i, entry)
ttnn.deallocate(entry)

if self.use_torch_conv:
x_ssm_torch = ttnn.to_torch(x_ssm).to(torch.float32) # 1, 1, 35, 2E
ttnn.deallocate(x_ssm)
x_ssm_torch = x_ssm_torch.squeeze(0).permute(0, 2, 1)
conv_out_with_bias = self.torch_depthwise_conv1d(x_ssm_torch)
x_ssm_torch.data = torch.tensor([])
conv_out_with_bias = conv_out_with_bias.squeeze(0).permute(1, 0).unsqueeze(0).unsqueeze(0)
conv_out_with_bias = ttnn.from_torch(
conv_out_with_bias,
device=self.device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=self.configs["dtype"]["activations"],
)

else:
conv_out_without_bias = self.mamba_conv(x_ssm)
ttnn.deallocate(x_ssm)
conv_out_with_bias = ttnn.add(
conv_out_without_bias,
self.conv1d_bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=self.configs["dtype"]["activations"],
)
ttnn.deallocate(conv_out_without_bias)

# omit the padding at the end
# conv_out_with_bias = conv_out_with_bias[:, :, :-3]
conv_out_without_bias = self.mamba_conv(x_ssm)
ttnn.deallocate(x_ssm)
conv_out_with_bias = ttnn.add(
conv_out_without_bias,
self.conv1d_bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=self.configs["dtype"]["activations"],
)
ttnn.deallocate(conv_out_without_bias)

conv_out_after_silu = ttnn.silu(conv_out_with_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(conv_out_with_bias)
Expand Down

0 comments on commit 2229a3b

Please sign in to comment.