Skip to content

Commit

Permalink
#0: Move Mamba demo to models/demos/wormhole
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Jul 18, 2024
1 parent 9c81d53 commit 7f57b73
Show file tree
Hide file tree
Showing 47 changed files with 76 additions and 75 deletions.
2 changes: 1 addition & 1 deletion CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ models/demos/metal_BERT_large_11 @tt-aho @TT-BrianLiu
models/demos/wormhole @uaydonat @eyonland @AleksKnezevic @nsmithtt
models/demos/t3000 @uaydonat @AleksKnezevic @nsmithtt
models/demos/falcon7b @skhorasganiTT @djordje-tt @uaydonat @pavlejosipovic @pavlepopovic @s-jovic
models/demos/mamba @esmalTT @uaydonat @kpaigwar
models/demos/wormhole/mamba @esmalTT @uaydonat @kpaigwar
models/demos/wormhole/falcon7b @skhorasganiTT @djordje-tt @uaydonat @pavlejosipovic @pavlepopovic @s-jovic
models/demos/wormhole/mistral7b @yieldthought @uaydonat @mtairum
models/demos/t3000/falcon40b @johanna-rock-tt @uaydonat @s-jovic
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
|----------------------------------------------------------------------------------------|--------------------|----------------------|------------------------------|-----------------------------|----------------|
| [Falcon7B](./models/demos/wormhole/falcon7b) | 129th | 32 | 13.3 t/s/u - 425 t/s | 15.4 t/s/u - 493 t/s | 26 |
| [Mistral-7B](./models/demos/wormhole/mistral7b) | 129th | 32 | 9.9 t/s/u - 317 t/s | 11.0 t/s/u - 352 t/s | 25 |
| [Mamba-2.8B](./models/demos/mamba) | any | 32 | 11.6 t/s/u - 370 t/s | 16.5 t/s/u - 528 t/s | 41 |
| [Mamba-2.8B](./models/demos/wormhole/mamba) | any | 32 | 11.6 t/s/u - 370 t/s | 16.5 t/s/u - 528 t/s | 41 |
| [BERT-Large](./models/demos/metal_BERT_large_11/) (sen/s) [4] | | 8 | 270 | 340 | 400 |
| [Stable Diffusion 1.4](./models/demos/wormhole/stable_diffusion) 512x512 (sec/img) [5] | | 1 | 6 | 5 | 3 |
| [ResNet-50](./models/demos/ttnn_resnet) (fps) | | 16 | 4,300 | 5,550 | 7,000 |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml
To run the model for a single user you can use the command line input:

```
pytest --disable-warnings -q -s --input-method=cli --cli-input="YOUR PROMPT GOES HERE!" models/demos/mamba/demo/demo.py
pytest --disable-warnings -q -s --input-method=cli --cli-input="YOUR PROMPT GOES HERE!" models/demos/wormhole/mamba/demo/demo.py
```

To run the demo using pre-written prompts for a batch of 32 users run:

```
pytest --disable-warnings -q -s --input-method=json --input-path='models/demos/mamba/demo/prompts.json' models/demos/mamba/demo/demo.py
pytest --disable-warnings -q -s --input-method=json --input-path='models/demos/wormhole/mamba/demo/prompts.json' models/demos/wormhole/mamba/demo/demo.py
```

To run the demo using custom input prompts, you can provide a different path to the input prompts file for e.g.:

```
pytest --disable-warnings -q -s --input-method=json --input-path='path_to_input_prompts.json' models/demos/mamba/demo/demo.py
pytest --disable-warnings -q -s --input-method=json --input-path='path_to_input_prompts.json' models/demos/wormhole/mamba/demo/demo.py
```

Any sequence length is supported. We currently only support JSON file with strictly 32 user prompts with same token length. Check the `models/demos/mamba/demo/prompts.json` file for reference.
Any sequence length is supported. We currently only support JSON file with strictly 32 user prompts with same token length. Check the `models/demos/wormhole/mamba/demo/prompts.json` file for reference.

The prefill graph is not currently integrated into the demo. Therefore we currently process the prompt a single token at a time using the decode graph.

Expand All @@ -48,27 +48,27 @@ cd tt-metal
### SSM Block

```
pytest -svv models/demos/mamba/tests/test_mamba_ssm.py
pytest -svv models/demos/wormhole/mamba/tests/test_mamba_ssm.py
```

### Mamba Block

```
pytest -svv models/demos/mamba/tests/test_mamba_block.py
pytest -svv models/demos/wormhole/mamba/tests/test_mamba_block.py
```

### Residual Block

```
pytest -svv models/demos/mamba/tests/test_residual_block.py
pytest -svv models/demos/wormhole/mamba/tests/test_residual_block.py
```

### Full Model

Note : input embedding layer and TopK are on CPU

```
pytest -svv models/demos/mamba/tests/test_full_model.py::test_inference
pytest -svv models/demos/wormhole/mamba/tests/test_full_model.py::test_inference
```

## Performance Tests
Expand All @@ -79,15 +79,15 @@ These tests are designed to evaluate device-side and host performance of Mamba m
### End-to-End Model Performance

