Skip to content

Commit

Permalink
#5337: Update Mistral-7B model config to support weight and cache fla…
Browse files Browse the repository at this point in the history
…gs. Update README
  • Loading branch information
mtairum committed Jun 5, 2024
1 parent 9bdbbe5 commit 0aa1a20
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 46 deletions.
68 changes: 44 additions & 24 deletions models/demos/wormhole/mistral7b/README.md
Original file line number Diff line number Diff line change
@@ -1,47 +1,67 @@
# Mistral7B Demo

Demo showcasing Mistral-7B-instruct running on Wormhole, using ttnn.
Demo showcasing Mistral-7B running on Wormhole, using ttnn.

## How to Run

If you are running on a T3000 please set the following:
### Download the weights

`export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml`
Download the weights tarfile directly from Mistral-AI:
- General weights: [Mistral-7B-v0.1](https://models.mistralcdn.com/mistral-7b-v0-1/mistral-7B-v0.1.tar)
- Finetune instruct weights: [Mistral-7B-Instruct-v0.2](https://models.mistralcdn.com/mistral-7b-v0-2/Mistral-7B-v0.2-Instruct.tar)

To run the model for a single user you can use the command line input:
Both the above tarfiles consolidate the weights into a single file `consolidated.00.pth`. They also contain the tokenizer `tokenizer.model`.

`pytest --disable-warnings -q -s --input-method=cli --cli-input="YOUR PROMPT GOES HERE!" models/demos/wormhole/mistral7b/demo/demo.py`

To run the demo using pre-written prompts for a batch of 32 users run (currently only supports same token-length inputs):
### Set up environment

`pytest --disable-warnings -q -s --input-method=json --input-path='models/demos/wormhole/mistral7b/demo/input_data_questions.json' models/demos/wormhole/mistral7b/demo/demo.py`
1. Prepare the weight cache directory:

```
# Make a directory for ttnn to cache weights into. This speeds up subsequent runs.
mkdir <weight_cache_dir>
```

## Inputs
2. Set up environment variables:
```
export MISTRAL_CKPT_DIR=<weights_dir>
export MISTRAL_TOKENIZER_PATH=<path_to_tokenizer_dir>
export MISTRAL_CACHE_PATH=<weights_cache_dir>
```

A sample of input prompts for 32 users is provided in `input_data_question.json` in the demo directory. These are to be used in instruct-mode (default).
We also provide another set of generative inputs `input_data.json` for generative-mode of open-ended generation.
A typical environment will have all the above point to the same folder.

If you wish you to run the model using a different set of input prompts you can provide a different path, e.g.:
Note that the cached weights folder structure will contain, after being generated, the general and instruct cached weights in separate directories, like so:

`pytest --disable-warnings -q -s --input-method=json --input-path='path_to_input_prompts.json' models/demos/wormhole/mistral7b/demo/demo.py`
```
<weights_cache_dir>
/mistral_tensor_cache_bfp8
/mistral_tensor_cache_instruct_bfp8
...
```

Keep in mind that for the instruct-mode, the prompts are automatically prefixed and suffixed by `[INST]` and `[/INST]`, respectively.
3. Cache the weights (first-time setup).
If the cached weights have not yet been created the first execution will take care of generating them. You can run the model test for this step:

```
# Build a full 32 layer model to cache the weights. This will take some time (1 time only).
pytest models/demos/wormhole/mistral7b/tests/test_mistral_model.py::test_mistral_model_inference[17-generative]
```

## Details
### Run the demo

This model can be used with the general weights from Mistral-AI [Mistral-7B-v0.1](https://models.mistralcdn.com/mistral-7b-v0-1/mistral-7B-v0.1.tar) or the instruct weights
[Mistral-7B-Instruct-v0.2](https://models.mistralcdn.com/mistral-7b-v0-2/Mistral-7B-v0.2-Instruct.tar).
Mistral-7B is running on a single chip. If you are running on a T3000 please set the following: `export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml`

Both these weights are consolidated into a single file `consolidated.00.pth`.
Keep in mind that the demo code expects the instruct weights to be named `consolidated_instruct.00.pth` instead, and the tokenizer to be named `tokenizer_instruct.model`.
```
# Run the demo with a pre-written batch of 32 user prompts:
pytest --collect-only models/demos/wormhole/mistral7b/demo/demo.py::test_demo[general_weights]
```

You can provide a custom path to the folder containing the weights by adding the path argument to `TtModelArgs(model_base_path=<weights_folder>)`.
We also provide an input file with 32 user question-prompt for instruct weights (don't forget to update your env flags to the correct instruct weights folder):
```
pytest --collect-only models/demos/wormhole/mistral7b/demo/demo.py::test_demo[instruct_weights]
```

For more configuration settings, please check the file `tt/model_config.py`.
Both input files are provided inside `models/demos/wormhole/mistral7b/demo/`.

The `demo.py` code is set to run in instruct-mode by default. Change the hardcoded flag inside the code for the general weights.
The `test_mistral_model.py` is currently parametrized to choose between the general generative weights or the instruct weights.

The first time you run the model, the weights will be processed into the target data type and stored on your machine, which will take a few minutes for the full model. In future runs, the weights will be loaded from your machine and it will be faster.
If you wish you to run the model using a different set of input prompts you can provide a different path to pytest inside the demo code. Keep in mind that for the instruct-mode, the prompts are automatically prefixed and suffixed by `[INST]` and `[/INST]`, respectively, so there's no need to add them to your file.
35 changes: 27 additions & 8 deletions models/demos/wormhole/mistral7b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
import json
from time import time
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1"
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

import ttnn
import pytest
from models.demos.wormhole.mistral7b.tt.mistral_common import (
prepare_inputs_ttnn,
sample,
Expand Down Expand Up @@ -84,10 +95,9 @@ def preprocess_inputs(input_prompts, tokenizer, model_args, dtype, embd, instruc
return emb_inputs, pt_tokenized_inputs, input_mask, rot_emb_matrix_list


def run_mistral_demo(user_input, batch_size, device):
def run_mistral_demo(user_input, batch_size, device, instruct_mode):
assert batch_size == 32, "Batch size must be 32"

instruct_mode = True
embed_on_device = False
dtype = ttnn.bfloat8_b

Expand All @@ -98,10 +108,11 @@ def run_mistral_demo(user_input, batch_size, device):
input_prompts = load_inputs(user_input, 32)

# Load model args, weights, and tokenizer
# Specify model_base_path=<MISTRAL_WEIGHTS_PATH> below to use your own weights
model_args = TtModelArgs(device, instruct=instruct_mode) # TtModelArgs(model_base_path=<weights_path>)
model_args = TtModelArgs(device, instruct=instruct_mode)
tokenizer = Tokenizer(model_args.tokenizer_path)

model_args.n_layers = 1

logger.info("Loading weights...")
state_dict = torch.load(model_args.consolidated_weights_path)
state_dict = {
Expand Down Expand Up @@ -140,15 +151,15 @@ def run_mistral_demo(user_input, batch_size, device):
device=device,
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype, instruct=instruct_mode),
weight_cache_path=model_args.weight_cache_path(dtype),
layers=list(range(model_args.n_layers)),
rot_mat=rot_emb_matrix_list,
start_pos=generation_start_pos,
)
tt_embd = TtMistralEmbedding(
device=device,
args=model_args,
weight_cache_path=model_args.weight_cache_path(dtype, instruct=instruct_mode),
weight_cache_path=model_args.weight_cache_path(dtype),
state_dict=state_dict,
dtype=ttnn.bfloat16, # Row major layout requires bfloat16
)
Expand Down Expand Up @@ -241,5 +252,13 @@ def run_mistral_demo(user_input, batch_size, device):
users_decoding = False


def test_demo(user_input, device, use_program_cache):
return run_mistral_demo(user_input=user_input, batch_size=32, device=device)
@pytest.mark.parametrize(
"input_prompts, instruct_weights",
[
("models/demos/wormhole/mistral7b/demo/input_data.json", False),
("models/demos/wormhole/mistral7b/demo/input_data_questions.json", True),
],
ids=["general_weights", "instruct_weights"],
)
def test_demo(device, use_program_cache, input_prompts, instruct_weights):
return run_mistral_demo(user_input=input_prompts, batch_size=32, device=device, instruct_mode=instruct_weights)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
import torch
import pytest
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1"
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

import ttnn
from models.demos.wormhole.mistral7b.tt.mistral_attention import TtMistralAttention
Expand Down
10 changes: 10 additions & 0 deletions models/demos/wormhole/mistral7b/tests/test_mistral_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
import torch
import pytest
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1"
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

import ttnn
from models.demos.wormhole.mistral7b.tt.mistral_common import (
precompute_freqs,
Expand Down
10 changes: 10 additions & 0 deletions models/demos/wormhole/mistral7b/tests/test_mistral_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
import torch
import pytest
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1"
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

import ttnn
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs
from models.demos.wormhole.mistral7b.tt.mistral_embedding import TtMistralEmbedding
Expand Down
10 changes: 10 additions & 0 deletions models/demos/wormhole/mistral7b/tests/test_mistral_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
import torch
import pytest
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1"
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

import ttnn
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs
from models.demos.wormhole.mistral7b.tt.mistral_mlp import TtMistralMLP
Expand Down
4 changes: 2 additions & 2 deletions models/demos/wormhole/mistral7b/tests/test_mistral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_mistral_model_inference(device, iterations, version, use_program_cache,

dtype = ttnn.bfloat8_b

model_args = TtModelArgs(device, instruct=instruct)
model_args = TtModelArgs(device)
model_args.max_batch_size = 32
model_args.n_layers = 32 # Full model

Expand Down Expand Up @@ -112,7 +112,7 @@ def test_mistral_model_inference(device, iterations, version, use_program_cache,
device=device,
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype, instruct=instruct),
weight_cache_path=model_args.weight_cache_path(dtype),
layers=list(range(model_args.n_layers)),
rot_mat=rot_emb_matrix_list,
start_pos=generation_start_pos,
Expand Down
10 changes: 10 additions & 0 deletions models/demos/wormhole/mistral7b/tests/test_mistral_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
import torch
import pytest
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1"
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

import ttnn
from models.demos.wormhole.mistral7b.tt.mistral_common import (
precompute_freqs,
Expand Down
10 changes: 10 additions & 0 deletions models/demos/wormhole/mistral7b/tests/test_mistral_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
import torch
import pytest
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1"
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

import ttnn
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs
from models.demos.wormhole.mistral7b.tt.mistral_rms_norm import TtRMSNorm
Expand Down
53 changes: 41 additions & 12 deletions models/demos/wormhole/mistral7b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

# SPDX-License-Identifier: Apache-2.0

import os
import ttnn
from pathlib import Path
from models.utility_functions import is_wormhole_b0
from loguru import logger
import tarfile
import urllib.request


class TtModelArgs:
Expand All @@ -25,6 +29,11 @@ class TtModelArgs:
max_seq_len = 4096
kv_seq_len = 1024 # TODO Update the initial cache size when scaling up (Should be window_size == 4096)

# Default folder location for weights and cached files
DEFAULT_CKPT_DIR = os.getenv("MISTRAL_CKPT_DIR", "/proj_sw/user_dev/hf_data/mistral/mistral-7B-v0.1/")
DEFAULT_TOKENIZER_PATH = os.getenv("MISTRAL_TOKENIZER_PATH", "/proj_sw/user_dev/hf_data/mistral/mistral-7B-v0.1/")
DEFAULT_CACHE_PATH = os.getenv("MISTRAL_CACHE_PATH", "/proj_sw/user_dev/hf_data/mistral/mistral-7B-v0.1/")

OP_KEYS = (
# Embedding
"EMB_WEIGHTS",
Expand All @@ -49,15 +58,35 @@ class TtModelArgs:
"DEC_SKIP_OUTPUT",
)

def __init__(self, device, model_base_path="/mnt/MLPerf/ttnn/models/demos/mistral7b", instruct=False):
self.model_base_path = Path(model_base_path)
def __init__(self, device, instruct=False):
# Assert if all folders and files exist
assert os.path.exists(
self.DEFAULT_CKPT_DIR
), f"Checkpoint directory {self.DEFAULT_CKPT_DIR} does not exist, please use export MISTRAL_CKPT_DIR=..."
assert os.path.isfile(
self.DEFAULT_TOKENIZER_PATH + "/tokenizer.model"
), f"Tokenizer file {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'} does not exist, please use export MISTRAL_TOKENIZER_PATH=..."
assert os.path.exists(
self.DEFAULT_CACHE_PATH
), f"Cache directory {self.DEFAULT_CACHE_PATH} does not exist, please use export MISTRAL_CACHE_PATH=..."
# Check if weights exist in the specified folder. If not warn the user to run the download and untar script.
assert os.path.isfile(
self.DEFAULT_CACHE_PATH + "/consolidated.00.pth"
), f"weights consolidated.00.pth file does not exist. Please use the script `models/demos/wormhole/mistral7b/scripts/get_weights.py` to download and untar the weights."

logger.info(f"Checkpoint directory: {self.DEFAULT_CKPT_DIR}")
logger.info(f"Tokenizer file: {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'}")
logger.info(f"Cache directory: {self.DEFAULT_CACHE_PATH}")

# Some consumers like SentencePiece only accept str not Path for files
if instruct: # Load instruct weights and tokenizer (Mistral-7B-Instruct-v0.2)
self.consolidated_weights_path = str(self.model_base_path / "consolidated_instruct.00.pth")
self.tokenizer_path = str(self.model_base_path / "tokenizer_instruct.model")
else: # Load generative weights and tokenizer (Mistral-7B-v0.1)
self.consolidated_weights_path = str(self.model_base_path / "consolidated.00.pth")
self.tokenizer_path = str(self.model_base_path / "tokenizer.model")
self.model_base_path = Path(self.DEFAULT_CKPT_DIR)
self.model_cache_path = Path(self.DEFAULT_CACHE_PATH)

# Load weights and tokenizer
self.consolidated_weights_path = self.DEFAULT_CKPT_DIR + "/consolidated.00.pth"
self.tokenizer_path = self.DEFAULT_TOKENIZER_PATH + "/tokenizer.model"

self.instruct = instruct

DRAM_MEMCFG = ttnn.DRAM_MEMORY_CONFIG
L1_MEMCFG = ttnn.L1_MEMORY_CONFIG
Expand Down Expand Up @@ -116,16 +145,16 @@ def __init__(self, device, model_base_path="/mnt/MLPerf/ttnn/models/demos/mistra
packer_l1_acc=True,
)

def weight_cache_path(self, dtype, instruct=False):
def weight_cache_path(self, dtype):
# Keep the weight cache separate for generative and instruct weights
if instruct:
if self.instruct:
return (
self.model_base_path
self.model_cache_path
/ {ttnn.bfloat16: "tensor_cache_instruct_bf16", ttnn.bfloat8_b: "tensor_cache_instruct_bfp8"}[dtype]
)
else:
return (
self.model_base_path / {ttnn.bfloat16: "tensor_cache_bf16", ttnn.bfloat8_b: "tensor_cache_bfp8"}[dtype]
self.model_cache_path / {ttnn.bfloat16: "tensor_cache_bf16", ttnn.bfloat8_b: "tensor_cache_bfp8"}[dtype]
)

def get_model_config(self):
Expand Down

0 comments on commit 0aa1a20

Please sign in to comment.