diff --git a/models/demos/wormhole/mistral7b/README.md b/models/demos/wormhole/mistral7b/README.md index 82202250b044..b798a0d0321c 100644 --- a/models/demos/wormhole/mistral7b/README.md +++ b/models/demos/wormhole/mistral7b/README.md @@ -1,47 +1,67 @@ # Mistral7B Demo -Demo showcasing Mistral-7B-instruct running on Wormhole, using ttnn. +Demo showcasing Mistral-7B running on Wormhole, using ttnn. ## How to Run -If you are running on a T3000 please set the following: +### Download the weights -`export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml` +Download the weights tarfile directly from Mistral-AI: +- General weights: [Mistral-7B-v0.1](https://models.mistralcdn.com/mistral-7b-v0-1/mistral-7B-v0.1.tar) +- Finetune instruct weights: [Mistral-7B-Instruct-v0.2](https://models.mistralcdn.com/mistral-7b-v0-2/Mistral-7B-v0.2-Instruct.tar) -To run the model for a single user you can use the command line input: +Both the above tarfiles consolidate the weights into a single file `consolidated.00.pth`. They also contain the tokenizer `tokenizer.model`. -`pytest --disable-warnings -q -s --input-method=cli --cli-input="YOUR PROMPT GOES HERE!" models/demos/wormhole/mistral7b/demo/demo.py` -To run the demo using pre-written prompts for a batch of 32 users run (currently only supports same token-length inputs): +### Set up environment -`pytest --disable-warnings -q -s --input-method=json --input-path='models/demos/wormhole/mistral7b/demo/input_data_questions.json' models/demos/wormhole/mistral7b/demo/demo.py` +1. Prepare the weight cache directory: +``` +# Make a directory for ttnn to cache weights into. This speeds up subsequent runs. +mkdir +``` -## Inputs +2. Set up environment variables: +``` +export MISTRAL_CKPT_DIR= +export MISTRAL_TOKENIZER_PATH= +export MISTRAL_CACHE_PATH= +``` -A sample of input prompts for 32 users is provided in `input_data_question.json` in the demo directory. These are to be used in instruct-mode (default). -We also provide another set of generative inputs `input_data.json` for generative-mode of open-ended generation. +A typical environment will have all the above point to the same folder. -If you wish you to run the model using a different set of input prompts you can provide a different path, e.g.: +Note that the cached weights folder structure will contain, after being generated, the general and instruct cached weights in separate directories, like so: -`pytest --disable-warnings -q -s --input-method=json --input-path='path_to_input_prompts.json' models/demos/wormhole/mistral7b/demo/demo.py` +``` + + /mistral_tensor_cache_bfp8 + /mistral_tensor_cache_instruct_bfp8 + ... +``` -Keep in mind that for the instruct-mode, the prompts are automatically prefixed and suffixed by `[INST]` and `[/INST]`, respectively. +3. Cache the weights (first-time setup). +If the cached weights have not yet been created the first execution will take care of generating them. You can run the model test for this step: +``` +# Build a full 32 layer model to cache the weights. This will take some time (1 time only). +pytest models/demos/wormhole/mistral7b/tests/test_mistral_model.py::test_mistral_model_inference[17-generative] +``` -## Details +### Run the demo -This model can be used with the general weights from Mistral-AI [Mistral-7B-v0.1](https://models.mistralcdn.com/mistral-7b-v0-1/mistral-7B-v0.1.tar) or the instruct weights - [Mistral-7B-Instruct-v0.2](https://models.mistralcdn.com/mistral-7b-v0-2/Mistral-7B-v0.2-Instruct.tar). +Mistral-7B is running on a single chip. If you are running on a T3000 please set the following: `export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml` -Both these weights are consolidated into a single file `consolidated.00.pth`. -Keep in mind that the demo code expects the instruct weights to be named `consolidated_instruct.00.pth` instead, and the tokenizer to be named `tokenizer_instruct.model`. +``` +# Run the demo with a pre-written batch of 32 user prompts: +pytest --collect-only models/demos/wormhole/mistral7b/demo/demo.py::test_demo[general_weights] +``` -You can provide a custom path to the folder containing the weights by adding the path argument to `TtModelArgs(model_base_path=)`. +We also provide an input file with 32 user question-prompt for instruct weights (don't forget to update your env flags to the correct instruct weights folder): +``` +pytest --collect-only models/demos/wormhole/mistral7b/demo/demo.py::test_demo[instruct_weights] +``` -For more configuration settings, please check the file `tt/model_config.py`. +Both input files are provided inside `models/demos/wormhole/mistral7b/demo/`. -The `demo.py` code is set to run in instruct-mode by default. Change the hardcoded flag inside the code for the general weights. -The `test_mistral_model.py` is currently parametrized to choose between the general generative weights or the instruct weights. - -The first time you run the model, the weights will be processed into the target data type and stored on your machine, which will take a few minutes for the full model. In future runs, the weights will be loaded from your machine and it will be faster. +If you wish you to run the model using a different set of input prompts you can provide a different path to pytest inside the demo code. Keep in mind that for the instruct-mode, the prompts are automatically prefixed and suffixed by `[INST]` and `[/INST]`, respectively, so there's no need to add them to your file. diff --git a/models/demos/wormhole/mistral7b/demo/demo.py b/models/demos/wormhole/mistral7b/demo/demo.py index ab3f15eda8e3..02660e4ff9b0 100644 --- a/models/demos/wormhole/mistral7b/demo/demo.py +++ b/models/demos/wormhole/mistral7b/demo/demo.py @@ -6,7 +6,18 @@ import json from time import time from loguru import logger +import os + +# Set Mistral flags for CI, if CI environment is setup +if os.getenv("CI") == "true": + os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1" + os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml" + import ttnn +import pytest from models.demos.wormhole.mistral7b.tt.mistral_common import ( prepare_inputs_ttnn, sample, @@ -84,10 +95,9 @@ def preprocess_inputs(input_prompts, tokenizer, model_args, dtype, embd, instruc return emb_inputs, pt_tokenized_inputs, input_mask, rot_emb_matrix_list -def run_mistral_demo(user_input, batch_size, device): +def run_mistral_demo(user_input, batch_size, device, instruct_mode): assert batch_size == 32, "Batch size must be 32" - instruct_mode = True embed_on_device = False dtype = ttnn.bfloat8_b @@ -98,10 +108,11 @@ def run_mistral_demo(user_input, batch_size, device): input_prompts = load_inputs(user_input, 32) # Load model args, weights, and tokenizer - # Specify model_base_path= below to use your own weights - model_args = TtModelArgs(device, instruct=instruct_mode) # TtModelArgs(model_base_path=) + model_args = TtModelArgs(device, instruct=instruct_mode) tokenizer = Tokenizer(model_args.tokenizer_path) + model_args.n_layers = 1 + logger.info("Loading weights...") state_dict = torch.load(model_args.consolidated_weights_path) state_dict = { @@ -140,7 +151,7 @@ def run_mistral_demo(user_input, batch_size, device): device=device, dtype=dtype, state_dict=state_dict, - weight_cache_path=model_args.weight_cache_path(dtype, instruct=instruct_mode), + weight_cache_path=model_args.weight_cache_path(dtype), layers=list(range(model_args.n_layers)), rot_mat=rot_emb_matrix_list, start_pos=generation_start_pos, @@ -148,7 +159,7 @@ def run_mistral_demo(user_input, batch_size, device): tt_embd = TtMistralEmbedding( device=device, args=model_args, - weight_cache_path=model_args.weight_cache_path(dtype, instruct=instruct_mode), + weight_cache_path=model_args.weight_cache_path(dtype), state_dict=state_dict, dtype=ttnn.bfloat16, # Row major layout requires bfloat16 ) @@ -241,5 +252,13 @@ def run_mistral_demo(user_input, batch_size, device): users_decoding = False -def test_demo(user_input, device, use_program_cache): - return run_mistral_demo(user_input=user_input, batch_size=32, device=device) +@pytest.mark.parametrize( + "input_prompts, instruct_weights", + [ + ("models/demos/wormhole/mistral7b/demo/input_data.json", False), + ("models/demos/wormhole/mistral7b/demo/input_data_questions.json", True), + ], + ids=["general_weights", "instruct_weights"], +) +def test_demo(device, use_program_cache, input_prompts, instruct_weights): + return run_mistral_demo(user_input=input_prompts, batch_size=32, device=device, instruct_mode=instruct_weights) diff --git a/models/demos/wormhole/mistral7b/tests/test_mistral_attention.py b/models/demos/wormhole/mistral7b/tests/test_mistral_attention.py index a63dc745657a..fb12d41bc60a 100644 --- a/models/demos/wormhole/mistral7b/tests/test_mistral_attention.py +++ b/models/demos/wormhole/mistral7b/tests/test_mistral_attention.py @@ -4,6 +4,15 @@ import torch import pytest from loguru import logger +import os + +# Set Mistral flags for CI, if CI environment is setup +if os.getenv("CI") == "true": + os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1" + os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml" import ttnn from models.demos.wormhole.mistral7b.tt.mistral_attention import TtMistralAttention diff --git a/models/demos/wormhole/mistral7b/tests/test_mistral_decoder.py b/models/demos/wormhole/mistral7b/tests/test_mistral_decoder.py index 11f262037118..f9231228cc21 100644 --- a/models/demos/wormhole/mistral7b/tests/test_mistral_decoder.py +++ b/models/demos/wormhole/mistral7b/tests/test_mistral_decoder.py @@ -4,6 +4,16 @@ import torch import pytest from loguru import logger +import os + +# Set Mistral flags for CI, if CI environment is setup +if os.getenv("CI") == "true": + os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1" + os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml" + import ttnn from models.demos.wormhole.mistral7b.tt.mistral_common import ( precompute_freqs, diff --git a/models/demos/wormhole/mistral7b/tests/test_mistral_embedding.py b/models/demos/wormhole/mistral7b/tests/test_mistral_embedding.py index e7a3653d5522..ea20ca0cd0d6 100644 --- a/models/demos/wormhole/mistral7b/tests/test_mistral_embedding.py +++ b/models/demos/wormhole/mistral7b/tests/test_mistral_embedding.py @@ -4,6 +4,16 @@ import torch import pytest from loguru import logger +import os + +# Set Mistral flags for CI, if CI environment is setup +if os.getenv("CI") == "true": + os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1" + os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml" + import ttnn from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs from models.demos.wormhole.mistral7b.tt.mistral_embedding import TtMistralEmbedding diff --git a/models/demos/wormhole/mistral7b/tests/test_mistral_mlp.py b/models/demos/wormhole/mistral7b/tests/test_mistral_mlp.py index f8abb7bbc985..2ce2aae85ecb 100644 --- a/models/demos/wormhole/mistral7b/tests/test_mistral_mlp.py +++ b/models/demos/wormhole/mistral7b/tests/test_mistral_mlp.py @@ -5,6 +5,16 @@ import torch import pytest from loguru import logger +import os + +# Set Mistral flags for CI, if CI environment is setup +if os.getenv("CI") == "true": + os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1" + os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml" + import ttnn from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs from models.demos.wormhole.mistral7b.tt.mistral_mlp import TtMistralMLP diff --git a/models/demos/wormhole/mistral7b/tests/test_mistral_model.py b/models/demos/wormhole/mistral7b/tests/test_mistral_model.py index 40a5e278d31d..c7e354b16008 100644 --- a/models/demos/wormhole/mistral7b/tests/test_mistral_model.py +++ b/models/demos/wormhole/mistral7b/tests/test_mistral_model.py @@ -57,7 +57,7 @@ def test_mistral_model_inference(device, iterations, version, use_program_cache, dtype = ttnn.bfloat8_b - model_args = TtModelArgs(device, instruct=instruct) + model_args = TtModelArgs(device) model_args.max_batch_size = 32 model_args.n_layers = 32 # Full model @@ -112,7 +112,7 @@ def test_mistral_model_inference(device, iterations, version, use_program_cache, device=device, dtype=dtype, state_dict=state_dict, - weight_cache_path=model_args.weight_cache_path(dtype, instruct=instruct), + weight_cache_path=model_args.weight_cache_path(dtype), layers=list(range(model_args.n_layers)), rot_mat=rot_emb_matrix_list, start_pos=generation_start_pos, diff --git a/models/demos/wormhole/mistral7b/tests/test_mistral_perf.py b/models/demos/wormhole/mistral7b/tests/test_mistral_perf.py index 32ee0acecb8d..979b9e6f1a00 100644 --- a/models/demos/wormhole/mistral7b/tests/test_mistral_perf.py +++ b/models/demos/wormhole/mistral7b/tests/test_mistral_perf.py @@ -4,6 +4,16 @@ import torch import pytest from loguru import logger +import os + +# Set Mistral flags for CI, if CI environment is setup +if os.getenv("CI") == "true": + os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1" + os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml" + import ttnn from models.demos.wormhole.mistral7b.tt.mistral_common import ( precompute_freqs, diff --git a/models/demos/wormhole/mistral7b/tests/test_mistral_rms_norm.py b/models/demos/wormhole/mistral7b/tests/test_mistral_rms_norm.py index 07472ac2daf5..84e1acb60986 100644 --- a/models/demos/wormhole/mistral7b/tests/test_mistral_rms_norm.py +++ b/models/demos/wormhole/mistral7b/tests/test_mistral_rms_norm.py @@ -4,6 +4,16 @@ import torch import pytest from loguru import logger +import os + +# Set Mistral flags for CI, if CI environment is setup +if os.getenv("CI") == "true": + os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1" + os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml" + import ttnn from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs from models.demos.wormhole.mistral7b.tt.mistral_rms_norm import TtRMSNorm diff --git a/models/demos/wormhole/mistral7b/tt/model_config.py b/models/demos/wormhole/mistral7b/tt/model_config.py index 765233d13834..1c89029c22e4 100644 --- a/models/demos/wormhole/mistral7b/tt/model_config.py +++ b/models/demos/wormhole/mistral7b/tt/model_config.py @@ -2,9 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 +import os import ttnn from pathlib import Path from models.utility_functions import is_wormhole_b0 +from loguru import logger +import tarfile +import urllib.request class TtModelArgs: @@ -25,6 +29,11 @@ class TtModelArgs: max_seq_len = 4096 kv_seq_len = 1024 # TODO Update the initial cache size when scaling up (Should be window_size == 4096) + # Default folder location for weights and cached files + DEFAULT_CKPT_DIR = os.getenv("MISTRAL_CKPT_DIR", "/proj_sw/user_dev/hf_data/mistral/mistral-7B-v0.1/") + DEFAULT_TOKENIZER_PATH = os.getenv("MISTRAL_TOKENIZER_PATH", "/proj_sw/user_dev/hf_data/mistral/mistral-7B-v0.1/") + DEFAULT_CACHE_PATH = os.getenv("MISTRAL_CACHE_PATH", "/proj_sw/user_dev/hf_data/mistral/mistral-7B-v0.1/") + OP_KEYS = ( # Embedding "EMB_WEIGHTS", @@ -49,15 +58,35 @@ class TtModelArgs: "DEC_SKIP_OUTPUT", ) - def __init__(self, device, model_base_path="/mnt/MLPerf/ttnn/models/demos/mistral7b", instruct=False): - self.model_base_path = Path(model_base_path) + def __init__(self, device, instruct=False): + # Assert if all folders and files exist + assert os.path.exists( + self.DEFAULT_CKPT_DIR + ), f"Checkpoint directory {self.DEFAULT_CKPT_DIR} does not exist, please use export MISTRAL_CKPT_DIR=..." + assert os.path.isfile( + self.DEFAULT_TOKENIZER_PATH + "/tokenizer.model" + ), f"Tokenizer file {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'} does not exist, please use export MISTRAL_TOKENIZER_PATH=..." + assert os.path.exists( + self.DEFAULT_CACHE_PATH + ), f"Cache directory {self.DEFAULT_CACHE_PATH} does not exist, please use export MISTRAL_CACHE_PATH=..." + # Check if weights exist in the specified folder. If not warn the user to run the download and untar script. + assert os.path.isfile( + self.DEFAULT_CACHE_PATH + "/consolidated.00.pth" + ), f"weights consolidated.00.pth file does not exist. Please use the script `models/demos/wormhole/mistral7b/scripts/get_weights.py` to download and untar the weights." + + logger.info(f"Checkpoint directory: {self.DEFAULT_CKPT_DIR}") + logger.info(f"Tokenizer file: {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'}") + logger.info(f"Cache directory: {self.DEFAULT_CACHE_PATH}") + # Some consumers like SentencePiece only accept str not Path for files - if instruct: # Load instruct weights and tokenizer (Mistral-7B-Instruct-v0.2) - self.consolidated_weights_path = str(self.model_base_path / "consolidated_instruct.00.pth") - self.tokenizer_path = str(self.model_base_path / "tokenizer_instruct.model") - else: # Load generative weights and tokenizer (Mistral-7B-v0.1) - self.consolidated_weights_path = str(self.model_base_path / "consolidated.00.pth") - self.tokenizer_path = str(self.model_base_path / "tokenizer.model") + self.model_base_path = Path(self.DEFAULT_CKPT_DIR) + self.model_cache_path = Path(self.DEFAULT_CACHE_PATH) + + # Load weights and tokenizer + self.consolidated_weights_path = self.DEFAULT_CKPT_DIR + "/consolidated.00.pth" + self.tokenizer_path = self.DEFAULT_TOKENIZER_PATH + "/tokenizer.model" + + self.instruct = instruct DRAM_MEMCFG = ttnn.DRAM_MEMORY_CONFIG L1_MEMCFG = ttnn.L1_MEMORY_CONFIG @@ -116,16 +145,16 @@ def __init__(self, device, model_base_path="/mnt/MLPerf/ttnn/models/demos/mistra packer_l1_acc=True, ) - def weight_cache_path(self, dtype, instruct=False): + def weight_cache_path(self, dtype): # Keep the weight cache separate for generative and instruct weights - if instruct: + if self.instruct: return ( - self.model_base_path + self.model_cache_path / {ttnn.bfloat16: "tensor_cache_instruct_bf16", ttnn.bfloat8_b: "tensor_cache_instruct_bfp8"}[dtype] ) else: return ( - self.model_base_path / {ttnn.bfloat16: "tensor_cache_bf16", ttnn.bfloat8_b: "tensor_cache_bfp8"}[dtype] + self.model_cache_path / {ttnn.bfloat16: "tensor_cache_bf16", ttnn.bfloat8_b: "tensor_cache_bfp8"}[dtype] ) def get_model_config(self):