```bash
pytest -svv models/demos/mamba/tests/test_mamba_perf.py -m models_performance_bare_metal
pytest -svv models/demos/wormhole/mamba/tests/test_mamba_perf.py -m models_performance_bare_metal
```

### Device-Side Performance

Build with profiler support enabled (use the build script `./scripts/build_scripts/build_with_profiler_opt.sh`) and run the following command to test device-side performance:

```
pytest -svv models/demos/mamba/tests/test_mamba_perf.py -m models_device_performance_bare_metal
pytest -svv models/demos/wormhole/mamba/tests/test_mamba_perf.py -m models_device_performance_bare_metal
```

This will also generate device and host profiling logs in directory `generated/profiler/reports/ttnn_mamba`
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from transformers import AutoTokenizer

from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName
from models.demos.mamba.benchmarks.loglikelihood import compute_loglikelihood_given_prompt_and_target
from models.demos.wormhole.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName
from models.demos.wormhole.mamba.benchmarks.loglikelihood import compute_loglikelihood_given_prompt_and_target

from lm_eval.api.model import LM
from lm_eval.api.instance import Instance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from models.demos.mamba.reference.decode_model import MambaDecode
from models.demos.wormhole.mamba.reference.decode_model import MambaDecode


def compute_loglikelihood(logits, labels) -> float:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@

from transformers import AutoTokenizer

from models.demos.mamba.reference.decode_model import MambaPretrainedModelName
from models.demos.mamba.reference.args import ModelMode
from models.demos.mamba.tt import model_config
from models.demos.mamba.tt.preprocessing import split_sequence_length
from models.demos.wormhole.mamba.reference.decode_model import MambaPretrainedModelName
from models.demos.wormhole.mamba.reference.args import ModelMode
from models.demos.wormhole.mamba.tt import model_config
from models.demos.wormhole.mamba.tt.preprocessing import split_sequence_length


def get_cpu_reference_model(version: MambaPretrainedModelName, batch_size: int):
from models.demos.mamba.reference.decode_model import MambaDecode
from models.demos.wormhole.mamba.reference.decode_model import MambaDecode

return MambaDecode.from_pretrained(version, batch_size=batch_size)

Expand All @@ -32,8 +32,8 @@ def get_tt_metal_model(
mode: ModelMode = ModelMode.DECODE,
seq_len: int = 1,
):
from models.demos.mamba.tt.full_model import MambaTT
from models.demos.mamba.tt import model_config
from models.demos.wormhole.mamba.tt.full_model import MambaTT
from models.demos.wormhole.mamba.tt import model_config

reference_model = get_cpu_reference_model(version, batch_size=batch_size)
config = model_config.create_model_config(batch_size, reference_model.args.d_model, mode=mode, seq_len=seq_len)
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import torch.nn.functional as F
from einops import rearrange, repeat, einsum

from models.demos.mamba.reference.args import ModelArgs
from models.demos.wormhole.mamba.reference.args import ModelArgs

from typing import Literal, cast

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import torch.nn.functional as F
from einops import rearrange, repeat, einsum

from models.demos.mamba.reference.args import ModelArgs, ModelMode
from models.demos.wormhole.mamba.reference.args import ModelArgs, ModelMode

from typing import Literal, cast

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pytest
import torch

