From 7f57b738ec2f7546e65b5d131666bcb4e4d9c4b4 Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Thu, 18 Jul 2024 20:40:27 +0000 Subject: [PATCH] #0: Move Mamba demo to models/demos/wormhole --- CODEOWNERS | 2 +- README.md | 2 +- models/demos/{ => wormhole}/mamba/README.md | 20 +++++++++---------- .../{ => wormhole}/mamba/benchmarks/README.md | 0 .../mamba/benchmarks/lm_harness_eval.py | 4 ++-- .../mamba/benchmarks/loglikelihood.py | 2 +- .../demos/{ => wormhole}/mamba/demo/demo.py | 14 ++++++------- .../{ => wormhole}/mamba/demo/prompts.json | 0 .../{ => wormhole}/mamba/reference/args.py | 0 .../mamba/reference/decode_model.py | 2 +- .../{ => wormhole}/mamba/reference/model.py | 0 .../mamba/reference/prefill_decode_model.py | 2 +- .../mamba/tests/test_benchmarks.py | 4 ++-- .../{ => wormhole}/mamba/tests/test_cache.py | 2 +- .../mamba/tests/test_full_model.py | 10 +++++----- .../mamba/tests/test_mamba_block.py | 10 +++++----- .../mamba/tests/test_mamba_demo.py | 2 +- .../mamba/tests/test_mamba_perf.py | 6 +++--- .../mamba/tests/test_mamba_ssm.py | 10 +++++----- .../mamba/tests/test_preprocessing.py | 2 +- .../mamba/tests/test_reference_model.py | 6 +++--- .../mamba/tests/test_residual_block.py | 10 +++++----- models/demos/{ => wormhole}/mamba/tt/cache.py | 0 .../{ => wormhole}/mamba/tt/full_model.py | 4 ++-- .../{ => wormhole}/mamba/tt/mamba_block.py | 8 ++++---- .../mamba/tt/mamba_one_step_ssm.py | 4 ++-- .../{ => wormhole}/mamba/tt/model_config.py | 3 ++- .../{ => wormhole}/mamba/tt/preprocessing.py | 0 .../{ => wormhole}/mamba/tt/residual_block.py | 4 ++-- .../demos/mamba/tests/test_benchmarks.py | 1 - .../models/demos/mamba/tests/test_cache.py | 1 - .../demos/mamba/tests/test_full_model.py | 1 - .../demos/mamba/tests/test_mamba_block.py | 1 - .../demos/mamba/tests/test_mamba_demo.py | 1 - .../demos/mamba/tests/test_mamba_ssm.py | 1 - .../demos/mamba/tests/test_preprocessing.py | 1 - .../demos/mamba/tests/test_reference_model.py | 1 - .../demos/mamba/tests/test_residual_block.py | 1 - .../demos/wormhole/mamba/test_benchmarks.py | 1 + .../models/demos/wormhole/mamba/test_cache.py | 1 + .../demos/wormhole/mamba/test_full_model.py | 1 + .../demos/wormhole/mamba/test_mamba_block.py | 1 + .../demos/wormhole/mamba/test_mamba_demo.py | 1 + .../demos/wormhole/mamba/test_mamba_ssm.py | 1 + .../wormhole/mamba/test_preprocessing.py | 1 + .../wormhole/mamba/test_reference_model.py | 1 + .../wormhole/mamba/test_residual_block.py | 1 + 47 files changed, 76 insertions(+), 75 deletions(-) rename models/demos/{ => wormhole}/mamba/README.md (73%) rename models/demos/{ => wormhole}/mamba/benchmarks/README.md (100%) rename models/demos/{ => wormhole}/mamba/benchmarks/lm_harness_eval.py (92%) rename models/demos/{ => wormhole}/mamba/benchmarks/loglikelihood.py (96%) rename models/demos/{ => wormhole}/mamba/demo/demo.py (94%) rename models/demos/{ => wormhole}/mamba/demo/prompts.json (100%) rename models/demos/{ => wormhole}/mamba/reference/args.py (100%) rename models/demos/{ => wormhole}/mamba/reference/decode_model.py (99%) rename models/demos/{ => wormhole}/mamba/reference/model.py (100%) rename models/demos/{ => wormhole}/mamba/reference/prefill_decode_model.py (99%) rename models/demos/{ => wormhole}/mamba/tests/test_benchmarks.py (93%) rename models/demos/{ => wormhole}/mamba/tests/test_cache.py (96%) rename models/demos/{ => wormhole}/mamba/tests/test_full_model.py (93%) rename models/demos/{ => wormhole}/mamba/tests/test_mamba_block.py (87%) rename models/demos/{ => wormhole}/mamba/tests/test_mamba_demo.py (95%) rename models/demos/{ => wormhole}/mamba/tests/test_mamba_perf.py (95%) rename models/demos/{ => wormhole}/mamba/tests/test_mamba_ssm.py (87%) rename models/demos/{ => wormhole}/mamba/tests/test_preprocessing.py (95%) rename models/demos/{ => wormhole}/mamba/tests/test_reference_model.py (89%) rename models/demos/{ => wormhole}/mamba/tests/test_residual_block.py (87%) rename models/demos/{ => wormhole}/mamba/tt/cache.py (100%) rename models/demos/{ => wormhole}/mamba/tt/full_model.py (97%) rename models/demos/{ => wormhole}/mamba/tt/mamba_block.py (97%) rename models/demos/{ => wormhole}/mamba/tt/mamba_one_step_ssm.py (98%) rename models/demos/{ => wormhole}/mamba/tt/model_config.py (98%) rename models/demos/{ => wormhole}/mamba/tt/preprocessing.py (100%) rename models/demos/{ => wormhole}/mamba/tt/residual_block.py (93%) delete mode 120000 tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_benchmarks.py delete mode 120000 tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_cache.py delete mode 120000 tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_full_model.py delete mode 120000 tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_block.py delete mode 120000 tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_demo.py delete mode 120000 tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_ssm.py delete mode 120000 tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_preprocessing.py delete mode 120000 tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_reference_model.py delete mode 120000 tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_residual_block.py create mode 120000 tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_benchmarks.py create mode 120000 tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_cache.py create mode 120000 tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_full_model.py create mode 120000 tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_block.py create mode 120000 tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_demo.py create mode 120000 tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_ssm.py create mode 120000 tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_preprocessing.py create mode 120000 tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_reference_model.py create mode 120000 tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_residual_block.py diff --git a/CODEOWNERS b/CODEOWNERS index 55bca38f156a..42647cf42127 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -141,7 +141,7 @@ models/demos/metal_BERT_large_11 @tt-aho @TT-BrianLiu models/demos/wormhole @uaydonat @eyonland @AleksKnezevic @nsmithtt models/demos/t3000 @uaydonat @AleksKnezevic @nsmithtt models/demos/falcon7b @skhorasganiTT @djordje-tt @uaydonat @pavlejosipovic @pavlepopovic @s-jovic -models/demos/mamba @esmalTT @uaydonat @kpaigwar +models/demos/wormhole/mamba @esmalTT @uaydonat @kpaigwar models/demos/wormhole/falcon7b @skhorasganiTT @djordje-tt @uaydonat @pavlejosipovic @pavlepopovic @s-jovic models/demos/wormhole/mistral7b @yieldthought @uaydonat @mtairum models/demos/t3000/falcon40b @johanna-rock-tt @uaydonat @s-jovic diff --git a/README.md b/README.md index e58e63543634..1a3e14ba7424 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ |----------------------------------------------------------------------------------------|--------------------|----------------------|------------------------------|-----------------------------|----------------| | [Falcon7B](./models/demos/wormhole/falcon7b) | 129th | 32 | 13.3 t/s/u - 425 t/s | 15.4 t/s/u - 493 t/s | 26 | | [Mistral-7B](./models/demos/wormhole/mistral7b) | 129th | 32 | 9.9 t/s/u - 317 t/s | 11.0 t/s/u - 352 t/s | 25 | -| [Mamba-2.8B](./models/demos/mamba) | any | 32 | 11.6 t/s/u - 370 t/s | 16.5 t/s/u - 528 t/s | 41 | +| [Mamba-2.8B](./models/demos/wormhole/mamba) | any | 32 | 11.6 t/s/u - 370 t/s | 16.5 t/s/u - 528 t/s | 41 | | [BERT-Large](./models/demos/metal_BERT_large_11/) (sen/s) [4] | | 8 | 270 | 340 | 400 | | [Stable Diffusion 1.4](./models/demos/wormhole/stable_diffusion) 512x512 (sec/img) [5] | | 1 | 6 | 5 | 3 | | [ResNet-50](./models/demos/ttnn_resnet) (fps) | | 16 | 4,300 | 5,550 | 7,000 | diff --git a/models/demos/mamba/README.md b/models/demos/wormhole/mamba/README.md similarity index 73% rename from models/demos/mamba/README.md rename to models/demos/wormhole/mamba/README.md index 94bc6e189e3b..37e9bf1b58df 100644 --- a/models/demos/mamba/README.md +++ b/models/demos/wormhole/mamba/README.md @@ -15,22 +15,22 @@ export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml To run the model for a single user you can use the command line input: ``` -pytest --disable-warnings -q -s --input-method=cli --cli-input="YOUR PROMPT GOES HERE!" models/demos/mamba/demo/demo.py +pytest --disable-warnings -q -s --input-method=cli --cli-input="YOUR PROMPT GOES HERE!" models/demos/wormhole/mamba/demo/demo.py ``` To run the demo using pre-written prompts for a batch of 32 users run: ``` -pytest --disable-warnings -q -s --input-method=json --input-path='models/demos/mamba/demo/prompts.json' models/demos/mamba/demo/demo.py +pytest --disable-warnings -q -s --input-method=json --input-path='models/demos/wormhole/mamba/demo/prompts.json' models/demos/wormhole/mamba/demo/demo.py ``` To run the demo using custom input prompts, you can provide a different path to the input prompts file for e.g.: ``` -pytest --disable-warnings -q -s --input-method=json --input-path='path_to_input_prompts.json' models/demos/mamba/demo/demo.py +pytest --disable-warnings -q -s --input-method=json --input-path='path_to_input_prompts.json' models/demos/wormhole/mamba/demo/demo.py ``` -Any sequence length is supported. We currently only support JSON file with strictly 32 user prompts with same token length. Check the `models/demos/mamba/demo/prompts.json` file for reference. +Any sequence length is supported. We currently only support JSON file with strictly 32 user prompts with same token length. Check the `models/demos/wormhole/mamba/demo/prompts.json` file for reference. The prefill graph is not currently integrated into the demo. Therefore we currently process the prompt a single token at a time using the decode graph. @@ -48,19 +48,19 @@ cd tt-metal ### SSM Block ``` -pytest -svv models/demos/mamba/tests/test_mamba_ssm.py +pytest -svv models/demos/wormhole/mamba/tests/test_mamba_ssm.py ``` ### Mamba Block ``` -pytest -svv models/demos/mamba/tests/test_mamba_block.py +pytest -svv models/demos/wormhole/mamba/tests/test_mamba_block.py ``` ### Residual Block ``` -pytest -svv models/demos/mamba/tests/test_residual_block.py +pytest -svv models/demos/wormhole/mamba/tests/test_residual_block.py ``` ### Full Model @@ -68,7 +68,7 @@ pytest -svv models/demos/mamba/tests/test_residual_block.py Note : input embedding layer and TopK are on CPU ``` -pytest -svv models/demos/mamba/tests/test_full_model.py::test_inference +pytest -svv models/demos/wormhole/mamba/tests/test_full_model.py::test_inference ``` ## Performance Tests @@ -79,7 +79,7 @@ These tests are designed to evaluate device-side and host performance of Mamba m ### End-to-End Model Performance ```bash -pytest -svv models/demos/mamba/tests/test_mamba_perf.py -m models_performance_bare_metal +pytest -svv models/demos/wormhole/mamba/tests/test_mamba_perf.py -m models_performance_bare_metal ``` ### Device-Side Performance @@ -87,7 +87,7 @@ pytest -svv models/demos/mamba/tests/test_mamba_perf.py -m models_performance_ba Build with profiler support enabled (use the build script `./scripts/build_scripts/build_with_profiler_opt.sh`) and run the following command to test device-side performance: ``` -pytest -svv models/demos/mamba/tests/test_mamba_perf.py -m models_device_performance_bare_metal +pytest -svv models/demos/wormhole/mamba/tests/test_mamba_perf.py -m models_device_performance_bare_metal ``` This will also generate device and host profiling logs in directory `generated/profiler/reports/ttnn_mamba` diff --git a/models/demos/mamba/benchmarks/README.md b/models/demos/wormhole/mamba/benchmarks/README.md similarity index 100% rename from models/demos/mamba/benchmarks/README.md rename to models/demos/wormhole/mamba/benchmarks/README.md diff --git a/models/demos/mamba/benchmarks/lm_harness_eval.py b/models/demos/wormhole/mamba/benchmarks/lm_harness_eval.py similarity index 92% rename from models/demos/mamba/benchmarks/lm_harness_eval.py rename to models/demos/wormhole/mamba/benchmarks/lm_harness_eval.py index 14beaba3d8be..20164fa58e14 100644 --- a/models/demos/mamba/benchmarks/lm_harness_eval.py +++ b/models/demos/wormhole/mamba/benchmarks/lm_harness_eval.py @@ -9,8 +9,8 @@ from transformers import AutoTokenizer -from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName -from models.demos.mamba.benchmarks.loglikelihood import compute_loglikelihood_given_prompt_and_target +from models.demos.wormhole.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName +from models.demos.wormhole.mamba.benchmarks.loglikelihood import compute_loglikelihood_given_prompt_and_target from lm_eval.api.model import LM from lm_eval.api.instance import Instance diff --git a/models/demos/mamba/benchmarks/loglikelihood.py b/models/demos/wormhole/mamba/benchmarks/loglikelihood.py similarity index 96% rename from models/demos/mamba/benchmarks/loglikelihood.py rename to models/demos/wormhole/mamba/benchmarks/loglikelihood.py index 0984361c7c6a..83690bddbf26 100644 --- a/models/demos/mamba/benchmarks/loglikelihood.py +++ b/models/demos/wormhole/mamba/benchmarks/loglikelihood.py @@ -4,7 +4,7 @@ import torch -from models.demos.mamba.reference.decode_model import MambaDecode +from models.demos.wormhole.mamba.reference.decode_model import MambaDecode def compute_loglikelihood(logits, labels) -> float: diff --git a/models/demos/mamba/demo/demo.py b/models/demos/wormhole/mamba/demo/demo.py similarity index 94% rename from models/demos/mamba/demo/demo.py rename to models/demos/wormhole/mamba/demo/demo.py index 0b7913972b89..f490f1f61a4e 100644 --- a/models/demos/mamba/demo/demo.py +++ b/models/demos/wormhole/mamba/demo/demo.py @@ -12,14 +12,14 @@ from transformers import AutoTokenizer -from models.demos.mamba.reference.decode_model import MambaPretrainedModelName -from models.demos.mamba.reference.args import ModelMode -from models.demos.mamba.tt import model_config -from models.demos.mamba.tt.preprocessing import split_sequence_length +from models.demos.wormhole.mamba.reference.decode_model import MambaPretrainedModelName +from models.demos.wormhole.mamba.reference.args import ModelMode +from models.demos.wormhole.mamba.tt import model_config +from models.demos.wormhole.mamba.tt.preprocessing import split_sequence_length def get_cpu_reference_model(version: MambaPretrainedModelName, batch_size: int): - from models.demos.mamba.reference.decode_model import MambaDecode + from models.demos.wormhole.mamba.reference.decode_model import MambaDecode return MambaDecode.from_pretrained(version, batch_size=batch_size) @@ -32,8 +32,8 @@ def get_tt_metal_model( mode: ModelMode = ModelMode.DECODE, seq_len: int = 1, ): - from models.demos.mamba.tt.full_model import MambaTT - from models.demos.mamba.tt import model_config + from models.demos.wormhole.mamba.tt.full_model import MambaTT + from models.demos.wormhole.mamba.tt import model_config reference_model = get_cpu_reference_model(version, batch_size=batch_size) config = model_config.create_model_config(batch_size, reference_model.args.d_model, mode=mode, seq_len=seq_len) diff --git a/models/demos/mamba/demo/prompts.json b/models/demos/wormhole/mamba/demo/prompts.json similarity index 100% rename from models/demos/mamba/demo/prompts.json rename to models/demos/wormhole/mamba/demo/prompts.json diff --git a/models/demos/mamba/reference/args.py b/models/demos/wormhole/mamba/reference/args.py similarity index 100% rename from models/demos/mamba/reference/args.py rename to models/demos/wormhole/mamba/reference/args.py diff --git a/models/demos/mamba/reference/decode_model.py b/models/demos/wormhole/mamba/reference/decode_model.py similarity index 99% rename from models/demos/mamba/reference/decode_model.py rename to models/demos/wormhole/mamba/reference/decode_model.py index 491bb84fe9fd..041c5175c892 100644 --- a/models/demos/mamba/reference/decode_model.py +++ b/models/demos/wormhole/mamba/reference/decode_model.py @@ -46,7 +46,7 @@ import torch.nn.functional as F from einops import rearrange, repeat, einsum -from models.demos.mamba.reference.args import ModelArgs +from models.demos.wormhole.mamba.reference.args import ModelArgs from typing import Literal, cast diff --git a/models/demos/mamba/reference/model.py b/models/demos/wormhole/mamba/reference/model.py similarity index 100% rename from models/demos/mamba/reference/model.py rename to models/demos/wormhole/mamba/reference/model.py diff --git a/models/demos/mamba/reference/prefill_decode_model.py b/models/demos/wormhole/mamba/reference/prefill_decode_model.py similarity index 99% rename from models/demos/mamba/reference/prefill_decode_model.py rename to models/demos/wormhole/mamba/reference/prefill_decode_model.py index d986bf9d8b37..34bde838a5f7 100644 --- a/models/demos/mamba/reference/prefill_decode_model.py +++ b/models/demos/wormhole/mamba/reference/prefill_decode_model.py @@ -46,7 +46,7 @@ import torch.nn.functional as F from einops import rearrange, repeat, einsum -from models.demos.mamba.reference.args import ModelArgs, ModelMode +from models.demos.wormhole.mamba.reference.args import ModelArgs, ModelMode from typing import Literal, cast diff --git a/models/demos/mamba/tests/test_benchmarks.py b/models/demos/wormhole/mamba/tests/test_benchmarks.py similarity index 93% rename from models/demos/mamba/tests/test_benchmarks.py rename to models/demos/wormhole/mamba/tests/test_benchmarks.py index b01090e33f7a..65e9e45db5e0 100644 --- a/models/demos/mamba/tests/test_benchmarks.py +++ b/models/demos/wormhole/mamba/tests/test_benchmarks.py @@ -5,8 +5,8 @@ import pytest import torch -from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName -from models.demos.mamba.benchmarks.loglikelihood import ( +from models.demos.wormhole.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName +from models.demos.wormhole.mamba.benchmarks.loglikelihood import ( compute_loglikelihood, compute_loglikelihood_given_prompt_and_target, ) diff --git a/models/demos/mamba/tests/test_cache.py b/models/demos/wormhole/mamba/tests/test_cache.py similarity index 96% rename from models/demos/mamba/tests/test_cache.py rename to models/demos/wormhole/mamba/tests/test_cache.py index ac4b1a9ddd7e..cd5c05f560c0 100644 --- a/models/demos/mamba/tests/test_cache.py +++ b/models/demos/wormhole/mamba/tests/test_cache.py @@ -6,7 +6,7 @@ import ttnn import torch -from models.demos.mamba.tt.cache import TensorCache +from models.demos.wormhole.mamba.tt.cache import TensorCache from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_pcc, diff --git a/models/demos/mamba/tests/test_full_model.py b/models/demos/wormhole/mamba/tests/test_full_model.py similarity index 93% rename from models/demos/mamba/tests/test_full_model.py rename to models/demos/wormhole/mamba/tests/test_full_model.py index f388dfea9e6f..c215b1c60a06 100644 --- a/models/demos/mamba/tests/test_full_model.py +++ b/models/demos/wormhole/mamba/tests/test_full_model.py @@ -9,11 +9,11 @@ from transformers import AutoTokenizer from typing import Optional import ttnn -from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName -from models.demos.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName -from models.demos.mamba.reference.args import ModelMode -from models.demos.mamba.tt.full_model import MambaTT -from models.demos.mamba.tt import model_config +from models.demos.wormhole.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName +from models.demos.wormhole.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName +from models.demos.wormhole.mamba.reference.args import ModelMode +from models.demos.wormhole.mamba.tt.full_model import MambaTT +from models.demos.wormhole.mamba.tt import model_config from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_allclose, comp_pcc, diff --git a/models/demos/mamba/tests/test_mamba_block.py b/models/demos/wormhole/mamba/tests/test_mamba_block.py similarity index 87% rename from models/demos/mamba/tests/test_mamba_block.py rename to models/demos/wormhole/mamba/tests/test_mamba_block.py index d67c04e69d57..acb193092613 100644 --- a/models/demos/mamba/tests/test_mamba_block.py +++ b/models/demos/wormhole/mamba/tests/test_mamba_block.py @@ -7,11 +7,11 @@ from loguru import logger from typing import Optional import ttnn -from models.demos.mamba.tt.full_model import TtTensorLoader -from models.demos.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName -from models.demos.mamba.reference.args import ModelMode -from models.demos.mamba.tt.mamba_block import TtMambaBlock -from models.demos.mamba.tt import model_config +from models.demos.wormhole.mamba.tt.full_model import TtTensorLoader +from models.demos.wormhole.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName +from models.demos.wormhole.mamba.reference.args import ModelMode +from models.demos.wormhole.mamba.tt.mamba_block import TtMambaBlock +from models.demos.wormhole.mamba.tt import model_config from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_allclose, comp_pcc, diff --git a/models/demos/mamba/tests/test_mamba_demo.py b/models/demos/wormhole/mamba/tests/test_mamba_demo.py similarity index 95% rename from models/demos/mamba/tests/test_mamba_demo.py rename to models/demos/wormhole/mamba/tests/test_mamba_demo.py index 29f864fceeb0..9752705414bf 100644 --- a/models/demos/mamba/tests/test_mamba_demo.py +++ b/models/demos/wormhole/mamba/tests/test_mamba_demo.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 -from models.demos.mamba.demo.demo import run_mamba_demo, run_mamba_prefill_decode_demo +from models.demos.wormhole.mamba.demo.demo import run_mamba_demo, run_mamba_prefill_decode_demo import pytest diff --git a/models/demos/mamba/tests/test_mamba_perf.py b/models/demos/wormhole/mamba/tests/test_mamba_perf.py similarity index 95% rename from models/demos/mamba/tests/test_mamba_perf.py rename to models/demos/wormhole/mamba/tests/test_mamba_perf.py index dbd888f698b5..d257829302fc 100644 --- a/models/demos/mamba/tests/test_mamba_perf.py +++ b/models/demos/wormhole/mamba/tests/test_mamba_perf.py @@ -6,7 +6,7 @@ import time import json -from models.demos.mamba.demo.demo import ( +from models.demos.wormhole.mamba.demo.demo import ( get_tokenizer, get_cpu_reference_model, get_tt_metal_model, @@ -50,7 +50,7 @@ def test_mamba_e2e_perf( profiler.clear() # Load prompts - with open("models/demos/mamba/demo/prompts.json", "r") as f: + with open("models/demos/wormhole/mamba/demo/prompts.json", "r") as f: prompts = json.load(f) profiler.start("pytorch_ref_model_setup") @@ -138,7 +138,7 @@ def test_mamba_perf_device(batch, warmup, expected_device_fw_duration_ms, reset_ inference_iterations = 2 else: inference_iterations = 1 - command = f"pytest models/demos/mamba/tests/test_full_model.py::test_device_perf[{inference_iterations}]" + command = f"pytest models/demos/wormhole/mamba/tests/test_full_model.py::test_device_perf[{inference_iterations}]" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] # convert expected perf (ms) to samples/s diff --git a/models/demos/mamba/tests/test_mamba_ssm.py b/models/demos/wormhole/mamba/tests/test_mamba_ssm.py similarity index 87% rename from models/demos/mamba/tests/test_mamba_ssm.py rename to models/demos/wormhole/mamba/tests/test_mamba_ssm.py index 34c3b173d162..cfe3869f1e51 100644 --- a/models/demos/mamba/tests/test_mamba_ssm.py +++ b/models/demos/wormhole/mamba/tests/test_mamba_ssm.py @@ -7,11 +7,11 @@ from loguru import logger from typing import Optional import ttnn -from models.demos.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName -from models.demos.mamba.reference.args import ModelMode -from models.demos.mamba.tt.full_model import TtTensorLoader -from models.demos.mamba.tt.mamba_one_step_ssm import TtMambaSSM -from models.demos.mamba.tt import model_config +from models.demos.wormhole.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName +from models.demos.wormhole.mamba.reference.args import ModelMode +from models.demos.wormhole.mamba.tt.full_model import TtTensorLoader +from models.demos.wormhole.mamba.tt.mamba_one_step_ssm import TtMambaSSM +from models.demos.wormhole.mamba.tt import model_config from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_allclose, comp_pcc, diff --git a/models/demos/mamba/tests/test_preprocessing.py b/models/demos/wormhole/mamba/tests/test_preprocessing.py similarity index 95% rename from models/demos/mamba/tests/test_preprocessing.py rename to models/demos/wormhole/mamba/tests/test_preprocessing.py index ae807a4f383a..f702bd4002bb 100644 --- a/models/demos/mamba/tests/test_preprocessing.py +++ b/models/demos/wormhole/mamba/tests/test_preprocessing.py @@ -11,7 +11,7 @@ comp_allclose, comp_pcc, ) -from models.demos.mamba.tt.preprocessing import split_sequence_length +from models.demos.wormhole.mamba.tt.preprocessing import split_sequence_length @pytest.mark.parametrize( diff --git a/models/demos/mamba/tests/test_reference_model.py b/models/demos/wormhole/mamba/tests/test_reference_model.py similarity index 89% rename from models/demos/mamba/tests/test_reference_model.py rename to models/demos/wormhole/mamba/tests/test_reference_model.py index 6505a04b31f6..b4160db33001 100644 --- a/models/demos/mamba/tests/test_reference_model.py +++ b/models/demos/wormhole/mamba/tests/test_reference_model.py @@ -8,9 +8,9 @@ from typing import Optional from transformers import AutoTokenizer -from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName -from models.demos.mamba.reference.prefill_decode_model import Mamba as MambaPrefillDecode -from models.demos.mamba.reference.model import Mamba +from models.demos.wormhole.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName +from models.demos.wormhole.mamba.reference.prefill_decode_model import Mamba as MambaPrefillDecode +from models.demos.wormhole.mamba.reference.model import Mamba def generate_through_selective_scan( diff --git a/models/demos/mamba/tests/test_residual_block.py b/models/demos/wormhole/mamba/tests/test_residual_block.py similarity index 87% rename from models/demos/mamba/tests/test_residual_block.py rename to models/demos/wormhole/mamba/tests/test_residual_block.py index 9d6bcbf36066..700f5427c971 100644 --- a/models/demos/mamba/tests/test_residual_block.py +++ b/models/demos/wormhole/mamba/tests/test_residual_block.py @@ -7,11 +7,11 @@ from loguru import logger from typing import Optional import ttnn -from models.demos.mamba.tt.full_model import TtTensorLoader -from models.demos.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName -from models.demos.mamba.reference.args import ModelMode -from models.demos.mamba.tt.residual_block import TtResidualBlock -from models.demos.mamba.tt import model_config +from models.demos.wormhole.mamba.tt.full_model import TtTensorLoader +from models.demos.wormhole.mamba.reference.prefill_decode_model import Mamba, MambaPretrainedModelName +from models.demos.wormhole.mamba.reference.args import ModelMode +from models.demos.wormhole.mamba.tt.residual_block import TtResidualBlock +from models.demos.wormhole.mamba.tt import model_config from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_allclose, comp_pcc, diff --git a/models/demos/mamba/tt/cache.py b/models/demos/wormhole/mamba/tt/cache.py similarity index 100% rename from models/demos/mamba/tt/cache.py rename to models/demos/wormhole/mamba/tt/cache.py diff --git a/models/demos/mamba/tt/full_model.py b/models/demos/wormhole/mamba/tt/full_model.py similarity index 97% rename from models/demos/mamba/tt/full_model.py rename to models/demos/wormhole/mamba/tt/full_model.py index e7864861cf79..bb62e5f5b0f2 100644 --- a/models/demos/mamba/tt/full_model.py +++ b/models/demos/wormhole/mamba/tt/full_model.py @@ -11,8 +11,8 @@ from pathlib import Path from typing import Callable, Optional -from models.demos.mamba.tt.residual_block import TtResidualBlock -from models.demos.mamba.reference.args import ModelMode +from models.demos.wormhole.mamba.tt.residual_block import TtResidualBlock +from models.demos.wormhole.mamba.reference.args import ModelMode class TtTensorLoader: diff --git a/models/demos/mamba/tt/mamba_block.py b/models/demos/wormhole/mamba/tt/mamba_block.py similarity index 97% rename from models/demos/mamba/tt/mamba_block.py rename to models/demos/wormhole/mamba/tt/mamba_block.py index 9c521ab2d838..ab8445afa064 100644 --- a/models/demos/mamba/tt/mamba_block.py +++ b/models/demos/wormhole/mamba/tt/mamba_block.py @@ -8,10 +8,10 @@ import tt_lib as ttl from typing import Callable -from models.demos.mamba.reference.args import ModelArgs -from models.demos.mamba.reference.args import ModelMode -from models.demos.mamba.tt.mamba_one_step_ssm import TtMambaSSM -from models.demos.mamba.tt.cache import TensorCache +from models.demos.wormhole.mamba.reference.args import ModelArgs +from models.demos.wormhole.mamba.reference.args import ModelMode +from models.demos.wormhole.mamba.tt.mamba_one_step_ssm import TtMambaSSM +from models.demos.wormhole.mamba.tt.cache import TensorCache class TtMambaBlock(torch.nn.Module): diff --git a/models/demos/mamba/tt/mamba_one_step_ssm.py b/models/demos/wormhole/mamba/tt/mamba_one_step_ssm.py similarity index 98% rename from models/demos/mamba/tt/mamba_one_step_ssm.py rename to models/demos/wormhole/mamba/tt/mamba_one_step_ssm.py index 5972b1c057ca..e26227fb0a1e 100644 --- a/models/demos/mamba/tt/mamba_one_step_ssm.py +++ b/models/demos/wormhole/mamba/tt/mamba_one_step_ssm.py @@ -8,8 +8,8 @@ import tt_lib as ttl from typing import Callable -from models.demos.mamba.reference.args import ModelArgs, ModelMode -from models.demos.mamba.tt.cache import TensorCache +from models.demos.wormhole.mamba.reference.args import ModelArgs, ModelMode +from models.demos.wormhole.mamba.tt.cache import TensorCache class TtMambaSSM(torch.nn.Module): diff --git a/models/demos/mamba/tt/model_config.py b/models/demos/wormhole/mamba/tt/model_config.py similarity index 98% rename from models/demos/mamba/tt/model_config.py rename to models/demos/wormhole/mamba/tt/model_config.py index bdf085510eff..662222dc654c 100644 --- a/models/demos/mamba/tt/model_config.py +++ b/models/demos/wormhole/mamba/tt/model_config.py @@ -4,7 +4,8 @@ import ttnn import os -from models.demos.mamba.reference.args import ModelMode + +from models.demos.wormhole.mamba.reference.args import ModelMode def create_model_config(batch_size, hidden_size, mode=ModelMode.DECODE, seq_len=1): diff --git a/models/demos/mamba/tt/preprocessing.py b/models/demos/wormhole/mamba/tt/preprocessing.py similarity index 100% rename from models/demos/mamba/tt/preprocessing.py rename to models/demos/wormhole/mamba/tt/preprocessing.py diff --git a/models/demos/mamba/tt/residual_block.py b/models/demos/wormhole/mamba/tt/residual_block.py similarity index 93% rename from models/demos/mamba/tt/residual_block.py rename to models/demos/wormhole/mamba/tt/residual_block.py index 0bf3fbac0285..c4d111fc9d3d 100644 --- a/models/demos/mamba/tt/residual_block.py +++ b/models/demos/wormhole/mamba/tt/residual_block.py @@ -8,8 +8,8 @@ from typing import Callable -from models.demos.mamba.reference.args import ModelArgs -from models.demos.mamba.tt.mamba_block import TtMambaBlock +from models.demos.wormhole.mamba.reference.args import ModelArgs +from models.demos.wormhole.mamba.tt.mamba_block import TtMambaBlock class TtResidualBlock(torch.nn.Module): diff --git a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_benchmarks.py b/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_benchmarks.py deleted file mode 120000 index 1b727ea63bea..000000000000 --- a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_benchmarks.py +++ /dev/null @@ -1 +0,0 @@ -../../../../../../../models/demos/mamba/tests/test_benchmarks.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_cache.py b/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_cache.py deleted file mode 120000 index 526ba62eb781..000000000000 --- a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_cache.py +++ /dev/null @@ -1 +0,0 @@ -../../../../../../../models/demos/mamba/tests/test_cache.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_full_model.py b/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_full_model.py deleted file mode 120000 index 607b21c241ca..000000000000 --- a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_full_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../../../../../models/demos/mamba/tests/test_full_model.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_block.py b/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_block.py deleted file mode 120000 index a4d1345b156e..000000000000 --- a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_block.py +++ /dev/null @@ -1 +0,0 @@ -../../../../../../../models/demos/mamba/tests/test_mamba_block.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_demo.py b/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_demo.py deleted file mode 120000 index b79b60213b45..000000000000 --- a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_demo.py +++ /dev/null @@ -1 +0,0 @@ -../../../../../../../models/demos/mamba/tests/test_mamba_demo.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_ssm.py b/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_ssm.py deleted file mode 120000 index d29b061a6cb3..000000000000 --- a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_mamba_ssm.py +++ /dev/null @@ -1 +0,0 @@ -../../../../../../../models/demos/mamba/tests/test_mamba_ssm.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_preprocessing.py b/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_preprocessing.py deleted file mode 120000 index a08dfcf11212..000000000000 --- a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_preprocessing.py +++ /dev/null @@ -1 +0,0 @@ -../../../../../../../models/demos/mamba/tests/test_preprocessing.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_reference_model.py b/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_reference_model.py deleted file mode 120000 index 272e796ae38a..000000000000 --- a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_reference_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../../../../../models/demos/mamba/tests/test_reference_model.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_residual_block.py b/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_residual_block.py deleted file mode 120000 index e96177f04568..000000000000 --- a/tests/nightly/wh_b0_only_eth/models/demos/mamba/tests/test_residual_block.py +++ /dev/null @@ -1 +0,0 @@ -../../../../../../../models/demos/mamba/tests/test_residual_block.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_benchmarks.py b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_benchmarks.py new file mode 120000 index 000000000000..4a3e58facc23 --- /dev/null +++ b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_benchmarks.py @@ -0,0 +1 @@ +../../../../../../../models/demos/wormhole/mamba/tests/test_benchmarks.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_cache.py b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_cache.py new file mode 120000 index 000000000000..46c7a7317e26 --- /dev/null +++ b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_cache.py @@ -0,0 +1 @@ +../../../../../../../models/demos/wormhole/mamba/tests/test_cache.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_full_model.py b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_full_model.py new file mode 120000 index 000000000000..a96840622e75 --- /dev/null +++ b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_full_model.py @@ -0,0 +1 @@ +../../../../../../../models/demos/wormhole/mamba/tests/test_full_model.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_block.py b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_block.py new file mode 120000 index 000000000000..0019019c066f --- /dev/null +++ b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_block.py @@ -0,0 +1 @@ +../../../../../../../models/demos/wormhole/mamba/tests/test_mamba_block.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_demo.py b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_demo.py new file mode 120000 index 000000000000..11ccea316d0c --- /dev/null +++ b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_demo.py @@ -0,0 +1 @@ +../../../../../../../models/demos/wormhole/mamba/tests/test_mamba_demo.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_ssm.py b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_ssm.py new file mode 120000 index 000000000000..84ab661445a1 --- /dev/null +++ b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_mamba_ssm.py @@ -0,0 +1 @@ +../../../../../../../models/demos/wormhole/mamba/tests/test_mamba_ssm.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_preprocessing.py b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_preprocessing.py new file mode 120000 index 000000000000..73f85b0e4267 --- /dev/null +++ b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_preprocessing.py @@ -0,0 +1 @@ +../../../../../../../models/demos/wormhole/mamba/tests/test_preprocessing.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_reference_model.py b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_reference_model.py new file mode 120000 index 000000000000..226115b67001 --- /dev/null +++ b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_reference_model.py @@ -0,0 +1 @@ +../../../../../../../models/demos/wormhole/mamba/tests/test_reference_model.py \ No newline at end of file diff --git a/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_residual_block.py b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_residual_block.py new file mode 120000 index 000000000000..92ae92b970e1 --- /dev/null +++ b/tests/nightly/wh_b0_only_eth/models/demos/wormhole/mamba/test_residual_block.py @@ -0,0 +1 @@ +../../../../../../../models/demos/wormhole/mamba/tests/test_residual_block.py \ No newline at end of file