From cef1262034c8c2f2320c149fcfefff92eb8365c4 Mon Sep 17 00:00:00 2001 From: kkeerthana0573 Date: Wed, 20 Nov 2024 10:57:24 +0000 Subject: [PATCH] #13331: Add Device Performance test for Whisper Model --- .../reference/torch_functional_whisper.py | 491 +++++++++++++ .../whisper/tests/test_perf_device_whisper.py | 34 + .../tt/ttnn_optimized_functional_whisper.py | 672 ++++++++++++++++++ .../test_ttnn_optimized_functional_whisper.py | 153 ++-- 4 files changed, 1279 insertions(+), 71 deletions(-) create mode 100644 models/demos/whisper/reference/torch_functional_whisper.py create mode 100644 models/demos/whisper/tests/test_perf_device_whisper.py create mode 100644 models/demos/whisper/tt/ttnn_optimized_functional_whisper.py diff --git a/models/demos/whisper/reference/torch_functional_whisper.py b/models/demos/whisper/reference/torch_functional_whisper.py new file mode 100644 index 000000000000..bb97bc164f13 --- /dev/null +++ b/models/demos/whisper/reference/torch_functional_whisper.py @@ -0,0 +1,491 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import transformers +from transformers import AutoFeatureExtractor, WhisperModel +from datasets import load_dataset +import torch +from typing import Optional + +from torch.nn import functional as F +from ttnn.model_preprocessing import preprocess_model_parameters +from loguru import logger + + +def conv(input, weight, bias, stride=1, padding=1, dilation=1, groups=1): + return F.conv1d(input, weight, bias, stride, padding, dilation, groups) + + +def gelu(tensor): + return torch.nn.functional.gelu(tensor) + + +def dropout(hidden_states, p, training): + return hidden_states + # return torch.nn.functional.dropout(hidden_states, p=p, training=training) + + +def calculate_key_values(config, key_value_states, parameters): + bsz, tgt_len, hidden_size = key_value_states.size() + head_size = hidden_size // config.encoder_attention_heads + + fused_qkv = key_value_states @ parameters.key_value.weight + parameters.key_value.bias + fused_qkv = torch.reshape(fused_qkv, shape=(bsz, tgt_len, 2, config.encoder_attention_heads, head_size)) + key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :] + + key_states = torch.reshape(key_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size)) + key_states = torch.permute(key_states, (0, 2, 1, 3)) + key_states = key_states.contiguous() + + value_states = torch.reshape(value_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size)) + value_states = torch.permute(value_states, (0, 2, 1, 3)) + value_states = value_states.contiguous() + + return key_states, value_states + + +def split_query_key_value_and_split_heads(config, fused_qkv): + head_size = config.d_model // config.encoder_attention_heads + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + hidden_size = three_times_hidden_size // 3 + num_heads = hidden_size // head_size + + fused_qkv = torch.reshape(fused_qkv, shape=(batch_size, seq_length, 3, num_heads, head_size)) + query_states, key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :], fused_qkv[..., 2, :, :] + + query_states = torch.reshape(query_states, shape=(batch_size, seq_length, num_heads, head_size)) + query_states = torch.permute(query_states, (0, 2, 1, 3)) + + key_states = torch.reshape(key_states, shape=(batch_size, seq_length, num_heads, head_size)) + key_states = torch.permute(key_states, (0, 2, 1, 3)) + + value_states = torch.reshape(value_states, shape=(batch_size, seq_length, num_heads, head_size)) + value_states = torch.permute(value_states, (0, 2, 1, 3)) + + return query_states, key_states, value_states + + +def calculate_query_key_values(config, hidden_states, *, parameters): + fused_qkv = hidden_states @ parameters.query_key_value.weight + parameters.query_key_value.bias + return split_query_key_value_and_split_heads(config, fused_qkv) + + +def whisper_attention(config, hidden_states, attention_mask, key_value_states, *, parameters): + head_size = config.d_model // config.encoder_attention_heads + scaling = head_size**-0.5 + bsz, tgt_len, _ = hidden_states.size() + + is_cross_attention = key_value_states is not None + if is_cross_attention: + query_states = hidden_states @ parameters.q_proj.weight + parameters.q_proj.bias + query_states = torch.reshape(query_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size)) + query_states = torch.permute(query_states, (0, 2, 1, 3)) + key_states, value_states = calculate_key_values(config, key_value_states, parameters=parameters) + else: + query_states, key_states, value_states = calculate_query_key_values( + config, hidden_states, parameters=parameters + ) + query_states *= scaling + + proj_shape = (bsz * config.encoder_attention_heads, -1, head_size) + query_states = torch.reshape(query_states, shape=proj_shape) + key_states = torch.reshape(key_states, shape=proj_shape) + value_states = torch.reshape(value_states, shape=proj_shape) + + attn_weights = query_states @ torch.permute(key_states, (0, 2, 1)) + if attention_mask is not None: + bsz, _, tgt_len, src_len = attention_mask.size() + attn_weights = ( + torch.reshape(attn_weights, shape=(bsz, config.encoder_attention_heads, tgt_len, src_len)) + attention_mask + ) + attn_weights = torch.reshape(attn_weights, shape=(bsz * config.encoder_attention_heads, tgt_len, src_len)) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + attn_probs = dropout(attn_weights, p=0, training=False) + attn_output = attn_probs @ value_states + attn_output = torch.reshape(attn_output, shape=(bsz, config.encoder_attention_heads, tgt_len, head_size)) + attn_output = torch.permute(attn_output, (0, 2, 1, 3)) + attn_output = attn_output.reshape(bsz, tgt_len, config.d_model) + attn_output = attn_output @ parameters.out_proj.weight + parameters.out_proj.bias + return attn_output + + +def encoder_layer(config, hidden_states, *, parameters): + residual = hidden_states + + hidden_states = F.layer_norm( + hidden_states, + (config.d_model,), + parameters.self_attn_layer_norm.weight, + parameters.self_attn_layer_norm.bias, + ) + hidden_states = whisper_attention( + config, hidden_states, attention_mask=None, key_value_states=None, parameters=parameters.self_attn + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = residual + hidden_states + + residual = hidden_states + + hidden_states = F.layer_norm( + hidden_states, + (config.d_model,), + parameters.final_layer_norm.weight, + parameters.final_layer_norm.bias, + ) + hidden_states = hidden_states @ parameters.fc1.weight + parameters.fc1.bias + hidden_states = gelu(hidden_states) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = hidden_states @ parameters.fc2.weight + parameters.fc2.bias + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states + + +def encoder(config, inputs_embeds, *, parameters): + hidden_states = inputs_embeds + parameters.embed_positions.weight + hidden_states = dropout(hidden_states, p=0, training=False) + + for encoder_layer_parameter in parameters.layers: + hidden_states = encoder_layer(config, hidden_states, parameters=encoder_layer_parameter) + + hidden_states = F.layer_norm( + hidden_states, + (config.d_model,), + parameters.layer_norm.weight, + parameters.layer_norm.bias, + ) + return hidden_states + + +def encoder_original(config, input_features, *, parameters): + inputs_embeds = gelu( + conv( + input_features, + weight=parameters.conv1.weight, + bias=parameters.conv1.bias, + padding=1, + ) + ) + inputs_embeds = gelu( + conv( + inputs_embeds, + weight=parameters.conv2.weight, + bias=parameters.conv2.bias, + stride=2, + padding=1, + ) + ) + inputs_embeds = inputs_embeds.permute(0, 2, 1) + hidden_states = inputs_embeds + parameters.embed_positions.weight + hidden_states = dropout(hidden_states, p=0, training=False) + + for encoder_layer_parameter in parameters.layers: + hidden_states = encoder_layer( + config, + hidden_states, + parameters=encoder_layer_parameter, + ) + + hidden_states = F.layer_norm( + hidden_states, + (config.d_model,), + parameters.layer_norm.weight, + parameters.layer_norm.bias, + ) + return hidden_states + + +def make_causal_mask(input_ids_shape, dtype): + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def decoder_layer(config, hidden_states, attention_mask, encoder_hidden_states, *, parameters): + residual = hidden_states + hidden_states = F.layer_norm( + hidden_states, + (config.d_model,), + parameters.self_attn_layer_norm.weight, + parameters.self_attn_layer_norm.bias, + ) + + hidden_states = whisper_attention( + config, + hidden_states=hidden_states, + attention_mask=attention_mask, + key_value_states=None, + parameters=parameters.self_attn, + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = residual + hidden_states + + # Cross-Attention Block + residual = hidden_states + hidden_states = F.layer_norm( + hidden_states, + (config.d_model,), + parameters.encoder_attn_layer_norm.weight, + parameters.encoder_attn_layer_norm.bias, + ) + + hidden_states = whisper_attention( + config, + hidden_states, + attention_mask=None, + key_value_states=encoder_hidden_states, + parameters=parameters.encoder_attn, + ) + + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = F.layer_norm( + hidden_states, + (config.d_model,), + parameters.final_layer_norm.weight, + parameters.final_layer_norm.bias, + ) + hidden_states = hidden_states @ parameters.fc1.weight + parameters.fc1.bias + hidden_states = gelu(hidden_states) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = hidden_states @ parameters.fc2.weight + parameters.fc2.bias + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = residual + hidden_states + + return hidden_states + + +def prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + if input_shape[-1] > 1: + combined_attention_mask = make_causal_mask(input_shape, inputs_embeds.dtype) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +def decoder(config, hidden_states, decoder_attention_mask, encoder_hidden_states, *, parameters): + hidden_states = dropout(hidden_states, p=0, training=False) + + for decoder_layer_parameter in parameters.layers: + hidden_states = decoder_layer( + config, + hidden_states, + decoder_attention_mask, + encoder_hidden_states, + parameters=decoder_layer_parameter, + ) + + hidden_states = F.layer_norm( + hidden_states, + (config.d_model,), + parameters.layer_norm.weight, + parameters.layer_norm.bias, + ) + + return hidden_states + + +def decoder_original(config, input_ids, attention_mask, encoder_hidden_states, parameters): + input_shape = input_ids.size() + input_ids = torch.reshape(input_ids, (-1, input_shape[-1])) + inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight) + attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds) + positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]] + + hidden_states = inputs_embeds + positions + hidden_states = dropout(hidden_states, p=0, training=False) + + for decoder_layer_parameter in parameters.layers: + hidden_states = decoder_layer( + config, + hidden_states, + attention_mask, + encoder_hidden_states, + parameters=decoder_layer_parameter, + ) + + hidden_states = F.layer_norm( + hidden_states, + (config.d_model,), + parameters.layer_norm.weight, + parameters.layer_norm.bias, + ) + + return hidden_states + + +def preprocess_encoder_inputs(input_features, parameters): + inputs_embeds = gelu( + conv( + input_features, + weight=parameters.conv1.weight, + bias=parameters.conv1.bias, + padding=1, + ) + ) + inputs_embeds = gelu( + conv( + inputs_embeds, + weight=parameters.conv2.weight, + bias=parameters.conv2.bias, + stride=2, + padding=1, + ) + ) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + return inputs_embeds + + +def preprocess_decoder_inputs(input_ids, attention_mask, *, parameters): + input_shape = input_ids.size() + input_ids = torch.reshape(input_ids, (-1, input_shape[-1])) + inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight) + attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds) + + positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]] + decoder_hidden_states = inputs_embeds + positions + + return decoder_hidden_states, attention_mask + + +def preprocess_inputs( + *, + input_features, + input_ids, + attention_mask, + parameters, +): + input_embeds = preprocess_encoder_inputs(input_features, parameters.encoder) + (decoder_hidden_states, attention_mask) = preprocess_decoder_inputs( + input_ids, attention_mask, parameters=parameters.decoder + ) + return input_embeds, decoder_hidden_states, attention_mask + + +def whisper_original(config, input_features, decoder_input_ids, attention_mask, *, parameters): + encoder_hidden_states = encoder_original(config, input_features, parameters=parameters.encoder) + return decoder_original( + config, + input_ids=decoder_input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + parameters=parameters.decoder, + ) + + +def whisper(config, input_embeds, decoder_hidden_states, decoder_attention_mask, *, parameters): + encoder_hidden_states = encoder(config, input_embeds, parameters=parameters.encoder) + return decoder( + config, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + parameters=parameters.decoder, + ) + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): + if "encoder_attn" in name: + parameters = {"key_value": {}, "q_proj": {}, "out_proj": {}} + preprocessed_weight = torch.cat([torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0) + preprocessed_bias = torch.cat([torch.zeros_like(torch_model.v_proj.bias), torch_model.v_proj.bias], dim=0) + parameters["key_value"]["weight"] = preprocessed_weight.T.contiguous() + parameters["key_value"]["bias"] = preprocessed_bias + parameters["q_proj"]["weight"] = torch_model.q_proj.weight.T.contiguous() + parameters["q_proj"]["bias"] = torch_model.q_proj.bias + else: + parameters = {"query_key_value": {}, "out_proj": {}} + preprocessed_weight = torch.cat( + [torch_model.q_proj.weight, torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0 + ) + preprocessed_bias = torch.cat( + [torch_model.q_proj.bias, torch.zeros_like(torch_model.v_proj.bias), torch_model.v_proj.bias], dim=0 + ) + parameters["query_key_value"]["weight"] = preprocessed_weight.T.contiguous() + parameters["query_key_value"]["bias"] = preprocessed_bias + + parameters["out_proj"]["weight"] = torch_model.out_proj.weight.T.contiguous() + parameters["out_proj"]["bias"] = torch_model.out_proj.bias + return parameters + + +if __name__ == "__main__": + # The following is simply to visualize the operations from pytorch + # sudo apt install graphviz + # pip install graphviz torchview + from torchview import draw_graph + from datasets import load_dataset + + model_name = "openai/whisper-base" + model = WhisperModel.from_pretrained(model_name).to(torch.bfloat16).eval() + + feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") + input_features = inputs.input_features + decoder_input_ids = torch.ones(1, 1).type(torch.int32) * model.config.decoder_start_token_id + + model_graph = draw_graph( + model, + input_size=((1, 80, 3000), (1, 2), (1, 80)), + dtypes=[torch.bfloat16, torch.int64, torch.int64], + expand_nested=True, + depth=10, + directory="out", + ) + model_graph.visual_graph.render(format="svg") + + # Sanity check the torch functional approach + parameters = preprocess_model_parameters( + model_name=f"torch_{model_name}", + initialize_model=lambda: model, + custom_preprocessor=custom_preprocessor, + convert_to_ttnn=lambda *_: False, + ) + last_hidden_state = whisper_original( + model.config, input_features, decoder_input_ids, attention_mask=None, parameters=parameters + ) + logger.info(last_hidden_state.shape) + last_three = last_hidden_state[0, -1, -3:] + logger.info(last_three) diff --git a/models/demos/whisper/tests/test_perf_device_whisper.py b/models/demos/whisper/tests/test_perf_device_whisper.py new file mode 100644 index 000000000000..f8b812fdc329 --- /dev/null +++ b/models/demos/whisper/tests/test_perf_device_whisper.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from models.utility_functions import is_grayskull +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report + + +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("batch_size", [8]) +def test_perf_device_bare_metal(device, batch_size, reset_seeds): + subdir = "ttnn_whisper_optimized_" + margin = 0.03 + num_iterations = 1 + + expected_perf = 13.38 if is_grayskull else 3.06 + command = ( + f"pytest tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py::test_ttnn_whisper" + ) + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + prep_device_perf_report( + model_name=f"ttnn_optimized_whisper_{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments=test.replace("/", "_"), + ) diff --git a/models/demos/whisper/tt/ttnn_optimized_functional_whisper.py b/models/demos/whisper/tt/ttnn_optimized_functional_whisper.py new file mode 100644 index 000000000000..172668664b9a --- /dev/null +++ b/models/demos/whisper/tt/ttnn_optimized_functional_whisper.py @@ -0,0 +1,672 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +import transformers +from typing import Optional +from torch.nn import functional as F +from ttnn.model_preprocessing import preprocess_linear_weight, preprocess_linear_bias + + +WHISPER_DTYPE = ttnn.bfloat8_b + + +def dropout(hidden_states, p, training): + # ignored for inference + return hidden_states + + +# The split_query_key_value_and_split_heads requires the query to have the same volume as the key and values +# This is not the case however for whisper so we currently cannot swap out calculate_key_values below +# def calculate_key_values(config, query_states, key_value_states, *, parameters): +# fused_kv = key_value_states @ parameters.key_value.weight + parameters.key_value.bias +# head_size = config.d_model // config.encoder_attention_heads +# batch_size, *_, _, two_times_hidden_size = fused_kv.shape.with_tile_padding() +# hidden_size = two_times_hidden_size // 2 +# encoder_attention_heads = hidden_size // head_size +# query_states, key_states, value_states = ttnn.transformer.split_query_key_value_and_split_heads( +# query_states, +# kv_input_tensor=fused_kv, +# num_heads=encoder_attention_heads, +# memory_config=WHISPER_MEMORY_CONFIG, +# ) +# key_states = ttnn.permute(key_states, (0, 1, 3, 2)) +# return query_states, key_states, value_states + + +def calculate_key_values(config, key_value_states, *, parameters, whisper_memory_config): + bsz, tgt_len, hidden_size = key_value_states.shape + bsz, tgt_len_padded, _ = key_value_states.shape.with_tile_padding() + head_size = hidden_size // config.encoder_attention_heads + + fused_qkv = ttnn.linear( + key_value_states, + parameters.weight, + bias=parameters.bias, + memory_config=whisper_memory_config, + ) + dtype = fused_qkv.dtype + device = fused_qkv.device() + + # fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.ROW_MAJOR_LAYOUT) + # fused_qkv = ttnn.from_device(fused_qkv) + # fused_qkv = ttnn.reshape(fused_qkv, (bsz, tgt_len, config.encoder_attention_heads, 2, head_size)) + # # Without Split: 0.84 pcc + # key_states = ttnn.reshape(fused_qkv, (bsz, tgt_len, config.encoder_attention_heads, head_size * 2))[..., :head_size] + # value_states = ttnn.reshape(fused_qkv, (bsz, tgt_len, config.encoder_attention_heads, head_size * 2))[..., head_size:] + + # key_states = ttnn.to_device(key_states, device) + # key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) + # key_states = ttnn.permute(key_states, (0, 2, 3, 1)) + + # value_states = ttnn.to_device(value_states, device) + # value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) + # value_states = ttnn.permute(value_states, (0, 2, 1, 3)) + + fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.ROW_MAJOR_LAYOUT) + fused_qkv = ttnn.from_device(fused_qkv) + fused_qkv = ttnn.reshape(fused_qkv, (bsz, tgt_len, 2, config.encoder_attention_heads, head_size)) + fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.TILE_LAYOUT) + fused_qkv = ttnn.to_device(fused_qkv, device=device) + + # #13672: Slice op Not supported for 5d tensors. + fused_qkv = ttnn.to_torch(fused_qkv) + key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :] # + key_states = ttnn.from_torch(key_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + value_states = ttnn.from_torch(value_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + key_states = ttnn.permute(key_states, (0, 2, 3, 1)) + value_states = ttnn.permute(value_states, (0, 2, 1, 3)) + key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) + value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) + + desired_shape = ttnn.Shape( + [bsz, config.encoder_attention_heads, head_size, tgt_len], + [bsz, config.encoder_attention_heads, head_size, tgt_len_padded], + ) + key_states = ttnn.reshape(key_states, shape=desired_shape) + + desired_shape = ttnn.Shape( + [bsz, config.encoder_attention_heads, tgt_len, head_size], + [bsz, config.encoder_attention_heads, tgt_len_padded, head_size], + ) + value_states = ttnn.reshape(value_states, shape=desired_shape) + + return key_states, value_states + + +def calculate_query_key_values(config, hidden_states, *, parameters, whisper_memory_config): + fused_qkv = ttnn.linear( + hidden_states, + parameters.weight, + bias=parameters.bias, + ) + + return ttnn.transformer.split_query_key_value_and_split_heads( + fused_qkv, memory_config=whisper_memory_config, num_heads=config.num_attention_heads + ) + + +def whisper_attention( + config, device, hidden_states, attention_mask, key_value_states=None, *, parameters, whisper_memory_config +): + head_size = config.d_model // config.encoder_attention_heads + scaling = head_size**-0.5 + bsz, *_, tgt_len, _ = hidden_states.shape + + is_cross_attention = key_value_states is not None + if is_cross_attention: + query_states = ttnn.linear( + hidden_states, + parameters.q_proj.weight, + bias=parameters.q_proj.bias, + memory_config=whisper_memory_config, + ) + query_states = ttnn.to_layout(query_states, layout=ttnn.ROW_MAJOR_LAYOUT) + query_states = ttnn.from_device(query_states) + query_states = ttnn.reshape(query_states, (bsz, tgt_len, config.encoder_attention_heads, head_size)) + query_states = ttnn.to_layout(query_states, layout=ttnn.TILE_LAYOUT) + query_states = ttnn.to_device(query_states, device=device) + query_states = ttnn.permute(query_states, (0, 2, 1, 3)) + key_states, value_states = calculate_key_values( + config, key_value_states, parameters=parameters.key_value, whisper_memory_config=whisper_memory_config + ) + else: + query_states, key_states, value_states = calculate_query_key_values( + config, hidden_states, parameters=parameters.query_key_value, whisper_memory_config=whisper_memory_config + ) + + query_states *= scaling + attn_weights = ttnn.matmul(query_states, key_states) + + if attention_mask is not None: + attn_weights = ttnn.add(attn_weights, attention_mask) + + # differences in ttnn.softmax vs torch.softmax cause the attn_weights to be slightly different + attn_weights = ttnn.softmax(attn_weights, dim=-1) + + attn_probs = dropout(attn_weights, p=0, training=False) + attn_output = ttnn.matmul(attn_probs, value_states, memory_config=whisper_memory_config) + + ttnn.deallocate(attn_probs) + ttnn.deallocate(attn_weights) + ttnn.deallocate(query_states) + + attn_output = ttnn.transformer.concatenate_heads(attn_output) + + attn_output = ttnn.linear( + attn_output, + parameters.out_proj.weight, + bias=parameters.out_proj.bias, + memory_config=whisper_memory_config, + ) + + return attn_output + + +def encoder_layer(config, device, hidden_states, *, parameters, whisper_memory_config): + residual = hidden_states + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.self_attn_layer_norm.weight, + bias=parameters.self_attn_layer_norm.bias, + memory_config=whisper_memory_config, + ) + + hidden_states = whisper_attention( + config, + device, + hidden_states, + attention_mask=None, + parameters=parameters.self_attn, + whisper_memory_config=whisper_memory_config, + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + + residual = hidden_states + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.final_layer_norm.weight, + bias=parameters.final_layer_norm.bias, + memory_config=whisper_memory_config, + ) + + hidden_states = ttnn.linear( + hidden_states, + parameters.fc1.weight, + bias=parameters.fc1.bias, + ) + + hidden_states = ttnn.gelu(hidden_states, memory_config=whisper_memory_config) + hidden_states = dropout(hidden_states, p=0, training=False) + + hidden_states = ttnn.linear( + hidden_states, + parameters.fc2.weight, + bias=parameters.fc2.bias, + memory_config=whisper_memory_config, + ) + + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + + # if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): + # clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + # hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states + + +def encoder(config, device, inputs_embeds, *, parameters, whisper_memory_config): + hidden_states = ttnn.add(inputs_embeds, parameters.embed_positions.weight) + hidden_states = dropout(hidden_states, p=0, training=False) + + for encoder_layer_parameter in parameters.layers: + hidden_states = encoder_layer( + config, + device, + hidden_states, + parameters=encoder_layer_parameter, + whisper_memory_config=whisper_memory_config, + ) + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.layer_norm.weight, + bias=parameters.layer_norm.bias, + ) + + return hidden_states + + +def make_causal_mask(input_ids_shape, dtype): + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.shape + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def decoder_layer( + config, device, hidden_states, attention_mask, encoder_hidden_states, *, parameters, whisper_memory_config +): + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.self_attn_layer_norm.weight, + bias=parameters.self_attn_layer_norm.bias, + memory_config=whisper_memory_config, + ) + + hidden_states = whisper_attention( + config, + device, + hidden_states=hidden_states, + attention_mask=attention_mask, + parameters=parameters.self_attn, + whisper_memory_config=whisper_memory_config, + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + + # Cross-Attention Block + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.encoder_attn_layer_norm.weight, + bias=parameters.encoder_attn_layer_norm.bias, + ) + + hidden_states = whisper_attention( + config, + device, + hidden_states, + attention_mask=None, + key_value_states=encoder_hidden_states, + parameters=parameters.encoder_attn, + whisper_memory_config=whisper_memory_config, + ) + + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.final_layer_norm.weight, + bias=parameters.final_layer_norm.bias, + ) + + hidden_states = ttnn.linear( + hidden_states, parameters.fc1.weight, bias=parameters.fc1.bias, memory_config=whisper_memory_config + ) + hidden_states = ttnn.gelu(hidden_states, memory_config=whisper_memory_config) + hidden_states = dropout(hidden_states, p=0, training=False) + + hidden_states = ttnn.linear( + hidden_states, parameters.fc2.weight, bias=parameters.fc2.bias, memory_config=whisper_memory_config + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + return hidden_states + + +def prepare_decoder_attention_mask(attention_mask, input_shape, input_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + if input_shape[-1] > 1: + combined_attention_mask = make_causal_mask(input_shape, input_embeds.dtype) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = expand_mask(attention_mask, input_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +def decoder( + config, device, hidden_states, decoder_attention_mask, encoder_hidden_states, *, parameters, whisper_memory_config +): + hidden_states = dropout(hidden_states, p=0, training=False) + + for decoder_layer_parameter in parameters.layers: + hidden_states = decoder_layer( + config, + device, + hidden_states, + decoder_attention_mask, + encoder_hidden_states, + parameters=decoder_layer_parameter, + whisper_memory_config=whisper_memory_config, + ) + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.layer_norm.weight, + bias=parameters.layer_norm.bias, + ) + + return hidden_states + + +def convert_to_ttnn(model, name): + return name not in [ + "encoder.conv1", + "encoder.conv2", + "decoder.embed_tokens", + "decoder.embed_positions", + ] + + +def preprocess_encoder_inputs(input_features, *, parameters, device, whisper_memory_config): + def conv(input, weight, bias, stride=1, padding=1, dilation=1, groups=1): + return F.conv1d(input, weight, bias, stride, padding, dilation, groups) + + def ttnn_conv1d( + device, + tt_input_tensor, + weights, + conv_params, + bias, + *, + output_dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat8_b, + math_fidelity=ttnn.MathFidelity.LoFi, + deallocate_activation=True, + act_block_h=32, + height_sharding=True, + use_shallow_conv_variant=False, + fp32_accum=False, + packer_l1_acc=False, + debug=False, + groups=1, + math_approx=False, + activation="", + reallocate_halo=False, + reshard=False, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, + ): + weights = ttnn.from_torch(weights, dtype=ttnn.float32) + bias = ttnn.from_torch(bias.unsqueeze(0).unsqueeze(0).unsqueeze(0), dtype=ttnn.float32) + + conv_config = ttnn.Conv1dConfig( + dtype=output_dtype, + weights_dtype=weights_dtype, + math_approx_mode_enabled=math_approx, + fp32_dest_acc_enabled=fp32_accum, + packer_l1_accum_enabled=packer_l1_acc, + activation=activation, + input_channels_alignment=(16 if use_shallow_conv_variant else 32), + deallocate_activation=deallocate_activation, + reallocate_halo_output=reallocate_halo, + act_block_h_override=act_block_h, + reshard_if_not_optimal=reshard, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), + math_fidelity=math_fidelity, + ) + + [tt_output_tensor_on_device, out_length, weights_device, bias_device] = ttnn.Conv1d( + input_tensor=tt_input_tensor, + weight_tensor=weights, + in_channels=tt_input_tensor.shape[-1], + out_channels=weights.shape[0], + device=device, + bias_tensor=bias, + kernel_size=3, + stride=conv_params[0], + padding=conv_params[1], + batch_size=tt_input_tensor.shape[0], + input_length=tt_input_tensor.shape[1], + conv_config=conv_config, + conv_op_cache={}, + debug=debug, + groups=groups, + ) + tt_output_tensor_on_device = ttnn.squeeze(tt_output_tensor_on_device, 0) + tt_output_tensor_on_device = ttnn.to_layout(tt_output_tensor_on_device, layout=ttnn.ROW_MAJOR_LAYOUT) + tt_output_tensor_on_device = ttnn.reshape( + tt_output_tensor_on_device, (tt_input_tensor.shape[0], out_length, tt_output_tensor_on_device.shape[-1]) + ) + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) + + return tt_output_tensor + + if parameters.conv1.weight.shape[0] == 512: + input_features = ttnn.from_torch(input_features, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + input_features = ttnn.permute(input_features, (0, 2, 1)) + conv1 = ttnn_conv1d( + device, + input_features, + parameters.conv1.weight, + [1, 1], + parameters.conv1.bias, + ) + conv1 = ttnn.to_layout(conv1, ttnn.TILE_LAYOUT) + conv1 = ttnn.to_device(conv1, device) + conv1 = ttnn.permute(conv1, (0, 2, 1)) + + else: + conv1 = conv( + input_features.float(), + weight=parameters.conv1.weight, + bias=parameters.conv1.bias, + stride=1, + padding=1, + ) + conv1 = ttnn.from_torch(conv1, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + input_embeds = ttnn.gelu(conv1, memory_config=whisper_memory_config) + input_embeds = ttnn.to_layout(input_embeds, layout=ttnn.ROW_MAJOR_LAYOUT) + + # input_embeds = ttnn.permute(input_embeds, (0, 2, 1)) + input_embeds = ttnn.to_torch(input_embeds) + + # #13529 ttnn.conv1d throws OOM here. + # conv2 = ttnn_conv1d( + # device, + # input_embeds, + # parameters.conv2.weight, + # [2, 1], + # parameters.conv2.bias, + # ) + # conv2 = ttnn.to_layout(conv2, ttnn.TILE_LAYOUT) + # conv2 = ttnn.to_device(conv2, device) + # conv2 = ttnn.permute(conv2, (0, 2, 1)) + # input_embeds = ttnn.gelu(conv2, memory_config=whisper_memory_config) + + conv = conv( + input_embeds.float(), + weight=parameters.conv2.weight, + bias=parameters.conv2.bias, + stride=2, + padding=1, + ) + conv = ttnn.from_torch(conv, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + input_embeds = ttnn.gelu(conv, memory_config=whisper_memory_config) + input_embeds = ttnn.permute(input_embeds, (0, 2, 1)) + + return input_embeds + + +def preprocess_decoder_inputs(config, input_ids, attention_mask, *, parameters, device): + input_shape = input_ids.size() + input_ids = torch.reshape(input_ids, (-1, input_shape[-1])) + inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight) + attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds) + # ttnn cannot broadcast when adding on the batch or channel dimensions so this is a workaround + attention_mask = attention_mask.expand(-1, config.encoder_attention_heads, -1, -1) + + positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]] + decoder_hidden_states = inputs_embeds + positions + + decoder_hidden_states = ttnn.from_torch( + decoder_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + return decoder_hidden_states, attention_mask + + +def preprocess_inputs(*, config, input_features, input_ids, attention_mask, parameters, device, whisper_memory_config): + input_embeds = preprocess_encoder_inputs( + input_features, parameters=parameters.encoder, device=device, whisper_memory_config=whisper_memory_config + ) + (decoder_hidden_states, attention_mask) = preprocess_decoder_inputs( + config, input_ids, attention_mask, parameters=parameters.decoder, device=device + ) + return input_embeds, decoder_hidden_states, attention_mask + + +def whisper( + config, + device, + encoder_hidden_states, + decoder_hidden_states, + decoder_attention_mask, + *, + parameters, + whisper_memory_config, +): + encoder_hidden_states = encoder( + config, + device, + encoder_hidden_states, + parameters=parameters.encoder, + whisper_memory_config=whisper_memory_config, + ) + + last_hidden_state = decoder( + config, + device, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + parameters=parameters.decoder, + whisper_memory_config=whisper_memory_config, + ) + + return last_hidden_state + + +def whisper_for_audio_classification(config, inputs_embeds, *, parameters, device, batch_size, whisper_memory_config): + encoder_outputs = encoder( + config=config, + device=device, + inputs_embeds=inputs_embeds, + parameters=parameters.encoder, + whisper_memory_config=whisper_memory_config, + ) + hidden_states = ttnn.linear( + encoder_outputs, + parameters.projector.weight, + bias=parameters.projector.bias, + memory_config=whisper_memory_config, + ) + pooled_output = ttnn.mean(hidden_states, dim=-2, keepdim=True) + + logits = ttnn.linear( + pooled_output, + parameters.classifier.weight, + bias=parameters.classifier.bias, + memory_config=whisper_memory_config, + ) + return logits + + +def whisper_for_conditional_generation( + config, + input_embeds, + decoder_hidden_states, + decoder_attention_mask, + *, + parameters, + device, + ttnn_linear_weight, + whisper_memory_config, +): + output = whisper( + config=config, + device=device, + encoder_hidden_states=input_embeds, + decoder_hidden_states=decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + whisper_memory_config=whisper_memory_config, + ) + + ttnn_output = ttnn.matmul( + output, + ttnn_linear_weight, + dtype=ttnn.bfloat16, + ) + return ttnn_output + + +def preprocess_conv_parameter(parameter, *, dtype): + parameter = ttnn.from_torch(parameter, dtype=dtype) + return parameter + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): + height, width = torch_model.k_proj.weight.shape + + if "encoder_attn" in name: + parameters = {"key_value": {}, "q_proj": {}, "out_proj": {}} + preprocessed_weight = torch.cat([torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0) + preprocessed_bias = torch.cat([torch.zeros(height), torch_model.v_proj.bias], dim=0) + parameters["key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) + parameters["key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) + parameters["q_proj"]["weight"] = preprocess_linear_weight(torch_model.q_proj.weight, dtype=ttnn.bfloat16) + parameters["q_proj"]["bias"] = preprocess_linear_bias(torch_model.q_proj.bias, dtype=ttnn.bfloat16) + else: + parameters = {"query_key_value": {}, "out_proj": {}} + preprocessed_weight = torch.cat( + [torch_model.q_proj.weight, torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0 + ) + preprocessed_bias = torch.cat( + [torch_model.q_proj.bias, torch.zeros(height), torch_model.v_proj.bias], dim=0 + ) + parameters["query_key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) + + parameters["out_proj"]["weight"] = preprocess_linear_weight(torch_model.out_proj.weight, dtype=ttnn.bfloat16) + parameters["out_proj"]["bias"] = preprocess_linear_bias(torch_model.out_proj.bias, dtype=ttnn.bfloat16) + + elif name == "encoder.embed_positions" and isinstance(torch_model, torch.nn.Embedding): + embeddings = torch_model.weight.unsqueeze(0).expand(8, -1, -1) + embeddings = ttnn.from_torch(embeddings, dtype=ttnn.bfloat16) + embeddings = ttnn.to_layout(embeddings, ttnn.TILE_LAYOUT) + parameters["weight"] = embeddings + + return parameters diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py index e6cea2f8870d..61f04e5b7f81 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py @@ -2,30 +2,29 @@ # SPDX-License-Identifier: Apache-2.0 +import ttnn +import torch import pytest -from models.experimental.functional_whisper.reference import torch_functional_whisper -from models.experimental.functional_whisper.tt import ttnn_optimized_functional_whisper import transformers -from transformers import AutoFeatureExtractor, WhisperModel, WhisperConfig from datasets import load_dataset -import torch -import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import torch_random +from models.utility_functions import torch_random, is_grayskull from ttnn.model_preprocessing import preprocess_model_parameters -from models.utility_functions import is_wormhole_b0, is_blackhole +from models.demos.whisper.reference import torch_functional_whisper +from models.demos.whisper.tt import ttnn_optimized_functional_whisper MODEL_NAME = "openai/whisper-base" -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) @pytest.mark.parametrize("use_key_value_states", [False, True]) -def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_size, use_key_value_states): - torch.manual_seed(0) +def test_whisper_attention( + device, ttnn_model, model_name, batch_size, sequence_size, use_key_value_states, reset_seeds +): config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperAttention( embed_dim=config.d_model, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout @@ -72,23 +71,24 @@ def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_ attention_mask = None output = ttnn_model.whisper_attention( config, + device, ttnn_hidden_states, attention_mask, key_value_states=ttnn_key_value_states, parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, ) output = ttnn.to_torch(output) - assert_with_pcc(torch_output, output, 0.98) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) -def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size): - torch.manual_seed(0) +def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size, reset_seeds): config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperEncoderLayer(config).eval() model = model @@ -113,20 +113,25 @@ def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device ) - output = ttnn_model.encoder_layer(config, ttnn_hidden_states, parameters=ttnn_parameters) + output = ttnn_model.encoder_layer( + config, + device, + ttnn_hidden_states, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + ) output = ttnn.to_torch(output) assert_with_pcc(torch_output, output, pcc=0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("feature_size", [80]) @pytest.mark.parametrize("sequence_length", [3000]) -def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, sequence_length): - torch.manual_seed(0) +def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, sequence_length, reset_seeds): config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperEncoder(config).eval() model = model @@ -139,10 +144,6 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque custom_preprocessor=torch_functional_whisper.custom_preprocessor, ) - # torch_original_output = torch_functional_whisper.encoder_original( - # torch_input_features, parameters, embed_dim, num_heads - # ) - inputs_embeds = torch_functional_whisper.preprocess_encoder_inputs( input_features=torch_input_features, parameters=parameters, @@ -162,17 +163,24 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque input_features=torch_input_features, parameters=ttnn_parameters, device=device, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, ) input_embeds = ttnn.to_layout(input_embeds, ttnn.TILE_LAYOUT) input_embeds = ttnn.to_device(input_embeds, device) - output = ttnn_model.encoder(config, input_embeds, parameters=ttnn_parameters) + output = ttnn_model.encoder( + config, + device, + input_embeds, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + ) output = ttnn.to_torch(output) - assert_with_pcc(torch_output, output, 0.968) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -222,20 +230,25 @@ def test_decoder_layer(device, ttnn_model, model_name, batch_size, sequence_size ttnn_encoder_hidden_states = ttnn.to_device(ttnn_encoder_hidden_states, device) output = ttnn_model.decoder_layer( - config, ttnn_hidden_states, ttnn_attention_mask, ttnn_encoder_hidden_states, parameters=ttnn_parameters + config, + device, + ttnn_hidden_states, + ttnn_attention_mask, + ttnn_encoder_hidden_states, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, ) output = ttnn.to_torch(output) - assert_with_pcc(torch_output, output, 0.97) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("sequence_size", [1500]) -def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): - torch.manual_seed(0) +def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size, reset_seeds): config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperDecoder(config).eval() model = model @@ -244,7 +257,6 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) - # decoder_input_ids = torch.ones(1, 32).type(torch.int32) * config.decoder_start_token_id decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id attention_mask = None @@ -255,10 +267,6 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): custom_preprocessor=torch_functional_whisper.custom_preprocessor, ) - # torch_original_output = torch_functional_whisper.decoder_original( - # decoder_input_ids, attention_mask, torch_encoder_hidden_states, parameters, embed_dim, num_heads - # ) - (decoder_hidden_states, decoder_attention_mask) = torch_functional_whisper.preprocess_decoder_inputs( decoder_input_ids, attention_mask, parameters=parameters ) @@ -291,32 +299,33 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): output = ttnn_model.decoder( config, + device=device, hidden_states=decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, encoder_hidden_states=ttnn_encoder_hidden_states, parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, ) output = ttnn.to_torch(output) assert_with_pcc(torch_output, output, pcc=0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") -@pytest.mark.requires_fast_runtime_mode_off +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) -def test_ttnn_whisper(tmp_path, device, ttnn_model): - torch.manual_seed(0) - model_name = "openai/whisper-base" - config = WhisperConfig.from_pretrained(model_name) - feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) +def test_ttnn_whisper(reset_seeds, device, batch_size, model_name, ttnn_model): + config = transformers.WhisperConfig.from_pretrained(model_name) + feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") + inputs = feature_extractor( + [ds[i]["audio"]["array"] for i in range(batch_size)], sampling_rate=16000, return_tensors="pt" + ) input_features = inputs.input_features - decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id - - attention_mask = None + decoder_input_ids = torch.tensor([[1, 1]] * batch_size) * config.decoder_start_token_id - model = WhisperModel.from_pretrained(model_name).eval() + model = transformers.WhisperModel.from_pretrained(model_name).eval() parameters = preprocess_model_parameters( initialize_model=lambda: model, @@ -327,11 +336,11 @@ def test_ttnn_whisper(tmp_path, device, ttnn_model): (encoder_hidden_states, decoder_hidden_states, decoder_attention_mask) = torch_functional_whisper.preprocess_inputs( input_features=input_features, input_ids=decoder_input_ids, - attention_mask=attention_mask, + attention_mask=None, parameters=parameters, ) - expected_last_hidden_state = torch_functional_whisper.whisper( + torch_last_hidden_state = torch_functional_whisper.whisper( config, encoder_hidden_states, decoder_hidden_states, @@ -346,24 +355,26 @@ def test_ttnn_whisper(tmp_path, device, ttnn_model): device=device, ) - with ttnn.tracer.trace(): - (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( - config=config, - input_features=input_features, - input_ids=decoder_input_ids, - attention_mask=attention_mask, - parameters=ttnn_parameters, - device=device, - ) + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=input_features, + input_ids=decoder_input_ids, + attention_mask=None, + parameters=ttnn_parameters, + device=device, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + ) - last_hidden_state = ttnn_model.whisper( - config, - input_embeds, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - parameters=ttnn_parameters, - ) - last_hidden_state = ttnn.to_torch(last_hidden_state) - ttnn.tracer.visualize(last_hidden_state, file_name=tmp_path / "whisper.svg") + last_hidden_state = ttnn_model.whisper( + config, + device, + input_embeds, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + ) + + last_hidden_state = ttnn.to_torch(last_hidden_state) - assert_with_pcc(expected_last_hidden_state, last_hidden_state, 0.964) + assert_with_pcc(torch_last_hidden_state, last_hidden_state, 0.97 if is_grayskull else 0.989)