from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName
from models.demos.mamba.benchmarks.loglikelihood import (
from models.demos.wormhole.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName
from models.demos.wormhole.mamba.benchmarks.loglikelihood import (
compute_loglikelihood,
compute_loglikelihood_given_prompt_and_target,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ttnn
import torch

from models.demos.mamba.tt.cache import TensorCache
from models.demos.wormhole.mamba.tt.cache import TensorCache

from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_pcc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from transformers import AutoTokenizer
from typing import Optional
import ttnn
from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName
from models.demos.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName
from models.demos.mamba.reference.args import ModelMode
from models.demos.mamba.tt.full_model import MambaTT
from models.demos.mamba.tt import model_config
from models.demos.wormhole.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName
from models.demos.wormhole.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName
from models.demos.wormhole.mamba.reference.args import ModelMode
from models.demos.wormhole.mamba.tt.full_model import MambaTT
from models.demos.wormhole.mamba.tt import model_config
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_allclose,
comp_pcc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from loguru import logger
from typing import Optional
import ttnn
from models.demos.mamba.tt.full_model import TtTensorLoader
from models.demos.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName
from models.demos.mamba.reference.args import ModelMode
from models.demos.mamba.tt.mamba_block import TtMambaBlock
from models.demos.mamba.tt import model_config
from models.demos.wormhole.mamba.tt.full_model import TtTensorLoader
from models.demos.wormhole.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName
from models.demos.wormhole.mamba.reference.args import ModelMode
from models.demos.wormhole.mamba.tt.mamba_block import TtMambaBlock
from models.demos.wormhole.mamba.tt import model_config
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_allclose,
comp_pcc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# SPDX-License-Identifier: Apache-2.0

from models.demos.mamba.demo.demo import run_mamba_demo, run_mamba_prefill_decode_demo
from models.demos.wormhole.mamba.demo.demo import run_mamba_demo, run_mamba_prefill_decode_demo
import pytest


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import json

from models.demos.mamba.demo.demo import (
from models.demos.wormhole.mamba.demo.demo import (
get_tokenizer,
get_cpu_reference_model,
get_tt_metal_model,
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_mamba_e2e_perf(
profiler.clear()

# Load prompts
with open("models/demos/mamba/demo/prompts.json", "r") as f:
with open("models/demos/wormhole/mamba/demo/prompts.json", "r") as f:
prompts = json.load(f)

profiler.start("pytorch_ref_model_setup")
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_mamba_perf_device(batch, warmup, expected_device_fw_duration_ms, reset_
inference_iterations = 2
else:
inference_iterations = 1
command = f"pytest models/demos/mamba/tests/test_full_model.py::test_device_perf[{inference_iterations}]"
command = f"pytest models/demos/wormhole/mamba/tests/test_full_model.py::test_device_perf[{inference_iterations}]"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

# convert expected perf (ms) to samples/s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from loguru import logger
from typing import Optional
import ttnn
from models.demos.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName
from models.demos.mamba.reference.args import ModelMode
from models.demos.mamba.tt.full_model import TtTensorLoader
from models.demos.mamba.tt.mamba_one_step_ssm import TtMambaSSM
from models.demos.mamba.tt import model_config
from models.demos.wormhole.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName
from models.demos.wormhole.mamba.reference.args import ModelMode
from models.demos.wormhole.mamba.tt.full_model import TtTensorLoader
from models.demos.wormhole.mamba.tt.mamba_one_step_ssm import TtMambaSSM
from models.demos.wormhole.mamba.tt import model_config
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_allclose,
comp_pcc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
comp_allclose,
comp_pcc,
)
from models.demos.mamba.tt.preprocessing import split_sequence_length
from models.demos.wormhole.mamba.tt.preprocessing import split_sequence_length


@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from typing import Optional
from transformers import AutoTokenizer

from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName
from models.demos.mamba.reference.prefill_decode_model import Mamba as MambaPrefillDecode
from models.demos.mamba.reference.model import Mamba
from models.demos.wormhole.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName
from models.demos.wormhole.mamba.reference.prefill_decode_model import Mamba as MambaPrefillDecode
from models.demos.wormhole.mamba.reference.model import Mamba


def generate_through_selective_scan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from loguru import logger
from typing import Optional
import ttnn
from models.demos.mamba.tt.full_model import TtTensorLoader
from models.demos.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName
from models.demos.mamba.reference.args import ModelMode
from models.demos.mamba.tt.residual_block import TtResidualBlock
from models.demos.mamba.tt import model_config
from models.demos.wormhole.mamba.tt.full_model import TtTensorLoader
from models.demos.wormhole.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName
from models.demos.wormhole.mamba.reference.args import ModelMode
from models.demos.wormhole.mamba.tt.residual_block import TtResidualBlock
from models.demos.wormhole.mamba.tt import model_config
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_allclose,
comp_pcc,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from pathlib import Path
from typing import Callable, Optional

from models.demos.mamba.tt.residual_block import TtResidualBlock
from models.demos.mamba.reference.args import ModelMode
from models.demos.wormhole.mamba.tt.residual_block import TtResidualBlock
from models.demos.wormhole.mamba.reference.args import ModelMode


class TtTensorLoader:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import tt_lib as ttl
from typing import Callable

from models.demos.mamba.reference.args import ModelArgs
from models.demos.mamba.reference.args import ModelMode
from models.demos.mamba.tt.mamba_one_step_ssm import TtMambaSSM
from models.demos.mamba.tt.cache import TensorCache
from models.demos.wormhole.mamba.reference.args import ModelArgs
from models.demos.wormhole.mamba.reference.args import ModelMode
from models.demos.wormhole.mamba.tt.mamba_one_step_ssm import TtMambaSSM
from models.demos.wormhole.mamba.tt.cache import TensorCache


class TtMambaBlock(torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import tt_lib as ttl
from typing import Callable

from models.demos.mamba.reference.args import ModelArgs, ModelMode
from models.demos.mamba.tt.cache import TensorCache
from models.demos.wormhole.mamba.reference.args import ModelArgs, ModelMode
from models.demos.wormhole.mamba.tt.cache import TensorCache


class TtMambaSSM(torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import ttnn
import os
from models.demos.mamba.reference.args import ModelMode

from models.demos.wormhole.mamba.reference.args import ModelMode


def create_model_config(batch_size, hidden_size, mode=ModelMode.DECODE, seq_len=1):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from typing import Callable

from models.demos.mamba.reference.args import ModelArgs
from models.demos.mamba.tt.mamba_block import TtMambaBlock
from models.demos.wormhole.mamba.reference.args import ModelArgs
from models.demos.wormhole.mamba.tt.mamba_block import TtMambaBlock


class TtResidualBlock(torch.nn.Module):
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit 7f57b73

Please sign in to comment.