Skip to content

Commit

Permalink
#5337: Refactored all Mistral demo and test scripts to use CI fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
mtairum committed Jul 20, 2024
1 parent 52bcf58 commit 457bcd0
Show file tree
Hide file tree
Showing 13 changed files with 147 additions and 236 deletions.
38 changes: 11 additions & 27 deletions models/demos/wormhole/mistral7b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,6 @@
from time import time
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

import ttnn
import pytest
from models.demos.wormhole.mistral7b.tt.mistral_common import (
Expand All @@ -26,7 +18,6 @@
)
from models.demos.wormhole.mistral7b.tt.mistral_model import TtTransformer
from models.demos.wormhole.mistral7b.tt.mistral_embedding import TtMistralEmbedding
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs
from models.demos.wormhole.mistral7b.reference.tokenizer import Tokenizer


Expand Down Expand Up @@ -95,6 +86,15 @@ def preprocess_inputs(input_prompts, tokenizer, model_args, dtype, embd, instruc


def run_mistral_demo(user_input, batch_size, device, instruct_mode, is_ci_env):
# Set Mistral flags for CI
if is_ci_env:
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"

# This module requires the env paths above for CI runs
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs

assert batch_size == 32, "Batch size must be 32"

embed_on_device = False
Expand Down Expand Up @@ -269,8 +269,6 @@ def run_mistral_demo(user_input, batch_size, device, instruct_mode, is_ci_env):
logger.info("[User {}] {}".format(user, text))


# Avoid running this test when in CI
@pytest.mark.skipif(os.getenv("CI") == "true", reason="Non-CI tests")
@pytest.mark.parametrize(
"input_prompts, instruct_weights",
[
Expand All @@ -280,23 +278,9 @@ def run_mistral_demo(user_input, batch_size, device, instruct_mode, is_ci_env):
ids=["general_weights", "instruct_weights"],
)
def test_mistral7B_demo(device, use_program_cache, input_prompts, instruct_weights, is_ci_env):
return run_mistral_demo(
user_input=input_prompts, batch_size=32, device=device, instruct_mode=instruct_weights, is_ci_env=is_ci_env
)
if is_ci_env and instruct_weights == False:
pytest.skip("CI demo test only runs instruct weights (to reduce CI pipeline load)")


# CI only runs general-weights demo
@pytest.mark.skipif(not os.getenv("CI") == "true", reason="CI-only test")
@pytest.mark.parametrize(
"input_prompts, instruct_weights",
[
("models/demos/wormhole/mistral7b/demo/input_data.json", False),
],
ids=[
"general_weights",
],
)
def test_mistral7B_demo_CI(device, use_program_cache, input_prompts, instruct_weights, is_ci_env):
return run_mistral_demo(
user_input=input_prompts, batch_size=32, device=device, instruct_mode=instruct_weights, is_ci_env=is_ci_env
)
38 changes: 11 additions & 27 deletions models/demos/wormhole/mistral7b/demo/demo_with_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,6 @@
from time import time
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

import ttnn
import pytest
from models.demos.wormhole.mistral7b.tt.mistral_common import (
Expand All @@ -29,7 +21,6 @@
)
from models.demos.wormhole.mistral7b.tt.mistral_model import TtTransformer
from models.demos.wormhole.mistral7b.tt.mistral_embedding import TtMistralEmbedding
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs
from models.demos.wormhole.mistral7b.reference.tokenizer import Tokenizer


Expand Down Expand Up @@ -136,6 +127,15 @@ def preprocess_inputs_prefill(input_prompts, tokenizer, model_args, dtype, embd,


def run_mistral_demo(user_input, batch_size, device, instruct_mode, is_ci_env):
# Set Mistral flags for CI
if is_ci_env:
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"

# This module requires the env paths above for CI runs
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs

assert batch_size == 32, "Batch size must be 32"

embed_on_device = False
Expand Down Expand Up @@ -348,8 +348,6 @@ def run_mistral_demo(user_input, batch_size, device, instruct_mode, is_ci_env):
logger.info("[User {}] {}".format(user, text))


# Avoid running this test when in CI
@pytest.mark.skipif(os.getenv("CI") == "true", reason="Non-CI tests")
@pytest.mark.parametrize(
"input_prompts, instruct_weights",
[
Expand All @@ -359,23 +357,9 @@ def run_mistral_demo(user_input, batch_size, device, instruct_mode, is_ci_env):
ids=["general_weights", "instruct_weights"],
)
def test_mistral7B_demo(device, use_program_cache, input_prompts, instruct_weights, is_ci_env):
return run_mistral_demo(
user_input=input_prompts, batch_size=32, device=device, instruct_mode=instruct_weights, is_ci_env=is_ci_env
)
if is_ci_env and instruct_weights == False:
pytest.skip("CI demo test only runs instruct weights (to reduce CI pipeline load)")


# CI only runs general-weights demo
@pytest.mark.skipif(not os.getenv("CI") == "true", reason="CI-only test")
@pytest.mark.parametrize(
"input_prompts, instruct_weights",
[
("models/demos/wormhole/mistral7b/demo/input_data_questions_prefill_128.json", False),
],
ids=[
"general_weights",
],
)
def test_mistral7B_demo_CI(device, use_program_cache, input_prompts, instruct_weights, is_ci_env):
return run_mistral_demo(
user_input=input_prompts, batch_size=32, device=device, instruct_mode=instruct_weights, is_ci_env=is_ci_env
)
25 changes: 11 additions & 14 deletions models/demos/wormhole/mistral7b/tests/test_mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,13 @@
import pytest
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"

import ttnn
from models.demos.wormhole.mistral7b.tt.mistral_attention import TtMistralAttention
from models.demos.wormhole.mistral7b.tt.mistral_common import (
precompute_freqs,
prepare_inputs_ttnn,
freqs_to_rotation_matrix,
)
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs
from models.demos.wormhole.mistral7b.reference.model import Attention
from models.utility_functions import (
comp_pcc,
Expand All @@ -29,11 +21,16 @@


@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"iterations",
((1),),
)
def test_mistral_attention_inference(iterations, device, use_program_cache, reset_seeds):
def test_mistral_attention_inference(device, use_program_cache, reset_seeds, is_ci_env):
# Set Mistral flags for CI
if is_ci_env:
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"

# This module requires the env paths above for CI runs
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs

dtype = ttnn.bfloat8_b
pcc = 0.99

Expand Down Expand Up @@ -63,7 +60,7 @@ def test_mistral_attention_inference(iterations, device, use_program_cache, rese
) # ttnn.bfloat16

generation_start_pos = 0
generation_length = iterations
generation_length = 3
all_tests_pass = True

tt_model = TtMistralAttention(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,13 @@
import pytest
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"

import ttnn
from models.demos.wormhole.mistral7b.tt.mistral_attention import TtMistralAttention
from models.demos.wormhole.mistral7b.tt.mistral_common import (
get_prefill_rot_mat,
prepare_inputs_ttnn_prefill,
get_rot_transformation_mat,
)
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs
from models.demos.wormhole.mistral7b.reference.model import Attention, precompute_freqs_cis
from models.utility_functions import (
comp_pcc,
Expand All @@ -39,11 +31,16 @@
4096,
),
)
@pytest.mark.parametrize(
"iterations",
((1),),
)
def test_mistral_attention_inference(iterations, seq_len, device, use_program_cache, reset_seeds):
def test_mistral_attention_inference(seq_len, device, use_program_cache, reset_seeds, is_ci_env):
# Set Mistral flags for CI
if is_ci_env:
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"

# This module requires the env paths above for CI runs
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs

dtype = ttnn.bfloat8_b
pcc = 0.99

Expand All @@ -68,7 +65,7 @@ def test_mistral_attention_inference(iterations, seq_len, device, use_program_ca
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
generation_start_pos = 0
generation_length = iterations
generation_length = 3
all_tests_pass = True

tt_model = TtMistralAttention(
Expand Down
25 changes: 11 additions & 14 deletions models/demos/wormhole/mistral7b/tests/test_mistral_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,13 @@
import pytest
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"

import ttnn
from models.demos.wormhole.mistral7b.tt.mistral_common import (
precompute_freqs,
prepare_inputs_ttnn,
freqs_to_rotation_matrix,
)
from models.demos.wormhole.mistral7b.tt.mistral_decoder import TtTransformerBlock
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs
from models.demos.wormhole.mistral7b.reference.model import TransformerBlock
from models.utility_functions import (
comp_pcc,
Expand All @@ -29,11 +21,16 @@


@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"iterations",
((1),),
)
def test_mistral_decoder_inference(device, iterations, use_program_cache, reset_seeds):
def test_mistral_decoder_inference(device, use_program_cache, reset_seeds, is_ci_env):
# Set Mistral flags for CI
if is_ci_env:
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"

# This module requires the env paths above for CI runs
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs

dtype = ttnn.bfloat8_b

model_args = TtModelArgs(device)
Expand All @@ -48,7 +45,7 @@ def test_mistral_decoder_inference(device, iterations, use_program_cache, reset_
reference_model.load_state_dict(partial_state_dict)

generation_start_pos = 0
generation_length = iterations
generation_length = 2
all_tests_pass = True

# pre-compute the rotational embedding matrix and send to device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,13 @@
import pytest
from loguru import logger
import os

# Set Mistral flags for CI, if CI environment is setup
if os.getenv("CI") == "true":
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"

import ttnn
from models.demos.wormhole.mistral7b.tt.mistral_common import (
get_prefill_rot_mat,
prepare_inputs_ttnn_prefill,
get_rot_transformation_mat,
)
from models.demos.wormhole.mistral7b.tt.mistral_decoder import TtTransformerBlock
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs
from models.demos.wormhole.mistral7b.reference.model import TransformerBlock, precompute_freqs_cis
from models.utility_functions import (
comp_pcc,
Expand All @@ -39,11 +31,16 @@
4096,
),
)
@pytest.mark.parametrize(
"iterations",
((1),),
)
def test_mistral_decoder_inference(device, iterations, seq_len, use_program_cache, reset_seeds):
def test_mistral_decoder_inference(device, seq_len, use_program_cache, reset_seeds, is_ci_env):
# Set Mistral flags for CI
if is_ci_env:
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"

# This module requires the env paths above for CI runs
from models.demos.wormhole.mistral7b.tt.model_config import TtModelArgs

dtype = ttnn.bfloat8_b

model_args = TtModelArgs(device)
Expand All @@ -58,7 +55,7 @@ def test_mistral_decoder_inference(device, iterations, seq_len, use_program_cach
reference_model.load_state_dict(partial_state_dict)

generation_start_pos = 0
generation_length = iterations
generation_length = 2
all_tests_pass = True

# pre-compute the rotational embedding matrix and send to device
Expand Down
Loading

0 comments on commit 457bcd0

Please sign in to comment.