diff --git a/models/demos/wormhole/whisper/README.md b/models/demos/wormhole/whisper/README.md new file mode 100644 index 00000000000..230b00c2236 --- /dev/null +++ b/models/demos/wormhole/whisper/README.md @@ -0,0 +1,43 @@ +# Whisper Demo + +Demo showcasing Data Parallel implementation of Whisper running on Wormhole - n150, n300 using ttnn. + +## Introduction + +Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. These tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing a single model to replace many stages of a traditional speech-processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets. + +## Details + +The entry point to whisper model is `whisper` in `models/demos/wormhole/whisper/tt/ttnn_optimized_functional_whisper.py` for optimized version.. The model picks up certain configs and weights from huggingface pretrained model. We have used openai/whisper-base version from huggingface as our reference. + +### Max Tokens: 32 + +Max Tokens determines the maximum number of input tokens processed by the model in a single pass durig transcription, optimizing performance and compatibility. It's recommended to set the max_tokens to 32 + +### Batch size: 8 + +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 8 + +## How to Run + +### Whisper For Audio Classification +Use `pytest --disable-warnings models/demos/wormhole/whisper/demo/demo.py::test_demo_for_audio_classification[wormhole_b0-True-models.demos.wormhole.whisper.tt.ttnn_optimized_functional_whisper-1-4-WHISPER_MEMORY_CONFIG0-sanchit-gandhi/whisper-medium-fleurs-lang-id-models/demos/wormhole/whisper/demo/dataset/audio_classification]` to run the ttnn optimized functional whisper demo for audio classification. + +#### Our another demo is designed to run with `google/fleurs` for Audio classification + +Use `pytest --disable-warnings models/demos/wormhole/whisper/demo/demo.py::test_demo_for_audio_classification_dataset` to run audio classification demo with dataset inputs. + +### Whisper For Conditional Generation + +Use `pytest --disable-warnings models/demos/wormhole/whisper/demo/demo.py::test_demo_for_conditional_generation[wormhole_b0-True-models.demos.wormhole.whisper.tt.ttnn_optimized_functional_whisper-4-32-WHISPER_MEMORY_CONFIG0-openai/whisper-tiny.en-models/demos/wormhole/whisper/demo/dataset/conditional_generation-device_params0]` to run the ttnn optimized functional whisper demo for conditional generation. + +#### Our another demo is designed to run with `hf-internal-testing/librispeech_asr_dummy` for Conditional generation + +Use `pytest --disable-warnings models/demos/wormhole/whisper/demo/demo.py::test_demo_for_conditional_generation_dataset` to run conditional generation demo with dataset inputs. + + +## Inputs + +Inputs by default are provided from `dataset/audio_classification` and `dataset/conditional_generation` folder. If you wish to change the inputs, provide a different path to demo. + +For demo with dataset, Inputs for Audio classification is taken from `google/fleurs` dataset and Inputs for Conditional generation is taken from `hf-internal-testing/librispeech_asr_dummy` dataset. diff --git a/models/demos/wormhole/whisper/demo/dataset/audio_classification/10116516891483200485.wav b/models/demos/wormhole/whisper/demo/dataset/audio_classification/10116516891483200485.wav new file mode 100644 index 00000000000..2003cba3fce Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/audio_classification/10116516891483200485.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/audio_classification/140291826269534354.wav b/models/demos/wormhole/whisper/demo/dataset/audio_classification/140291826269534354.wav new file mode 100644 index 00000000000..74c05c9bca6 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/audio_classification/140291826269534354.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/audio_classification/1689242038473278354.wav b/models/demos/wormhole/whisper/demo/dataset/audio_classification/1689242038473278354.wav new file mode 100644 index 00000000000..d1dbd535229 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/audio_classification/1689242038473278354.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/audio_classification/17340315164505628698.wav b/models/demos/wormhole/whisper/demo/dataset/audio_classification/17340315164505628698.wav new file mode 100644 index 00000000000..c8f031ca3d7 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/audio_classification/17340315164505628698.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/audio_classification/17659141715436566244.wav b/models/demos/wormhole/whisper/demo/dataset/audio_classification/17659141715436566244.wav new file mode 100644 index 00000000000..3b60f9d9350 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/audio_classification/17659141715436566244.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/audio_classification/17928171511082320095.wav b/models/demos/wormhole/whisper/demo/dataset/audio_classification/17928171511082320095.wav new file mode 100644 index 00000000000..780372c5f01 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/audio_classification/17928171511082320095.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/audio_classification/2086639904747050008.wav b/models/demos/wormhole/whisper/demo/dataset/audio_classification/2086639904747050008.wav new file mode 100644 index 00000000000..06cb1be0c62 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/audio_classification/2086639904747050008.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/audio_classification/622196158886216764.wav b/models/demos/wormhole/whisper/demo/dataset/audio_classification/622196158886216764.wav new file mode 100644 index 00000000000..d200b14eca7 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/audio_classification/622196158886216764.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/audio_classification/7043619860143829064.wav b/models/demos/wormhole/whisper/demo/dataset/audio_classification/7043619860143829064.wav new file mode 100644 index 00000000000..c41e1ebe75e Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/audio_classification/7043619860143829064.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/audio_classification/9522084197299278725.wav b/models/demos/wormhole/whisper/demo/dataset/audio_classification/9522084197299278725.wav new file mode 100644 index 00000000000..6615ab4898c Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/audio_classification/9522084197299278725.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/conditional_generation/11150113890463037787.wav b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/11150113890463037787.wav new file mode 100644 index 00000000000..21bc223a8f5 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/11150113890463037787.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/conditional_generation/1298409023920250606.wav b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/1298409023920250606.wav new file mode 100644 index 00000000000..397987797b9 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/1298409023920250606.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17566024285835266239.wav b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17566024285835266239.wav new file mode 100644 index 00000000000..6ad289c431a Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17566024285835266239.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17646385371758249908.wav b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17646385371758249908.wav new file mode 100644 index 00000000000..fdeacb473d9 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17646385371758249908.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17659141715436566244.wav b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17659141715436566244.wav new file mode 100644 index 00000000000..3b60f9d9350 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17659141715436566244.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17928171511082320095.wav b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17928171511082320095.wav new file mode 100644 index 00000000000..780372c5f01 Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17928171511082320095.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17938133003986293739.wav b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17938133003986293739.wav new file mode 100644 index 00000000000..d0c87b033ff Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/17938133003986293739.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/conditional_generation/2842775607363710885.wav b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/2842775607363710885.wav new file mode 100644 index 00000000000..61f6441346b Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/2842775607363710885.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/conditional_generation/6757317816154782558.wav b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/6757317816154782558.wav new file mode 100644 index 00000000000..badc92b1f2d Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/6757317816154782558.wav differ diff --git a/models/demos/wormhole/whisper/demo/dataset/conditional_generation/6969469525741631060.wav b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/6969469525741631060.wav new file mode 100644 index 00000000000..2e9495b476d Binary files /dev/null and b/models/demos/wormhole/whisper/demo/dataset/conditional_generation/6969469525741631060.wav differ diff --git a/models/demos/wormhole/whisper/demo/demo.py b/models/demos/wormhole/whisper/demo/demo.py new file mode 100644 index 00000000000..7cdc85d1bdb --- /dev/null +++ b/models/demos/wormhole/whisper/demo/demo.py @@ -0,0 +1,620 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import os +import ttnn +import torch +import pytest +import transformers + +from os import listdir +from loguru import logger +from scipy.io import wavfile +from os.path import isfile, join +from datasets import load_dataset +from sklearn.metrics import accuracy_score +from torchmetrics.text import WordErrorRate + +from models.utility_functions import ( + profiler, + is_grayskull, + is_wormhole_b0, + run_for_wormhole_b0, + disable_compilation_reports, + disable_persistent_kernel_cache, +) + +from ttnn.model_preprocessing import preprocess_model_parameters +from models.generation_utils import get_logits_processor, pad_input_32 +from models.demos.wormhole.whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper + + +def load_input_paths(folder_path): + files = [os.path.join(folder_path, f) for f in listdir(folder_path) if isfile(join(folder_path, f))] + return files + + +def run_generate( + config, + input_embeds, + input_features, + ttnn_model, + decoder_hidden_states, + decoder_attention_mask, + parameters, + ttnn_linear_weight, + device, + decoder_input_ids, + generation_config, + batch_size, + max_tokens, + whisper_memory_config, + output_mesh_composer, + inputs_mesh_mapper, +): + logits_processor = get_logits_processor(decoder_input_ids, config) + decoder_start_values = generation_config.pad_token_id * torch.ones(batch_size, input_features.shape[1]).to( + torch.long + ) + eos_reached = torch.zeros(batch_size, dtype=torch.bool) + + profiler.start(f"inference_time") + for i in range(max_tokens): + ttnn_output = ttnn_model.whisper_for_conditional_generation( + config=config, + input_embeds=input_embeds, + decoder_hidden_states=decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + device=device, + ttnn_linear_weight=ttnn_linear_weight, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + logits_to_torch = ttnn.to_torch(ttnn_output, mesh_composer=output_mesh_composer) + next_token_logits = logits_to_torch[:, i, :] + next_tokens_scores = logits_processor(input_features, next_token_logits) + next_tokens = torch.argmax(next_tokens_scores, dim=-1).unsqueeze(0) + + # Check if EOS token is generated for any sample in the batch and + # Setting subsequent next_tokens to config.pad_token_id if EOS token is reached. + eos_generated_flags = next_tokens == config.eos_token_id + eos_reached = eos_reached | eos_generated_flags.squeeze(0) + next_tokens[:, eos_reached] = config.pad_token_id + + if (i + 1) % 32 == 0: + decoder_input_ids = torch.cat([decoder_input_ids, decoder_start_values], dim=1) + + decoder_input_ids[:, i + 1] = next_tokens[:, None] + decoder_hidden_states, decoder_attention_mask = ttnn_model.preprocess_decoder_inputs( + config=config, + input_ids=decoder_input_ids, + attention_mask=None, + parameters=parameters.decoder, + device=device, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + + if torch.all(next_tokens == config.eos_token_id): + break + + profiler.end(f"inference_time") + return decoder_input_ids + + +def run_demo_functional_whisper_for_audio_classification_inference( + mesh_device, model_name, input_path, ttnn_model, num_inputs, batch_size, whisper_memory_config +): + feature_extractor = transformers.AutoFeatureExtractor.from_pretrained( + "sanchit-gandhi/whisper-medium-fleurs-lang-id" + ) + model = transformers.WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") + + model.eval() + input_data = load_input_paths(input_path) + tt_model_name = f"ttnn_{model_name}_optimized" + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + + if len(input_data) < batch_size: + assert False, "batch_size exceeds number of audio files available in folder" + + batched_inputs = [] + for i in range(batch_size): + input_file_path = input_data[i] + samplerate, data = wavfile.read(input_file_path) + + inputs = feature_extractor( + data, + sampling_rate=samplerate, + return_tensors="pt", + ) + + input_features = inputs.input_features + batched_inputs = input_features if i == 0 else torch.cat((batched_inputs, input_features), dim=0) + + config = model.config + input_embedding = ttnn_model.preprocess_encoder_inputs( + input_features=batched_inputs, + parameters=ttnn_parameters.encoder, + device=mesh_device, + whisper_memory_config=whisper_memory_config, + weights_mesh_mapper=weights_mesh_mapper, + inputs_mesh_mapper=inputs_mesh_mapper, + output_mesh_composer=output_mesh_composer, + ) + + out_logits = ttnn_model.whisper_for_audio_classification( + config=config, + inputs_embeds=input_embedding, + parameters=ttnn_parameters, + device=mesh_device, + batch_size=batch_size, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + + logits_torch = ttnn.to_torch(out_logits, mesh_composer=output_mesh_composer) + predicted_list = [] + for i in range(batch_size): + single_logits_torch = logits_torch[i].squeeze(0) + predicted_class_ids = torch.argmax(single_logits_torch).item() + predicted_label = model.config.id2label[predicted_class_ids] + logger.info(f"Predicted label: {predicted_label}") + predicted_list.append(predicted_label) + + return predicted_list + + +def run_demo_functional_whisper_for_conditional_generation_inference( + mesh_device, + reset_seeds, + batch_size, + model_name, + input_path, + ttnn_model, + max_tokens=32, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, +): + model = transformers.WhisperModel.from_pretrained(model_name).eval() + config = transformers.WhisperConfig.from_pretrained(model_name) + processor = transformers.AutoProcessor.from_pretrained(model_name, language="English", task="transcribe") + hf_reference_model = transformers.WhisperForConditionalGeneration.from_pretrained(model_name) + + feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name) + input_data = load_input_paths(input_path) + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + + model = model.eval() + + linear_weight = hf_reference_model.proj_out.weight + ttnn_linear_weight = ttnn.from_torch( + linear_weight, layout=ttnn.TILE_LAYOUT, device=mesh_device, dtype=ttnn.bfloat16, mesh_mapper=weights_mesh_mapper + ) + ttnn_linear_weight = ttnn.permute(ttnn_linear_weight, (1, 0)) + ttnn_linear_weight = ttnn.to_layout(ttnn_linear_weight, layout=ttnn.TILE_LAYOUT) + + if len(input_data) < batch_size: + assert False, "batch_size exceeds number of audio files available in folder" + + for i in range(batch_size): + input_file_path = input_data[i] + samplerate, data = wavfile.read(input_file_path) + inputs = feature_extractor(data, sampling_rate=samplerate, return_tensors="pt") + dtype_to_use = torch.bfloat16 + input_features = inputs.input_features.type(dtype_to_use) + batched_inputs = input_features if i == 0 else torch.cat((batched_inputs, input_features), dim=0) + + decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id + decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) + batched_decoder_input_ids = ( + decoder_input_ids if i == 0 else torch.cat((batched_decoder_input_ids, decoder_input_ids), dim=0) + ) + + profiler.start(f"preprocessing_inputs") + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=batched_inputs, + input_ids=batched_decoder_input_ids, + attention_mask=None, + parameters=ttnn_parameters, + device=mesh_device, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + weights_mesh_mapper=weights_mesh_mapper, + inputs_mesh_mapper=inputs_mesh_mapper, + output_mesh_composer=output_mesh_composer, + ) + profiler.end(f"preprocessing_inputs") + + generation_config = hf_reference_model.generation_config + ttnn_output = run_generate( + config, + input_embeds, + batched_inputs, + ttnn_model, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=ttnn_parameters, + ttnn_linear_weight=ttnn_linear_weight, + device=mesh_device, + decoder_input_ids=batched_decoder_input_ids, + generation_config=generation_config, + batch_size=batch_size, + max_tokens=max_tokens, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + + profiler.start(f"post_processing_output_to_string") + ttnn_transcription = processor.batch_decode(ttnn_output, skip_special_tokens=True) + profiler.end(f"post_processing_output_to_string") + + logger.info("Model Output") + logger.info(ttnn_transcription) + + measurements = { + "preprocessing_input": profiler.get("preprocessing_input"), + "inference_time": profiler.get("inference_time"), + "post_processing": profiler.get("post_processing_output_to_string"), + } + + logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") + logger.info(f"inference_time: {measurements['inference_time']} s") + logger.info(f"post_processing : {measurements['post_processing']} s") + + return measurements, ttnn_transcription + + +def run_demo_functional_whisper_for_audio_classification_dataset( + mesh_device, reset_seeds, model_name, ttnn_model, batch_size, n_iterations, whisper_memory_config +): + feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name) + model = transformers.WhisperForAudioClassification.from_pretrained(model_name).eval() + + ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) + sample = iter(ds) + + reference_labels = [] + predicted_labels = [] + config = model.config + + tt_model_name = f"ttnn_{model_name}_optimized" + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + + for _ in range(n_iterations): + batch_input = [] + for i in range(batch_size): + s = next(sample) + inputs = feature_extractor(s["audio"]["array"], sampling_rate=16000, return_tensors="pt") + input_features = inputs.input_features.type(torch.bfloat16) + batch_input = input_features if i == 0 else torch.cat((batch_input, input_features), dim=0) + reference_labels.append(s["language"]) + + input_embedding = ttnn_model.preprocess_encoder_inputs( + input_features=batch_input, + parameters=ttnn_parameters.encoder, + device=mesh_device, + whisper_memory_config=whisper_memory_config, + weights_mesh_mapper=weights_mesh_mapper, + inputs_mesh_mapper=inputs_mesh_mapper, + output_mesh_composer=output_mesh_composer, + ) + + out_logits = ttnn_model.whisper_for_audio_classification( + config=config, + inputs_embeds=input_embedding, + parameters=ttnn_parameters, + device=mesh_device, + batch_size=batch_size, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + logits_torch = ttnn.to_torch(out_logits, mesh_composer=output_mesh_composer) + + for i in range(batch_size): + single_logits_torch = logits_torch[i].squeeze(0) + predicted_class_ids = torch.argmax(single_logits_torch).item() + predicted_label = model.config.id2label[predicted_class_ids] + predicted_labels.append(predicted_label) + + accuracy = accuracy_score(reference_labels, predicted_labels) + logger.info(f"Reference labels: {reference_labels}") + logger.info(f"Predicted labels: {predicted_labels}") + logger.info(f"Accuracy: {accuracy}") + return accuracy + + +def run_demo_functional_whisper_for_conditional_generation_dataset( + mesh_device, + reset_seeds, + model_name, + ttnn_model, + batch_size=1, + n_iterations=1, + max_tokens=32, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, +): + model = transformers.WhisperModel.from_pretrained(model_name).eval() + config = transformers.WhisperConfig.from_pretrained(model_name) + processor = transformers.AutoProcessor.from_pretrained(model_name, language="English", task="transcribe") + hf_reference_model = transformers.WhisperForConditionalGeneration.from_pretrained(model_name) + + feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + sample = iter(ds) + batched_ground_truth_transcriptions = [] + + tt_model_name = f"ttnn_{model_name}_optimized" + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + + model = model.eval() + + linear_weight = hf_reference_model.proj_out.weight + ttnn_linear_weight = ttnn.from_torch( + linear_weight, layout=ttnn.TILE_LAYOUT, device=mesh_device, dtype=ttnn.bfloat16, mesh_mapper=weights_mesh_mapper + ) + ttnn_linear_weight = ttnn.permute(ttnn_linear_weight, (1, 0)) + ttnn_linear_weight = ttnn.to_layout(ttnn_linear_weight, layout=ttnn.TILE_LAYOUT) + + for _ in range(n_iterations): + for i in range(batch_size): + s = next(sample) + inputs = feature_extractor(s["audio"]["array"], sampling_rate=16000, return_tensors="pt") + ground_truth_transcriptions = s["text"] + dtype_to_use = torch.bfloat16 + input_features = inputs.input_features.type(dtype_to_use) + + batched_inputs = input_features if i == 0 else torch.cat((batched_inputs, input_features), dim=0) + + decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id + decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) + batched_decoder_input_ids = ( + decoder_input_ids if i == 0 else torch.cat((batched_decoder_input_ids, decoder_input_ids), dim=0) + ) + + batched_ground_truth_transcriptions.append(ground_truth_transcriptions) + + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=batched_inputs, + input_ids=batched_decoder_input_ids, + attention_mask=None, + parameters=ttnn_parameters, + device=mesh_device, + whisper_memory_config=whisper_memory_config, + weights_mesh_mapper=weights_mesh_mapper, + inputs_mesh_mapper=inputs_mesh_mapper, + output_mesh_composer=output_mesh_composer, + ) + ttnn_output = run_generate( + config, + input_embeds, + batched_inputs, + ttnn_model, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=ttnn_parameters, + ttnn_linear_weight=ttnn_linear_weight, + device=mesh_device, + decoder_input_ids=batched_decoder_input_ids, + generation_config=hf_reference_model.generation_config, + batch_size=batch_size, + max_tokens=max_tokens, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + ttnn_transcription = processor.batch_decode(ttnn_output, skip_special_tokens=True) + + logger.info("Model Output") + logger.info(ttnn_transcription) + + wer = WordErrorRate() + wer_scores = [] + for transcription, ground_truth in zip(ttnn_transcription, batched_ground_truth_transcriptions): + transcription = transcription.upper() + individual_wer_score = wer([transcription], [ground_truth]) + wer_scores.append(individual_wer_score) + logger.info(f"Individual Sample WER score: {individual_wer_score}") + + average_wer_score = sum(wer_scores) / len(wer_scores) + logger.info(f"Average WER score: {average_wer_score}") + accuracy = 1 - average_wer_score + logger.info(f"Accuracy: {accuracy}") + + return average_wer_score + + +@run_for_wormhole_b0() +@pytest.mark.parametrize( + "model_name, input_loc", + ( + ( + [ + "sanchit-gandhi/whisper-medium-fleurs-lang-id", + "models/demos/wormhole/whisper/demo/dataset/audio_classification", + ] + ), + ), +) +@pytest.mark.parametrize( + ("ttnn_model", "num_inputs", "batch_size", "WHISPER_MEMORY_CONFIG"), + ((ttnn_optimized_functional_whisper, 1, 4, ttnn.DRAM_MEMORY_CONFIG),), +) +def test_demo_for_audio_classification( + mesh_device, + reset_seeds, + use_program_cache, + model_name, + input_loc, + ttnn_model, + num_inputs, + batch_size, + WHISPER_MEMORY_CONFIG, +): + disable_persistent_kernel_cache() + disable_compilation_reports() + return run_demo_functional_whisper_for_audio_classification_inference( + mesh_device, + model_name=model_name, + input_path=input_loc, + ttnn_model=ttnn_model, + num_inputs=num_inputs, + batch_size=batch_size, + whisper_memory_config=WHISPER_MEMORY_CONFIG, + ) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "model_name, input_loc", + ((["openai/whisper-tiny.en", "models/demos/wormhole/whisper/demo/dataset/conditional_generation"]),), +) +@pytest.mark.parametrize( + ("ttnn_model", "batch_size", "max_tokens", "WHISPER_MEMORY_CONFIG"), + ((ttnn_optimized_functional_whisper, 4, 32, ttnn.DRAM_MEMORY_CONFIG),), +) +def test_demo_for_conditional_generation( + mesh_device, + reset_seeds, + use_program_cache, + model_name, + input_loc, + ttnn_model, + batch_size, + max_tokens, + WHISPER_MEMORY_CONFIG, +): + disable_persistent_kernel_cache() + disable_compilation_reports() + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + return run_demo_functional_whisper_for_conditional_generation_inference( + mesh_device, + reset_seeds, + batch_size=batch_size, + model_name=model_name, + input_path=input_loc, + ttnn_model=ttnn_model, + max_tokens=max_tokens, + whisper_memory_config=WHISPER_MEMORY_CONFIG, + ) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "model_name", + (["sanchit-gandhi/whisper-medium-fleurs-lang-id"]), +) +@pytest.mark.parametrize( + ("ttnn_model", "batch_size", "n_iterations", "WHISPER_MEMORY_CONFIG"), + ((ttnn_optimized_functional_whisper, 8, 1, ttnn.DRAM_MEMORY_CONFIG),), +) +def test_demo_for_audio_classification_dataset( + mesh_device, reset_seeds, use_program_cache, model_name, ttnn_model, batch_size, n_iterations, WHISPER_MEMORY_CONFIG +): + disable_persistent_kernel_cache() + disable_compilation_reports() + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + return run_demo_functional_whisper_for_audio_classification_dataset( + mesh_device, + reset_seeds, + model_name=model_name, + ttnn_model=ttnn_model, + batch_size=batch_size, + n_iterations=n_iterations, + whisper_memory_config=WHISPER_MEMORY_CONFIG, + ) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "model_name", + (["openai/whisper-tiny.en"]), +) +@pytest.mark.parametrize( + ("ttnn_model", "batch_size", "n_iterations", "max_tokens", "WHISPER_MEMORY_CONFIG"), + ((ttnn_optimized_functional_whisper, 8, 1, 32, ttnn.DRAM_MEMORY_CONFIG),), +) +def test_demo_for_conditional_generation_dataset( + mesh_device, + reset_seeds, + use_program_cache, + model_name, + ttnn_model, + batch_size, + n_iterations, + max_tokens, + WHISPER_MEMORY_CONFIG, +): + disable_persistent_kernel_cache() + disable_compilation_reports() + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + return run_demo_functional_whisper_for_conditional_generation_dataset( + mesh_device, + reset_seeds, + model_name=model_name, + ttnn_model=ttnn_model, + batch_size=batch_size, + n_iterations=n_iterations, + max_tokens=max_tokens, + whisper_memory_config=WHISPER_MEMORY_CONFIG, + ) diff --git a/models/demos/wormhole/whisper/tests/test_performance.py b/models/demos/wormhole/whisper/tests/test_performance.py new file mode 100644 index 00000000000..7921bd67827 --- /dev/null +++ b/models/demos/wormhole/whisper/tests/test_performance.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import time +import torch +import pytest +import transformers +from loguru import logger +from datasets import load_dataset +from models.perf.perf_utils import prep_perf_report +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.wormhole.whisper.tt import ttnn_optimized_functional_whisper +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report +from models.utility_functions import is_wormhole_b0, is_grayskull, skip_for_grayskull, run_for_wormhole_b0 + + +def get_expected_times(functional_whisper): + return {ttnn_optimized_functional_whisper: (43.84, 10.9)}[functional_whisper] + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["openai/whisper-base"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [500]) +@pytest.mark.parametrize("functional_whisper", [ttnn_optimized_functional_whisper]) +def test_performance(mesh_device, use_program_cache, model_name, batch_size, sequence_size, functional_whisper): + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + config = transformers.WhisperConfig.from_pretrained(model_name) + tt_model_name = f"ttnn_{model_name}_optimized" + config = transformers.WhisperConfig.from_pretrained(model_name) + feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + inputs = feature_extractor( + [ds[i]["audio"]["array"] for i in range(batch_size)], sampling_rate=16000, return_tensors="pt" + ) + input_features = inputs.input_features + decoder_input_ids = torch.tensor([[1, 1]] * batch_size) * config.decoder_start_token_id + model = transformers.WhisperModel.from_pretrained(model_name) + + attention_mask = None + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=functional_whisper.convert_to_ttnn, + custom_preprocessor=functional_whisper.custom_preprocessor, + device=mesh_device, + ) + + durations = [] + for _ in range(2): + (input_embeds, decoder_hidden_states, decoder_attention_mask) = functional_whisper.preprocess_inputs( + config=config, + input_features=input_features, + input_ids=decoder_input_ids, + attention_mask=attention_mask, + parameters=ttnn_parameters, + device=mesh_device, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + weights_mesh_mapper=weights_mesh_mapper, + inputs_mesh_mapper=inputs_mesh_mapper, + output_mesh_composer=output_mesh_composer, + ) + + start = time.time() + tt_output = functional_whisper.whisper( + config, + mesh_device, + input_embeds, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer) + end = time.time() + + duration = end - start + durations.append(duration) + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times(functional_whisper) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=durations[0], + inference_time=durations[1], + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + +@run_for_wormhole_b0() +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, test", + [ + [ + 8, + "silicon_arch_name=wormhole_b0-silicon_arch_wormhole_b0=True-ttnn_model=models.demos.wormhole.whisper.tt.ttnn_optimized_functional_whisper-model_name=openai/whisper-base-batch_size=8", + ], + ], +) +def test_perf_device_bare_metal(batch_size, test): + subdir = "ttnn_whisper" + num_iterations = 1 + margin = 0.03 + expected_perf = 35.97 + + command = ( + f"pytest tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper_wh.py::test_ttnn_whisper" + ) + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=False) + prep_device_perf_report( + model_name=f"ttnn_whisper_{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments=test.replace("/", "_"), + ) diff --git a/models/demos/wormhole/whisper/tt/ttnn_functional_whisper.py b/models/demos/wormhole/whisper/tt/ttnn_functional_whisper.py new file mode 100644 index 00000000000..f9ec760b230 --- /dev/null +++ b/models/demos/wormhole/whisper/tt/ttnn_functional_whisper.py @@ -0,0 +1,444 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import transformers +import torch +from typing import Optional, Tuple + +from torch.nn import functional as F +from ttnn.model_preprocessing import preprocess_linear_weight, preprocess_linear_bias +import ttnn + + +def gelu(tensor): + return ttnn.gelu(tensor) + + +def dropout(hidden_states, p, training): + # ignored for inference + return hidden_states + + +def calculate_key_values(config, key_value_states, *, parameters): + bsz, tgt_len, hidden_size = key_value_states.shape + bsz, tgt_len_padded, _ = key_value_states.shape.with_tile_padding() + head_size = hidden_size // config.encoder_attention_heads + + fused_qkv = key_value_states @ parameters.key_value.weight + parameters.key_value.bias + + dtype = fused_qkv.dtype + device = fused_qkv.device() + fused_qkv = ttnn.to_torch(fused_qkv) + fused_qkv = torch.reshape(fused_qkv, (bsz, tgt_len, 2, config.encoder_attention_heads, head_size)) + key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :] + + key_states = ttnn.from_torch(key_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + value_states = ttnn.from_torch(value_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + key_states = ttnn.permute(key_states, (0, 2, 1, 3)) + value_states = ttnn.permute(value_states, (0, 2, 1, 3)) + key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) + value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) + + desired_shape = ttnn.Shape( + [bsz, config.encoder_attention_heads, tgt_len, head_size], + [bsz, config.encoder_attention_heads, tgt_len_padded, head_size], + ) + key_states = ttnn.reshape(key_states, shape=desired_shape) + value_states = ttnn.reshape(value_states, shape=desired_shape) + + return key_states, value_states + + +def split_query_key_value_and_split_heads( + config, fused_qkv: ttnn.Tensor +) -> Tuple[ttnn.Tensor, ttnn.Tensor, ttnn.Tensor]: + head_size = config.d_model // config.encoder_attention_heads + batch_size, *_, seq_length, three_times_hidden_size = fused_qkv.shape + batch_size, *_, padded_seq_length, three_times_hidden_size = fused_qkv.shape.with_tile_padding() + hidden_size = three_times_hidden_size // 3 + encoder_attention_heads = hidden_size // head_size + + dtype = fused_qkv.dtype + device = fused_qkv.device() + fused_qkv = ttnn.to_torch(fused_qkv) + fused_qkv = torch.reshape(fused_qkv, shape=(batch_size, seq_length, 3, encoder_attention_heads, head_size)) + query_states, key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :], fused_qkv[..., 2, :, :] + + query_states = ttnn.from_torch(query_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + key_states = ttnn.from_torch(key_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + value_states = ttnn.from_torch(value_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + query_states = ttnn.permute(query_states, (0, 2, 1, 3)) + key_states = ttnn.permute(key_states, (0, 2, 1, 3)) + value_states = ttnn.permute(value_states, (0, 2, 1, 3)) + + query_states = ttnn.to_layout(query_states, ttnn.TILE_LAYOUT) + key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) + value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) + + desired_shape = ttnn.Shape( + [batch_size, encoder_attention_heads, seq_length, head_size], + [batch_size, encoder_attention_heads, padded_seq_length, head_size], + ) + query_states = ttnn.reshape(query_states, shape=desired_shape) + key_states = ttnn.reshape(key_states, shape=desired_shape) + value_states = ttnn.reshape(value_states, shape=desired_shape) + return query_states, key_states, value_states + + +def calculate_query_key_values(config, hidden_states, *, parameters): + fused_qkv = hidden_states @ parameters.query_key_value.weight + parameters.query_key_value.bias + return split_query_key_value_and_split_heads(config, fused_qkv) + + +def whisper_attention(config, hidden_states, attention_mask, key_value_states=None, *, parameters): + head_size = config.d_model // config.encoder_attention_heads + scaling = head_size**-0.5 + bsz, *_, padded_tgt_len, _ = hidden_states.shape.with_tile_padding() + bsz, *_, tgt_len, _ = hidden_states.shape + + is_cross_attention = key_value_states is not None + if is_cross_attention: + query_states = hidden_states @ parameters.q_proj.weight + parameters.q_proj.bias + + dtype = query_states.dtype + device = query_states.device() + query_states = ttnn.to_torch(query_states) + query_states = torch.reshape(query_states, (bsz, tgt_len, config.encoder_attention_heads, head_size)) + query_states = ttnn.from_torch(query_states, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) + query_states = ttnn.permute(query_states, (0, 2, 1, 3)) + key_states, value_states = calculate_key_values(config, key_value_states, parameters=parameters) + padded_key_value_tgt_len = key_states.shape.with_tile_padding()[2] + key_value_tgt_len = key_states.shape[2] + else: + query_states, key_states, value_states = calculate_query_key_values( + config, hidden_states, parameters=parameters + ) + padded_key_value_tgt_len = padded_tgt_len + key_value_tgt_len = tgt_len + + query_states *= scaling + + proj_shape = ttnn.Shape( + [bsz * config.encoder_attention_heads, tgt_len, head_size], + [bsz * config.encoder_attention_heads, padded_tgt_len, head_size], + ) + query_states = ttnn.reshape(query_states, shape=proj_shape) + proj_shape = ttnn.Shape( + [bsz * config.encoder_attention_heads, key_value_tgt_len, head_size], + [bsz * config.encoder_attention_heads, padded_key_value_tgt_len, head_size], + ) + key_states = ttnn.reshape(key_states, shape=proj_shape) + value_states = ttnn.reshape(value_states, shape=proj_shape) + + query_states = ttnn.to_layout(query_states, layout=ttnn.TILE_LAYOUT) + key_states = ttnn.to_layout(key_states, layout=ttnn.TILE_LAYOUT) + value_states = ttnn.to_layout(value_states, layout=ttnn.TILE_LAYOUT) + + attn_weights = query_states @ ttnn.permute(key_states, (0, 2, 1)) + if attention_mask is not None: + bsz, _, tgt_len, src_len = attention_mask.shape + attn_weights = ttnn.to_layout(attn_weights, layout=ttnn.ROW_MAJOR_LAYOUT) + attn_weights = ttnn.reshape(attn_weights, shape=(bsz, config.encoder_attention_heads, tgt_len, src_len)) + attn_weights = ttnn.to_layout(attn_weights, layout=ttnn.TILE_LAYOUT) + attn_weights = attn_weights + attention_mask + attn_weights = ttnn.to_layout(attn_weights, layout=ttnn.ROW_MAJOR_LAYOUT) + attn_weights = ttnn.reshape(attn_weights, shape=(bsz * config.encoder_attention_heads, tgt_len, src_len)) + attn_weights = ttnn.to_layout(attn_weights, layout=ttnn.TILE_LAYOUT) + + # differences in ttnn.softmax vs torch.softmax cause the attn_weights to be slightly different + attn_weights = ttnn.softmax(attn_weights, dim=-1) + + attn_probs = dropout(attn_weights, p=0, training=False) + attn_output = attn_probs @ value_states + attn_output = ttnn.to_layout(attn_output, layout=ttnn.ROW_MAJOR_LAYOUT) + attn_output = ttnn.reshape(attn_output, shape=(bsz, config.encoder_attention_heads, tgt_len, head_size)) + attn_output = ttnn.permute(attn_output, (0, 2, 1, 3)) + + dtype = attn_output.dtype + device = attn_output.device() + attn_output = ttnn.to_torch(attn_output) + attn_output = torch.reshape(attn_output, (bsz, tgt_len, config.d_model)) + attn_output = ttnn.from_torch(attn_output, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) + attn_output = attn_output @ parameters.out_proj.weight + parameters.out_proj.bias + return attn_output + + +def encoder_layer(config, hidden_states, *, parameters): + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.self_attn_layer_norm.weight, + bias=parameters.self_attn_layer_norm.bias, + ) + + hidden_states = whisper_attention(config, hidden_states, attention_mask=None, parameters=parameters.self_attn) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.final_layer_norm.weight, + bias=parameters.final_layer_norm.bias, + ) + hidden_states = hidden_states @ parameters.fc1.weight + parameters.fc1.bias + hidden_states = gelu(hidden_states) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = hidden_states @ parameters.fc2.weight + parameters.fc2.bias + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = residual + hidden_states + + # if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): + # clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + # hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states + + +def encoder(config, inputs_embeds, *, parameters): + hidden_states = inputs_embeds + parameters.embed_positions.weight + hidden_states = dropout(hidden_states, p=0, training=False) + + for encoder_layer_parameter in parameters.layers: + hidden_states = encoder_layer(config, hidden_states, parameters=encoder_layer_parameter) + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.layer_norm.weight, + bias=parameters.layer_norm.bias, + ) + return hidden_states + + +def make_causal_mask(input_ids_shape, dtype): + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.shape + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def decoder_layer(config, hidden_states, attention_mask, encoder_hidden_states, *, parameters): + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.self_attn_layer_norm.weight, + bias=parameters.self_attn_layer_norm.bias, + ) + + hidden_states = whisper_attention( + config, + hidden_states=hidden_states, + attention_mask=attention_mask, + parameters=parameters.self_attn, + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = residual + hidden_states + + # Cross-Attention Block + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.encoder_attn_layer_norm.weight, + bias=parameters.encoder_attn_layer_norm.bias, + ) + + hidden_states = whisper_attention( + config, + hidden_states, + attention_mask=None, + key_value_states=encoder_hidden_states, + parameters=parameters.encoder_attn, + ) + + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.final_layer_norm.weight, + bias=parameters.final_layer_norm.bias, + ) + hidden_states = hidden_states @ parameters.fc1.weight + parameters.fc1.bias + hidden_states = gelu(hidden_states) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = hidden_states @ parameters.fc2.weight + parameters.fc2.bias + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = residual + hidden_states + + return hidden_states + + +def prepare_decoder_attention_mask(attention_mask, input_shape, input_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + if input_shape[-1] > 1: + combined_attention_mask = make_causal_mask(input_shape, input_embeds.dtype) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = expand_mask(attention_mask, input_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +def decoder(config, hidden_states, decoder_attention_mask, encoder_hidden_states, *, parameters): + hidden_states = dropout(hidden_states, p=0, training=False) + + for decoder_layer_parameter in parameters.layers: + hidden_states = decoder_layer( + config, + hidden_states, + decoder_attention_mask, + encoder_hidden_states, + parameters=decoder_layer_parameter, + ) + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.layer_norm.weight, + bias=parameters.layer_norm.bias, + ) + + return hidden_states + + +def convert_to_ttnn(model, name): + return name not in [ + "encoder.conv1", + "encoder.conv2", + "decoder.embed_tokens", + "decoder.embed_positions", + ] + + +def preprocess_encoder_inputs(input_features, *, parameters, device): + def conv(input, weight, bias, stride=1, padding=1, dilation=1, groups=1): + return F.conv1d(input, weight, bias, stride, padding, dilation, groups) + + input_embeds = torch.nn.functional.gelu( + conv( + input_features, + weight=parameters.conv1.weight, + bias=parameters.conv1.bias, + padding=1, + ) + ) + input_embeds = torch.nn.functional.gelu( + conv( + input_embeds, + weight=parameters.conv2.weight, + bias=parameters.conv2.bias, + stride=2, + padding=1, + ) + ) + input_embeds = input_embeds.permute(0, 2, 1) + input_embeds = ttnn.from_torch(input_embeds, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + return input_embeds + + +def preprocess_decoder_inputs(config, input_ids, attention_mask, *, parameters, device): + input_shape = input_ids.size() + input_ids = torch.reshape(input_ids, (-1, input_shape[-1])) + inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight) + attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds) + # ttnn cannot broadcast when adding on the batch or channel dimensions so this is a workaround + attention_mask = attention_mask.expand(-1, config.encoder_attention_heads, -1, -1) + + positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]] + decoder_hidden_states = inputs_embeds + positions + + decoder_hidden_states = ttnn.from_torch( + decoder_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + return decoder_hidden_states, attention_mask + + +def preprocess_inputs( + *, + config, + input_features, + input_ids, + attention_mask, + parameters, + device, +): + input_embeds = preprocess_encoder_inputs(input_features, parameters=parameters.encoder, device=device) + (decoder_hidden_states, attention_mask) = preprocess_decoder_inputs( + config, input_ids, attention_mask, parameters=parameters.decoder, device=device + ) + return input_embeds, decoder_hidden_states, attention_mask + + +def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters): + encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder) + last_hidden_state = decoder( + config, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + parameters=parameters.decoder, + ) + return last_hidden_state + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): + height, width = torch_model.k_proj.weight.shape + + if "encoder_attn" in name: + parameters = {"key_value": {}, "q_proj": {}, "out_proj": {}} + preprocessed_weight = torch.cat([torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0) + preprocessed_bias = torch.cat([torch.zeros(height), torch_model.v_proj.bias], dim=0) + parameters["key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) + parameters["key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) + parameters["q_proj"]["weight"] = preprocess_linear_weight(torch_model.q_proj.weight, dtype=ttnn.bfloat16) + parameters["q_proj"]["bias"] = preprocess_linear_bias(torch_model.q_proj.bias, dtype=ttnn.bfloat16) + else: + parameters = {"query_key_value": {}, "out_proj": {}} + preprocessed_weight = torch.cat( + [torch_model.q_proj.weight, torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0 + ) + preprocessed_bias = torch.cat( + [torch_model.q_proj.bias, torch.zeros(height), torch_model.v_proj.bias], dim=0 + ) + parameters["query_key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) + + parameters["out_proj"]["weight"] = preprocess_linear_weight(torch_model.out_proj.weight, dtype=ttnn.bfloat16) + parameters["out_proj"]["bias"] = preprocess_linear_bias(torch_model.out_proj.bias, dtype=ttnn.bfloat16) + elif name == "encoder.embed_positions" and isinstance(torch_model, torch.nn.Embedding): + embeddings = ttnn.from_torch(torch_model.weight, dtype=ttnn.bfloat16) + embeddings = ttnn.to_layout(embeddings, ttnn.TILE_LAYOUT) + parameters["weight"] = embeddings + return parameters diff --git a/models/demos/wormhole/whisper/tt/ttnn_optimized_functional_whisper.py b/models/demos/wormhole/whisper/tt/ttnn_optimized_functional_whisper.py new file mode 100644 index 00000000000..afd97175970 --- /dev/null +++ b/models/demos/wormhole/whisper/tt/ttnn_optimized_functional_whisper.py @@ -0,0 +1,782 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +import transformers +from typing import Optional +from torch.nn import functional as F +from ttnn.model_preprocessing import preprocess_linear_weight, preprocess_linear_bias + + +WHISPER_DTYPE = ttnn.bfloat8_b + + +def dropout(hidden_states, p, training): + # ignored for inference + return hidden_states + + +# The split_query_key_value_and_split_heads requires the query to have the same volume as the key and values +# This is not the case however for whisper so we currently cannot swap out calculate_key_values below +# def calculate_key_values(config, query_states, key_value_states, *, parameters): +# fused_kv = key_value_states @ parameters.key_value.weight + parameters.key_value.bias +# head_size = config.d_model // config.encoder_attention_heads +# batch_size, *_, _, two_times_hidden_size = fused_kv.shape.with_tile_padding() +# hidden_size = two_times_hidden_size // 2 +# encoder_attention_heads = hidden_size // head_size +# query_states, key_states, value_states = ttnn.transformer.split_query_key_value_and_split_heads( +# query_states, +# kv_input_tensor=fused_kv, +# num_heads=encoder_attention_heads, +# memory_config=WHISPER_MEMORY_CONFIG, +# ) +# key_states = ttnn.permute(key_states, (0, 1, 3, 2)) +# return query_states, key_states, value_states + + +def calculate_key_values( + config, device, key_value_states, *, parameters, whisper_memory_config, output_mesh_composer, inputs_mesh_mapper +): + bsz, tgt_len, hidden_size = key_value_states.shape + bsz, tgt_len_padded, _ = key_value_states.shape.with_tile_padding() + head_size = hidden_size // config.encoder_attention_heads + fused_qkv = ttnn.linear( + key_value_states, + parameters.weight, + bias=parameters.bias, + memory_config=whisper_memory_config, + ) + dtype = fused_qkv.dtype + + fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.ROW_MAJOR_LAYOUT) + fused_qkv = ttnn.from_device(fused_qkv) + fused_qkv = ttnn.reshape(fused_qkv, (bsz, tgt_len, 2, config.encoder_attention_heads, head_size)) + + fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.TILE_LAYOUT) + fused_qkv = ttnn.to_device(fused_qkv, device=device) + # #13672: Slice op doesn't support for 5d tensors. + fused_qkv = ttnn.to_torch(fused_qkv, mesh_composer=output_mesh_composer) + key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :] # + key_states = ttnn.from_torch( + key_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, mesh_mapper=inputs_mesh_mapper + ) + value_states = ttnn.from_torch( + value_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, mesh_mapper=inputs_mesh_mapper + ) + + key_states = ttnn.permute(key_states, (0, 2, 3, 1)) + value_states = ttnn.permute(value_states, (0, 2, 1, 3)) + + key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) + value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) + + desired_shape = ttnn.Shape( + [bsz, config.encoder_attention_heads, head_size, tgt_len], + [bsz, config.encoder_attention_heads, head_size, tgt_len_padded], + ) + key_states = ttnn.reshape(key_states, shape=desired_shape) + + desired_shape = ttnn.Shape( + [bsz, config.encoder_attention_heads, tgt_len, head_size], + [bsz, config.encoder_attention_heads, tgt_len_padded, head_size], + ) + value_states = ttnn.reshape(value_states, shape=desired_shape) + + return key_states, value_states + + +def calculate_query_key_values(config, hidden_states, *, parameters, whisper_memory_config): + fused_qkv = ttnn.linear( + hidden_states, + parameters.weight, + bias=parameters.bias, + ) + + return ttnn.transformer.split_query_key_value_and_split_heads( + fused_qkv, memory_config=whisper_memory_config, num_heads=config.num_attention_heads + ) + + +def whisper_attention( + config, + device, + hidden_states, + attention_mask, + key_value_states=None, + *, + parameters, + whisper_memory_config, + output_mesh_composer, + inputs_mesh_mapper, +): + head_size = config.d_model // config.encoder_attention_heads + scaling = head_size**-0.5 + bsz, *_, tgt_len, _ = hidden_states.shape + is_cross_attention = key_value_states is not None + + if is_cross_attention: + query_states = ttnn.linear( + hidden_states, + parameters.q_proj.weight, + bias=parameters.q_proj.bias, + memory_config=whisper_memory_config, + ) + query_states = ttnn.to_layout(query_states, layout=ttnn.ROW_MAJOR_LAYOUT) + query_states = ttnn.from_device(query_states) + query_states = ttnn.reshape(query_states, (bsz, tgt_len, config.encoder_attention_heads, head_size)) + query_states = ttnn.to_layout(query_states, layout=ttnn.TILE_LAYOUT) + query_states = ttnn.to_device(query_states, device=device) + query_states = ttnn.permute(query_states, (0, 2, 1, 3)) + + key_states, value_states = calculate_key_values( + config, + device, + key_value_states, + parameters=parameters.key_value, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + else: + query_states, key_states, value_states = calculate_query_key_values( + config, hidden_states, parameters=parameters.query_key_value, whisper_memory_config=whisper_memory_config + ) + + query_states *= scaling + attn_weights = ttnn.matmul(query_states, key_states) + + if attention_mask is not None: + attn_weights = ttnn.add(attn_weights, attention_mask) + + # differences in ttnn.softmax vs torch.softmax cause the attn_weights to be slightly different + attn_weights = ttnn.softmax(attn_weights, dim=-1) + attn_probs = dropout(attn_weights, p=0, training=False) + attn_output = ttnn.matmul(attn_probs, value_states, memory_config=whisper_memory_config) + + ttnn.deallocate(attn_probs) + ttnn.deallocate(attn_weights) + ttnn.deallocate(query_states) + attn_output = ttnn.transformer.concatenate_heads(attn_output) + attn_output = ttnn.linear( + attn_output, + parameters.out_proj.weight, + bias=parameters.out_proj.bias, + memory_config=whisper_memory_config, + ) + + return attn_output + + +def encoder_layer( + config, + device, + hidden_states, + *, + parameters, + whisper_memory_config, + output_mesh_composer, + inputs_mesh_mapper, +): + residual = hidden_states + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.self_attn_layer_norm.weight, + bias=parameters.self_attn_layer_norm.bias, + memory_config=whisper_memory_config, + ) + + hidden_states = whisper_attention( + config, + device, + hidden_states, + attention_mask=None, + parameters=parameters.self_attn, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + residual = hidden_states + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.final_layer_norm.weight, + bias=parameters.final_layer_norm.bias, + memory_config=whisper_memory_config, + ) + + hidden_states = ttnn.linear( + hidden_states, + parameters.fc1.weight, + bias=parameters.fc1.bias, + ) + + hidden_states = ttnn.gelu(hidden_states, memory_config=whisper_memory_config) + hidden_states = dropout(hidden_states, p=0, training=False) + + hidden_states = ttnn.linear( + hidden_states, + parameters.fc2.weight, + bias=parameters.fc2.bias, + memory_config=whisper_memory_config, + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + + # if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): + # clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + # hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states + + +def encoder( + config, device, inputs_embeds, *, parameters, whisper_memory_config, output_mesh_composer, inputs_mesh_mapper +): + bs, _, _ = inputs_embeds.shape + weights = parameters.embed_positions.weight + weights = ttnn.concat([weights] * bs, dim=0) + hidden_states = ttnn.add(inputs_embeds, weights) + hidden_states = dropout(hidden_states, p=0, training=False) + + for encoder_layer_parameter in parameters.layers: + hidden_states = encoder_layer( + config, + device, + hidden_states, + parameters=encoder_layer_parameter, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.layer_norm.weight, + bias=parameters.layer_norm.bias, + ) + + return hidden_states + + +def make_causal_mask(input_ids_shape, dtype): + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.shape + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def decoder_layer( + config, + device, + hidden_states, + attention_mask, + encoder_hidden_states, + *, + parameters, + whisper_memory_config, + output_mesh_composer, + inputs_mesh_mapper, +): + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.self_attn_layer_norm.weight, + bias=parameters.self_attn_layer_norm.bias, + memory_config=whisper_memory_config, + ) + + hidden_states = whisper_attention( + config, + device, + hidden_states=hidden_states, + attention_mask=attention_mask, + parameters=parameters.self_attn, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + + # Cross-Attention Block + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.encoder_attn_layer_norm.weight, + bias=parameters.encoder_attn_layer_norm.bias, + ) + + hidden_states = whisper_attention( + config, + device, + hidden_states, + attention_mask=None, + key_value_states=encoder_hidden_states, + parameters=parameters.encoder_attn, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.final_layer_norm.weight, + bias=parameters.final_layer_norm.bias, + ) + + hidden_states = ttnn.linear( + hidden_states, parameters.fc1.weight, bias=parameters.fc1.bias, memory_config=whisper_memory_config + ) + hidden_states = ttnn.gelu(hidden_states, memory_config=whisper_memory_config) + hidden_states = dropout(hidden_states, p=0, training=False) + + hidden_states = ttnn.linear( + hidden_states, parameters.fc2.weight, bias=parameters.fc2.bias, memory_config=whisper_memory_config + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + return hidden_states + + +def prepare_decoder_attention_mask(attention_mask, input_shape, input_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + if input_shape[-1] > 1: + combined_attention_mask = make_causal_mask(input_shape, input_embeds.dtype) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = expand_mask(attention_mask, input_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +def decoder( + config, + device, + hidden_states, + decoder_attention_mask, + encoder_hidden_states, + *, + parameters, + whisper_memory_config, + output_mesh_composer, + inputs_mesh_mapper, +): + hidden_states = dropout(hidden_states, p=0, training=False) + for decoder_layer_parameter in parameters.layers: + hidden_states = decoder_layer( + config, + device, + hidden_states, + decoder_attention_mask, + encoder_hidden_states, + parameters=decoder_layer_parameter, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.layer_norm.weight, + bias=parameters.layer_norm.bias, + ) + + return hidden_states + + +def convert_to_ttnn(model, name): + return name not in [ + "encoder.conv1", + "encoder.conv2", + "decoder.embed_tokens", + "decoder.embed_positions", + ] + + +def preprocess_encoder_inputs( + input_features, + *, + parameters, + device, + whisper_memory_config, + weights_mesh_mapper, + inputs_mesh_mapper, + output_mesh_composer, +): + def conv(input, weight, bias, stride=1, padding=1, dilation=1, groups=1): + return F.conv1d(input, weight, bias, stride, padding, dilation, groups) + + def ttnn_conv1d( + device, + tt_input_tensor, + weights, + conv_params, + bias, + *, + output_dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat8_b, + math_fidelity=ttnn.MathFidelity.LoFi, + deallocate_activation=True, + act_block_h=32, + height_sharding=True, + use_shallow_conv_variant=False, + fp32_accum=False, + packer_l1_acc=False, + debug=False, + groups=1, + math_approx=False, + activation="", + reallocate_halo=False, + reshard=False, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, + weights_mesh_mapper=None, + inputs_mesh_mapper=None, + ): + weights = ttnn.from_torch(weights, dtype=ttnn.float32, mesh_mapper=weights_mesh_mapper) + bias = ttnn.from_torch( + bias.unsqueeze(0).unsqueeze(0).unsqueeze(0), dtype=ttnn.float32, mesh_mapper=weights_mesh_mapper + ) + + conv_config = ttnn.Conv1dConfig( + dtype=output_dtype, + weights_dtype=weights_dtype, + math_approx_mode_enabled=math_approx, + fp32_dest_acc_enabled=fp32_accum, + packer_l1_accum_enabled=packer_l1_acc, + activation=activation, + input_channels_alignment=(16 if use_shallow_conv_variant else 32), + deallocate_activation=deallocate_activation, + reallocate_halo_output=reallocate_halo, + act_block_h_override=act_block_h, + reshard_if_not_optimal=reshard, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), + math_fidelity=math_fidelity, + ) + [tt_output_tensor_on_device, out_length, weights_device, bias_device] = ttnn.Conv1d( + input_tensor=tt_input_tensor, + weight_tensor=weights, + in_channels=tt_input_tensor.shape[-1], + out_channels=weights.shape[0], + device=device, + bias_tensor=bias, + kernel_size=3, + stride=conv_params[0], + padding=conv_params[1], + batch_size=tt_input_tensor.shape[0], + input_length=tt_input_tensor.shape[1], + conv_config=conv_config, + conv_op_cache={}, + debug=debug, + groups=groups, + ) + tt_output_tensor_on_device = ttnn.squeeze(tt_output_tensor_on_device, 0) + tt_output_tensor_on_device = ttnn.to_layout(tt_output_tensor_on_device, layout=ttnn.ROW_MAJOR_LAYOUT) + tt_output_tensor_on_device = ttnn.reshape( + tt_output_tensor_on_device, (tt_input_tensor.shape[0], out_length, tt_output_tensor_on_device.shape[-1]) + ) + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) + + return tt_output_tensor + + if parameters.conv1.weight.shape[0] == 512: + input_features = ttnn.from_torch( + input_features, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT, mesh_mapper=inputs_mesh_mapper + ) + input_features = ttnn.permute(input_features, (0, 2, 1)) + conv1 = ttnn_conv1d( + device, + input_features, + parameters.conv1.weight, + [1, 1], + parameters.conv1.bias, + ) + conv1 = ttnn.to_layout(conv1, ttnn.TILE_LAYOUT) + conv1 = ttnn.to_device(conv1, device) + conv1 = ttnn.permute(conv1, (0, 2, 1)) + + else: + conv1 = conv( + input_features.float(), + weight=parameters.conv1.weight, + bias=parameters.conv1.bias, + stride=1, + padding=1, + ) + conv1 = ttnn.from_torch( + conv1, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device, mesh_mapper=inputs_mesh_mapper + ) + + input_embeds = ttnn.gelu(conv1, memory_config=whisper_memory_config) + input_embeds = ttnn.to_layout(input_embeds, layout=ttnn.ROW_MAJOR_LAYOUT) + + # input_embeds = ttnn.permute(input_embeds, (0, 2, 1)) + input_embeds = ttnn.to_torch(input_embeds, mesh_composer=output_mesh_composer) + + # #13529 ttnn.conv1d throws OOM here. + # conv2 = ttnn_conv1d( + # device, + # input_embeds, + # parameters.conv2.weight, + # [2, 1], + # parameters.conv2.bias, + # ) + # conv2 = ttnn.to_layout(conv2, ttnn.TILE_LAYOUT) + # conv2 = ttnn.to_device(conv2, device) + # conv2 = ttnn.permute(conv2, (0, 2, 1)) + # input_embeds = ttnn.gelu(conv2, memory_config=whisper_memory_config) + + conv = conv( + input_embeds.float(), + weight=parameters.conv2.weight, + bias=parameters.conv2.bias, + stride=2, + padding=1, + ) + conv = ttnn.from_torch( + conv, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device, mesh_mapper=inputs_mesh_mapper + ) + + input_embeds = ttnn.gelu(conv, memory_config=whisper_memory_config) + input_embeds = ttnn.permute(input_embeds, (0, 2, 1)) + + return input_embeds + + +def preprocess_decoder_inputs(config, input_ids, attention_mask, *, parameters, device, inputs_mesh_mapper): + input_shape = input_ids.size() + input_ids = torch.reshape(input_ids, (-1, input_shape[-1])) + inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight) + attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds) + # ttnn cannot broadcast when adding on the batch or channel dimensions so this is a workaround + attention_mask = attention_mask.expand(-1, config.encoder_attention_heads, -1, -1) + + positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]] + decoder_hidden_states = inputs_embeds + positions + + decoder_hidden_states = ttnn.from_torch( + decoder_hidden_states, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + mesh_mapper=inputs_mesh_mapper, + ) + attention_mask = ttnn.from_torch( + attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device, mesh_mapper=inputs_mesh_mapper + ) + + return decoder_hidden_states, attention_mask + + +def preprocess_inputs( + *, + config, + input_features, + input_ids, + attention_mask, + parameters, + device, + whisper_memory_config, + weights_mesh_mapper, + inputs_mesh_mapper, + output_mesh_composer, +): + input_embeds = preprocess_encoder_inputs( + input_features, + parameters=parameters.encoder, + device=device, + whisper_memory_config=whisper_memory_config, + weights_mesh_mapper=weights_mesh_mapper, + inputs_mesh_mapper=inputs_mesh_mapper, + output_mesh_composer=output_mesh_composer, + ) + + (decoder_hidden_states, attention_mask) = preprocess_decoder_inputs( + config, + input_ids, + attention_mask, + parameters=parameters.decoder, + device=device, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + return input_embeds, decoder_hidden_states, attention_mask + + +def whisper( + config, + device, + encoder_hidden_states, + decoder_hidden_states, + decoder_attention_mask, + *, + parameters, + whisper_memory_config, + output_mesh_composer, + inputs_mesh_mapper, +): + encoder_hidden_states = encoder( + config, + device, + encoder_hidden_states, + parameters=parameters.encoder, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + + last_hidden_state = decoder( + config, + device, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + parameters=parameters.decoder, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + + return last_hidden_state + + +def whisper_for_audio_classification( + config, + inputs_embeds, + *, + parameters, + device, + batch_size, + whisper_memory_config, + output_mesh_composer, + inputs_mesh_mapper, +): + encoder_outputs = encoder( + config=config, + device=device, + inputs_embeds=inputs_embeds, + parameters=parameters.encoder, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + hidden_states = ttnn.linear( + encoder_outputs, + parameters.projector.weight, + bias=parameters.projector.bias, + memory_config=whisper_memory_config, + ) + pooled_output = ttnn.mean(hidden_states, dim=-2, keepdim=True) + + logits = ttnn.linear( + pooled_output, + parameters.classifier.weight, + bias=parameters.classifier.bias, + memory_config=whisper_memory_config, + ) + return logits + + +def whisper_for_conditional_generation( + config, + input_embeds, + decoder_hidden_states, + decoder_attention_mask, + *, + parameters, + device, + ttnn_linear_weight, + whisper_memory_config, + output_mesh_composer, + inputs_mesh_mapper, +): + output = whisper( + config=config, + device=device, + encoder_hidden_states=input_embeds, + decoder_hidden_states=decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + whisper_memory_config=whisper_memory_config, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + ttnn_output = ttnn.matmul( + output, + ttnn_linear_weight, + dtype=ttnn.bfloat16, + ) + return ttnn_output + + +def preprocess_conv_parameter(parameter, *, dtype): + parameter = ttnn.from_torch(parameter, dtype=dtype) + return parameter + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): + height, width = torch_model.k_proj.weight.shape + + if "encoder_attn" in name: + parameters = {"key_value": {}, "q_proj": {}, "out_proj": {}} + preprocessed_weight = torch.cat([torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0) + preprocessed_bias = torch.cat([torch.zeros(height), torch_model.v_proj.bias], dim=0) + parameters["key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) + parameters["key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) + parameters["q_proj"]["weight"] = preprocess_linear_weight(torch_model.q_proj.weight, dtype=ttnn.bfloat16) + parameters["q_proj"]["bias"] = preprocess_linear_bias(torch_model.q_proj.bias, dtype=ttnn.bfloat16) + else: + parameters = {"query_key_value": {}, "out_proj": {}} + preprocessed_weight = torch.cat( + [torch_model.q_proj.weight, torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0 + ) + preprocessed_bias = torch.cat( + [torch_model.q_proj.bias, torch.zeros(height), torch_model.v_proj.bias], dim=0 + ) + parameters["query_key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) + + parameters["out_proj"]["weight"] = preprocess_linear_weight(torch_model.out_proj.weight, dtype=ttnn.bfloat16) + parameters["out_proj"]["bias"] = preprocess_linear_bias(torch_model.out_proj.bias, dtype=ttnn.bfloat16) + + elif name == "encoder.embed_positions" and isinstance(torch_model, torch.nn.Embedding): + embeddings = torch_model.weight.unsqueeze(0) + embeddings = ttnn.from_torch(embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + parameters["weight"] = embeddings + + return parameters diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index 7956d1c7b03..50334e453a7 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -21,6 +21,8 @@ run_perf_models_other() { env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/bert_tiny/tests/test_performance.py -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/yolov4/tests/test_perf_yolo.py -m $test_marker + + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/whisper/tests -m $test_marker fi env pytest -n auto tests/ttnn/integration_tests/bert/test_performance.py -m $test_marker @@ -121,6 +123,8 @@ run_device_perf_models() { env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/bert_tiny/tests -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/yolov4/tests/ -m $test_marker + + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/whisper/tests/test_performance.py -m $test_marker fi ## Merge all the generated reports diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index 5f5642483f6..10ca7a5f735 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -39,6 +39,9 @@ run_common_func_tests() { # Mnist pytest --disable-warnings models/demos/mnist/demo/demo.py --timeout 600; fail+=$? + # Whisper + pytest --disable-warnings models/demos/wormhole/whisper/demo/demo.py --timeout 600; fail+=$? + return $fail } diff --git a/tests/ttnn/integration_tests/whisper/test_performance.py b/tests/ttnn/integration_tests/whisper/test_performance.py deleted file mode 100644 index 288afa78719..00000000000 --- a/tests/ttnn/integration_tests/whisper/test_performance.py +++ /dev/null @@ -1,101 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import pytest -from models.experimental.functional_whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper -from transformers import AutoFeatureExtractor, WhisperModel, WhisperConfig -from datasets import load_dataset -import torch -from ttnn.model_preprocessing import preprocess_model_parameters -from loguru import logger -from models.utility_functions import is_wormhole_b0, is_blackhole -from models.perf.perf_utils import prep_perf_report -import time -import ttnn - - -def get_expected_times(functional_whisper): - return { - ttnn_functional_whisper: (11.7, 4.16), - ttnn_optimized_functional_whisper: (1.5, 1.35), - }[functional_whisper] - - -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Not tested on single WH") -@pytest.mark.models_performance_bare_metal -@pytest.mark.models_performance_virtual_machine -@pytest.mark.parametrize("model_name", ["openai/whisper-base"]) -@pytest.mark.parametrize("batch_size", [1]) -@pytest.mark.parametrize("sequence_size", [500]) -@pytest.mark.parametrize("functional_whisper", [ttnn_functional_whisper, ttnn_optimized_functional_whisper]) -def test_performance(device, use_program_cache, model_name, batch_size, sequence_size, functional_whisper): - config = WhisperConfig.from_pretrained(model_name) - - # Run TT Model - if functional_whisper == ttnn_functional_whisper: - tt_model_name = f"ttnn_{model_name}" - elif functional_whisper == ttnn_optimized_functional_whisper: - tt_model_name = f"ttnn_{model_name}_optimized" - else: - raise ValueError(f"Unknown functional_t5: {functional_whisper}") - - config = WhisperConfig.from_pretrained(model_name) - feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") - input_features = inputs.input_features - decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id - - attention_mask = None - - parameters = preprocess_model_parameters( - model_name=tt_model_name, - initialize_model=lambda: WhisperModel.from_pretrained(model_name).eval(), - convert_to_ttnn=functional_whisper.convert_to_ttnn, - custom_preprocessor=functional_whisper.custom_preprocessor, - device=device, - ) - - durations = [] - for _ in range(2): - (input_embeds, decoder_hidden_states, decoder_attention_mask) = functional_whisper.preprocess_inputs( - config=config, - input_features=input_features, - input_ids=decoder_input_ids, - attention_mask=attention_mask, - parameters=parameters, - device=device, - ) - - start = time.time() - tt_output = functional_whisper.whisper( - config, - input_embeds, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - parameters=parameters, - ) - tt_output = ttnn.to_torch(tt_output) - end = time.time() - - duration = end - start - durations.append(duration) - - inference_and_compile_time, inference_time, *_ = durations - - expected_compile_time, expected_inference_time = get_expected_times(functional_whisper) - prep_perf_report( - model_name=tt_model_name, - batch_size=batch_size, - inference_and_compile_time=durations[0], - inference_time=durations[1], - expected_compile_time=expected_compile_time, - expected_inference_time=expected_inference_time, - comments="", - inference_time_cpu=0.0, - ) - - logger.info(f"Compile time: {inference_and_compile_time - inference_time}") - logger.info(f"Inference time: {inference_time}") - logger.info(f"Samples per second: {1 / inference_time * batch_size}") diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py index e6f02bf3203..5c2af553ef2 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -19,7 +19,6 @@ MODEL_NAME = "openai/whisper-base" -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -84,7 +83,6 @@ def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_ assert_with_pcc(torch_output, output, 0.98) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -120,7 +118,6 @@ def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size assert_with_pcc(torch_output, output, pcc=0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -174,7 +171,6 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque assert_with_pcc(torch_output, output, 0.97) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -233,7 +229,6 @@ def test_decoder_layer(device, ttnn_model, model_name, batch_size, sequence_size assert_with_pcc(torch_output, output, 0.97) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -305,7 +300,6 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): assert_with_pcc(torch_output, output, pcc=0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) def test_ttnn_whisper(device, ttnn_model): torch.manual_seed(0) diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py index e6cea2f8870..b0b40660440 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py @@ -2,44 +2,67 @@ # SPDX-License-Identifier: Apache-2.0 +import ttnn +import torch import pytest -from models.experimental.functional_whisper.reference import torch_functional_whisper -from models.experimental.functional_whisper.tt import ttnn_optimized_functional_whisper import transformers -from transformers import AutoFeatureExtractor, WhisperModel, WhisperConfig from datasets import load_dataset -import torch -import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import torch_random from ttnn.model_preprocessing import preprocess_model_parameters -from models.utility_functions import is_wormhole_b0, is_blackhole +from models.demos.wormhole.whisper.reference import torch_functional_whisper +from models.demos.wormhole.whisper.tt import ttnn_optimized_functional_whisper +from models.utility_functions import torch_random, is_grayskull, is_wormhole_b0, skip_for_grayskull MODEL_NAME = "openai/whisper-base" -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("sequence_size", [1500]) @pytest.mark.parametrize("use_key_value_states", [False, True]) -def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_size, use_key_value_states): - torch.manual_seed(0) +def test_whisper_attention( + mesh_device, ttnn_model, model_name, batch_size, sequence_size, use_key_value_states, reset_seeds +): config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperAttention( embed_dim=config.d_model, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout ).eval() torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + prefix="encoder_attn" if use_key_value_states else "", + device=mesh_device, + ) + ttnn_hidden_states = ttnn.from_torch( - torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + torch_hidden_states, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, ) if use_key_value_states: torch_key_value_states = torch_random( (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32 ) ttnn_key_value_states = ttnn.from_torch( - torch_key_value_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + torch_key_value_states, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, ) else: torch_key_value_states = None @@ -62,35 +85,32 @@ def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_ parameters=torch_parameters, ) - ttnn_parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=lambda *_: True, - custom_preprocessor=ttnn_model.custom_preprocessor, - device=device, - prefix="encoder_attn" if use_key_value_states else "", - ) attention_mask = None output = ttnn_model.whisper_attention( config, + mesh_device, ttnn_hidden_states, attention_mask, key_value_states=ttnn_key_value_states, parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, ) - output = ttnn.to_torch(output) - assert_with_pcc(torch_output, output, 0.98) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("sequence_size", [1500]) -def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size): - torch.manual_seed(0) +def test_encoder_layer(mesh_device, ttnn_model, model_name, batch_size, sequence_size, reset_seeds): config = transformers.WhisperConfig.from_pretrained(model_name) - model = transformers.models.whisper.modeling_whisper.WhisperEncoderLayer(config).eval() + model = transformers.models.whisper.modeling_whisper.WhisperEncoderLayer(config) model = model embed_dim = config.d_model @@ -103,33 +123,53 @@ def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size ) torch_output = torch_functional_whisper.encoder_layer(config, torch_hidden_states, parameters=parameters) - ttnn_parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=ttnn_model.convert_to_ttnn, - custom_preprocessor=ttnn_model.custom_preprocessor, - device=device, - ) + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + + model = model.eval() ttnn_hidden_states = ttnn.from_torch( - torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + torch_hidden_states, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, ) - output = ttnn_model.encoder_layer(config, ttnn_hidden_states, parameters=ttnn_parameters) - output = ttnn.to_torch(output) + output = ttnn_model.encoder_layer( + config, + mesh_device, + ttnn_hidden_states, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) assert_with_pcc(torch_output, output, pcc=0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("feature_size", [80]) @pytest.mark.parametrize("sequence_length", [3000]) -def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, sequence_length): - torch.manual_seed(0) +def test_encoder(mesh_device, ttnn_model, model_name, batch_size, feature_size, sequence_length, reset_seeds): config = transformers.WhisperConfig.from_pretrained(model_name) - model = transformers.models.whisper.modeling_whisper.WhisperEncoder(config).eval() - model = model + model = transformers.models.whisper.modeling_whisper.WhisperEncoder(config) torch_input_features = torch_random((batch_size, feature_size, sequence_length), -0.1, 0.1, dtype=torch.float32) @@ -139,9 +179,12 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque custom_preprocessor=torch_functional_whisper.custom_preprocessor, ) - # torch_original_output = torch_functional_whisper.encoder_original( - # torch_input_features, parameters, embed_dim, num_heads - # ) + def convert_to_ttnn(model, name): + return name not in [ + "conv1", + "conv2", + "embed_positions", + ] inputs_embeds = torch_functional_whisper.preprocess_encoder_inputs( input_features=torch_input_features, @@ -150,39 +193,61 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque torch_output = torch_functional_whisper.encoder(config, inputs_embeds, parameters=parameters) - ttnn_parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=ttnn_model.convert_to_ttnn, - custom_preprocessor=ttnn_model.custom_preprocessor, - prefix="encoder", - device=device, + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + + ttnn_parameters.embed_positions.weight = ttnn_parameters.embed_positions.weight.unsqueeze(0) + ttnn_parameters.embed_positions.weight = ttnn.from_torch( + ttnn_parameters.embed_positions.weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device ) + model = model.eval() input_embeds = ttnn_model.preprocess_encoder_inputs( input_features=torch_input_features, parameters=ttnn_parameters, - device=device, + device=mesh_device, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + weights_mesh_mapper=weights_mesh_mapper, + inputs_mesh_mapper=inputs_mesh_mapper, + output_mesh_composer=output_mesh_composer, ) input_embeds = ttnn.to_layout(input_embeds, ttnn.TILE_LAYOUT) - input_embeds = ttnn.to_device(input_embeds, device) + input_embeds = ttnn.to_device(input_embeds, mesh_device) - output = ttnn_model.encoder(config, input_embeds, parameters=ttnn_parameters) - output = ttnn.to_torch(output) + output = ttnn_model.encoder( + config, + mesh_device, + input_embeds, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) - assert_with_pcc(torch_output, output, 0.968) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("sequence_size", [1500]) -def test_decoder_layer(device, ttnn_model, model_name, batch_size, sequence_size): - torch.manual_seed(0) +def test_decoder_layer(mesh_device, reset_seeds, ttnn_model, model_name, batch_size, sequence_size): config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperDecoderLayer(config).eval() - model = model - num_heads = config.encoder_attention_heads embed_dim = config.d_model torch_hidden_states = torch_random((batch_size, 2, embed_dim), -0.1, 0.1, dtype=torch.float32) @@ -202,50 +267,71 @@ def test_decoder_layer(device, ttnn_model, model_name, batch_size, sequence_size torch_output = torch_functional_whisper.decoder_layer( config, torch_hidden_states, attention_mask, torch_encoder_hidden_states, parameters=parameters ) + tt_model_name = f"ttnn_{model_name}_optimized" + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) - ttnn_parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=lambda *_: True, - custom_preprocessor=ttnn_model.custom_preprocessor, - device=device, + model = model.eval() + ttnn_hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat16, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + layout=ttnn.TILE_LAYOUT, ) - ttnn_hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16) - ttnn_hidden_states = ttnn.to_layout(ttnn_hidden_states, ttnn.TILE_LAYOUT) - ttnn_hidden_states = ttnn.to_device(ttnn_hidden_states, device) - ttnn_attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16) - ttnn_attention_mask = ttnn.to_layout(ttnn_attention_mask, ttnn.TILE_LAYOUT) - ttnn_attention_mask = ttnn.to_device(ttnn_attention_mask, device) + ttnn_attention_mask = ttnn.from_torch( + attention_mask, dtype=ttnn.bfloat16, device=mesh_device, mesh_mapper=inputs_mesh_mapper, layout=ttnn.TILE_LAYOUT + ) - ttnn_encoder_hidden_states = ttnn.from_torch(torch_encoder_hidden_states, dtype=ttnn.bfloat16) - ttnn_encoder_hidden_states = ttnn.to_layout(ttnn_encoder_hidden_states, ttnn.TILE_LAYOUT) - ttnn_encoder_hidden_states = ttnn.to_device(ttnn_encoder_hidden_states, device) + ttnn_encoder_hidden_states = ttnn.from_torch( + torch_encoder_hidden_states, + dtype=ttnn.bfloat16, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + layout=ttnn.TILE_LAYOUT, + ) output = ttnn_model.decoder_layer( - config, ttnn_hidden_states, ttnn_attention_mask, ttnn_encoder_hidden_states, parameters=ttnn_parameters + config, + mesh_device, + ttnn_hidden_states, + ttnn_attention_mask, + ttnn_encoder_hidden_states, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, ) - output = ttnn.to_torch(output) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) - assert_with_pcc(torch_output, output, 0.97) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("sequence_size", [1500]) -def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): - torch.manual_seed(0) +def test_decoder(mesh_device, ttnn_model, model_name, batch_size, sequence_size, reset_seeds): config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperDecoder(config).eval() - model = model embed_dim = config.d_model torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) - # decoder_input_ids = torch.ones(1, 32).type(torch.int32) * config.decoder_start_token_id - decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id + decoder_input_ids = torch.tensor([[1, 1]] * batch_size) * config.decoder_start_token_id attention_mask = None @@ -255,10 +341,6 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): custom_preprocessor=torch_functional_whisper.custom_preprocessor, ) - # torch_original_output = torch_functional_whisper.decoder_original( - # decoder_input_ids, attention_mask, torch_encoder_hidden_states, parameters, embed_dim, num_heads - # ) - (decoder_hidden_states, decoder_attention_mask) = torch_functional_whisper.preprocess_decoder_inputs( decoder_input_ids, attention_mask, parameters=parameters ) @@ -271,55 +353,77 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): parameters=parameters, ) - ttnn_parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=ttnn_model.convert_to_ttnn, - custom_preprocessor=ttnn_model.custom_preprocessor, - device=device, - prefix="decoder", - ) - ttnn_decoder_input_ids = ttnn.from_torch(decoder_input_ids, dtype=ttnn.bfloat16) - ttnn_decoder_input_ids = ttnn.to_device(ttnn_decoder_input_ids, device) + tt_model_name = f"ttnn_{model_name}_optimized" + + def convert_to_ttnn(model, name): + return name not in ["conv1", "conv2", "embed_positions", "embed_tokens"] + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + model = model.eval() - ttnn_encoder_hidden_states = ttnn.from_torch(torch_encoder_hidden_states, dtype=ttnn.bfloat16) + ttnn_decoder_input_ids = ttnn.from_torch(decoder_input_ids, dtype=ttnn.bfloat16, mesh_mapper=inputs_mesh_mapper) + ttnn_decoder_input_ids = ttnn.to_device(ttnn_decoder_input_ids, mesh_device) + + ttnn_encoder_hidden_states = ttnn.from_torch( + torch_encoder_hidden_states, dtype=ttnn.bfloat16, mesh_mapper=inputs_mesh_mapper + ) ttnn_encoder_hidden_states = ttnn.to_layout(ttnn_encoder_hidden_states, ttnn.TILE_LAYOUT) - ttnn_encoder_hidden_states = ttnn.to_device(ttnn_encoder_hidden_states, device) + ttnn_encoder_hidden_states = ttnn.to_device(ttnn_encoder_hidden_states, mesh_device) (decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_decoder_inputs( - config, decoder_input_ids, attention_mask, parameters=ttnn_parameters, device=device + config, + decoder_input_ids, + attention_mask, + parameters=ttnn_parameters, + device=mesh_device, + inputs_mesh_mapper=inputs_mesh_mapper, ) output = ttnn_model.decoder( config, + device=mesh_device, hidden_states=decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, encoder_hidden_states=ttnn_encoder_hidden_states, parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, ) - output = ttnn.to_torch(output) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) assert_with_pcc(torch_output, output, pcc=0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") -@pytest.mark.requires_fast_runtime_mode_off +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("batch_size", [16]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) -def test_ttnn_whisper(tmp_path, device, ttnn_model): - torch.manual_seed(0) - model_name = "openai/whisper-base" - config = WhisperConfig.from_pretrained(model_name) - feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) +def test_ttnn_whisper(tmp_path, mesh_device, model_name, ttnn_model, batch_size, reset_seeds): + config = transformers.WhisperConfig.from_pretrained(model_name) + feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") + inputs = feature_extractor( + [ds[i]["audio"]["array"] for i in range(batch_size)], sampling_rate=16000, return_tensors="pt" + ) input_features = inputs.input_features - decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id + decoder_input_ids = torch.tensor([[1, 1]] * batch_size) * config.decoder_start_token_id - attention_mask = None - - model = WhisperModel.from_pretrained(model_name).eval() + model = transformers.WhisperModel.from_pretrained(model_name) parameters = preprocess_model_parameters( - initialize_model=lambda: model, + initialize_model=lambda: model.eval(), convert_to_ttnn=lambda *_: False, custom_preprocessor=torch_functional_whisper.custom_preprocessor, ) @@ -327,11 +431,11 @@ def test_ttnn_whisper(tmp_path, device, ttnn_model): (encoder_hidden_states, decoder_hidden_states, decoder_attention_mask) = torch_functional_whisper.preprocess_inputs( input_features=input_features, input_ids=decoder_input_ids, - attention_mask=attention_mask, + attention_mask=None, parameters=parameters, ) - expected_last_hidden_state = torch_functional_whisper.whisper( + torch_last_hidden_state = torch_functional_whisper.whisper( config, encoder_hidden_states, decoder_hidden_states, @@ -339,31 +443,46 @@ def test_ttnn_whisper(tmp_path, device, ttnn_model): parameters=parameters, ) - ttnn_parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=ttnn_model.convert_to_ttnn, - custom_preprocessor=ttnn_model.custom_preprocessor, - device=device, + tt_model_name = f"ttnn_{model_name}_optimized" + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + + model = model.eval() + + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=input_features, + input_ids=decoder_input_ids, + attention_mask=None, + parameters=ttnn_parameters, + device=mesh_device, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + weights_mesh_mapper=weights_mesh_mapper, + inputs_mesh_mapper=inputs_mesh_mapper, + output_mesh_composer=output_mesh_composer, ) - with ttnn.tracer.trace(): - (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( - config=config, - input_features=input_features, - input_ids=decoder_input_ids, - attention_mask=attention_mask, - parameters=ttnn_parameters, - device=device, - ) + last_hidden_state = ttnn_model.whisper( + config, + mesh_device, + input_embeds, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) - last_hidden_state = ttnn_model.whisper( - config, - input_embeds, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - parameters=ttnn_parameters, - ) - last_hidden_state = ttnn.to_torch(last_hidden_state) - ttnn.tracer.visualize(last_hidden_state, file_name=tmp_path / "whisper.svg") + last_hidden_state = ttnn.to_torch(last_hidden_state, mesh_composer=output_mesh_composer) - assert_with_pcc(expected_last_hidden_state, last_hidden_state, 0.964) + assert_with_pcc(torch_last_hidden_state, last_hidden_state, 0.857) diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper_wh.py b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper_wh.py new file mode 100644 index 00000000000..1893dac18e5 --- /dev/null +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper_wh.py @@ -0,0 +1,509 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +import pytest +import transformers +from datasets import load_dataset +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.wormhole.whisper.reference import torch_functional_whisper +from models.demos.wormhole.whisper.tt import ttnn_optimized_functional_whisper +from models.utility_functions import torch_random, is_grayskull, is_wormhole_b0, skip_for_grayskull, run_for_wormhole_b0 + +MODEL_NAME = "openai/whisper-base" + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [1500]) +@pytest.mark.parametrize("use_key_value_states", [False, True]) +def test_whisper_attention( + mesh_device, ttnn_model, model_name, batch_size, sequence_size, use_key_value_states, reset_seeds +): + config = transformers.WhisperConfig.from_pretrained(model_name) + model = transformers.models.whisper.modeling_whisper.WhisperAttention( + embed_dim=config.d_model, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout + ).eval() + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + prefix="encoder_attn" if use_key_value_states else "", + device=mesh_device, + ) + + ttnn_hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + if use_key_value_states: + torch_key_value_states = torch_random( + (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32 + ) + ttnn_key_value_states = ttnn.from_torch( + torch_key_value_states, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + else: + torch_key_value_states = None + ttnn_key_value_states = None + + torch_parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_whisper.custom_preprocessor, + prefix="encoder_attn" if use_key_value_states else "", + ) + + torch_attention_mask = None + + torch_output = torch_functional_whisper.whisper_attention( + config, + torch_hidden_states, + torch_attention_mask, + key_value_states=torch_key_value_states, + parameters=torch_parameters, + ) + + attention_mask = None + output = ttnn_model.whisper_attention( + config, + mesh_device, + ttnn_hidden_states, + attention_mask, + key_value_states=ttnn_key_value_states, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + assert_with_pcc(torch_output, output, 0.99) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [1500]) +def test_encoder_layer(mesh_device, ttnn_model, model_name, batch_size, sequence_size, reset_seeds): + config = transformers.WhisperConfig.from_pretrained(model_name) + model = transformers.models.whisper.modeling_whisper.WhisperEncoderLayer(config) + model = model + + embed_dim = config.d_model + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + torch_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_whisper.custom_preprocessor, + ) + torch_output = torch_functional_whisper.encoder_layer(config, torch_hidden_states, parameters=parameters) + + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + + model = model.eval() + ttnn_hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + + output = ttnn_model.encoder_layer( + config, + mesh_device, + ttnn_hidden_states, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_output, output, pcc=0.99) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("feature_size", [80]) +@pytest.mark.parametrize("sequence_length", [3000]) +def test_encoder(mesh_device, ttnn_model, model_name, batch_size, feature_size, sequence_length, reset_seeds): + config = transformers.WhisperConfig.from_pretrained(model_name) + model = transformers.models.whisper.modeling_whisper.WhisperEncoder(config) + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + torch_input_features = torch_random((batch_size, feature_size, sequence_length), -0.1, 0.1, dtype=torch.float32) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_whisper.custom_preprocessor, + ) + + def convert_to_ttnn(model, name): + return name not in [ + "conv1", + "conv2", + "embed_positions", + ] + + inputs_embeds = torch_functional_whisper.preprocess_encoder_inputs( + input_features=torch_input_features, + parameters=parameters, + ) + + torch_output = torch_functional_whisper.encoder(config, inputs_embeds, parameters=parameters) + + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + + ttnn_parameters.embed_positions.weight = ttnn_parameters.embed_positions.weight.unsqueeze(0) + ttnn_parameters.embed_positions.weight = ttnn.from_torch( + ttnn_parameters.embed_positions.weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device + ) + + model = model.eval() + input_embeds = ttnn_model.preprocess_encoder_inputs( + input_features=torch_input_features, + parameters=ttnn_parameters, + device=mesh_device, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + weights_mesh_mapper=weights_mesh_mapper, + inputs_mesh_mapper=inputs_mesh_mapper, + output_mesh_composer=output_mesh_composer, + ) + input_embeds = ttnn.to_layout(input_embeds, ttnn.TILE_LAYOUT) + input_embeds = ttnn.to_device(input_embeds, mesh_device) + + output = ttnn_model.encoder( + config, + mesh_device, + input_embeds, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_output, output, 0.99) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [1500]) +def test_decoder_layer(mesh_device, reset_seeds, ttnn_model, model_name, batch_size, sequence_size): + config = transformers.WhisperConfig.from_pretrained(model_name) + model = transformers.models.whisper.modeling_whisper.WhisperDecoderLayer(config).eval() + num_heads = config.encoder_attention_heads + embed_dim = config.d_model + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + torch_hidden_states = torch_random((batch_size, 2, embed_dim), -0.1, 0.1, dtype=torch.float32) + + torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) + + attention_mask = torch_random((batch_size, 1, 2, 2), -0.1, 0.1, dtype=torch.float32) + # Putting num_heads in the channel because the add does not support broadcasting outside of the h and w dimensions. + attention_mask = attention_mask.expand(-1, num_heads, -1, -1) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_whisper.custom_preprocessor, + ) + + torch_output = torch_functional_whisper.decoder_layer( + config, torch_hidden_states, attention_mask, torch_encoder_hidden_states, parameters=parameters + ) + tt_model_name = f"ttnn_{model_name}_optimized" + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + + model = model.eval() + ttnn_hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat16, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + layout=ttnn.TILE_LAYOUT, + ) + + ttnn_attention_mask = ttnn.from_torch( + attention_mask, dtype=ttnn.bfloat16, device=mesh_device, mesh_mapper=inputs_mesh_mapper, layout=ttnn.TILE_LAYOUT + ) + + ttnn_encoder_hidden_states = ttnn.from_torch( + torch_encoder_hidden_states, + dtype=ttnn.bfloat16, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + layout=ttnn.TILE_LAYOUT, + ) + + output = ttnn_model.decoder_layer( + config, + mesh_device, + ttnn_hidden_states, + ttnn_attention_mask, + ttnn_encoder_hidden_states, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_output, output, 0.99) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [1500]) +def test_decoder(mesh_device, ttnn_model, model_name, batch_size, sequence_size, reset_seeds): + config = transformers.WhisperConfig.from_pretrained(model_name) + model = transformers.models.whisper.modeling_whisper.WhisperDecoder(config).eval() + embed_dim = config.d_model + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) + + decoder_input_ids = torch.tensor([[1, 1]] * batch_size) * config.decoder_start_token_id + + attention_mask = None + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_whisper.custom_preprocessor, + ) + + (decoder_hidden_states, decoder_attention_mask) = torch_functional_whisper.preprocess_decoder_inputs( + decoder_input_ids, attention_mask, parameters=parameters + ) + + torch_output = torch_functional_whisper.decoder( + config, + hidden_states=decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=torch_encoder_hidden_states, + parameters=parameters, + ) + + tt_model_name = f"ttnn_{model_name}_optimized" + + def convert_to_ttnn(model, name): + return name not in ["conv1", "conv2", "embed_positions", "embed_tokens"] + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + model = model.eval() + + ttnn_decoder_input_ids = ttnn.from_torch(decoder_input_ids, dtype=ttnn.bfloat16, mesh_mapper=inputs_mesh_mapper) + ttnn_decoder_input_ids = ttnn.to_device(ttnn_decoder_input_ids, mesh_device) + + ttnn_encoder_hidden_states = ttnn.from_torch( + torch_encoder_hidden_states, dtype=ttnn.bfloat16, mesh_mapper=inputs_mesh_mapper + ) + ttnn_encoder_hidden_states = ttnn.to_layout(ttnn_encoder_hidden_states, ttnn.TILE_LAYOUT) + ttnn_encoder_hidden_states = ttnn.to_device(ttnn_encoder_hidden_states, mesh_device) + + (decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_decoder_inputs( + config, + decoder_input_ids, + attention_mask, + parameters=ttnn_parameters, + device=mesh_device, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + + output = ttnn_model.decoder( + config, + device=mesh_device, + hidden_states=decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=ttnn_encoder_hidden_states, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_output, output, pcc=0.99) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) +def test_ttnn_whisper(tmp_path, mesh_device, model_name, ttnn_model, batch_size, reset_seeds): + config = transformers.WhisperConfig.from_pretrained(model_name) + feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + inputs = feature_extractor( + [ds[i]["audio"]["array"] for i in range(batch_size)], sampling_rate=16000, return_tensors="pt" + ) + input_features = inputs.input_features + decoder_input_ids = torch.tensor([[1, 1]] * batch_size) * config.decoder_start_token_id + + model = transformers.WhisperModel.from_pretrained(model_name) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.eval(), + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_whisper.custom_preprocessor, + ) + + (encoder_hidden_states, decoder_hidden_states, decoder_attention_mask) = torch_functional_whisper.preprocess_inputs( + input_features=input_features, + input_ids=decoder_input_ids, + attention_mask=None, + parameters=parameters, + ) + + torch_last_hidden_state = torch_functional_whisper.whisper( + config, + encoder_hidden_states, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + ) + + tt_model_name = f"ttnn_{model_name}_optimized" + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + ttnn_parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=mesh_device, + ) + + model = model.eval() + + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=input_features, + input_ids=decoder_input_ids, + attention_mask=None, + parameters=ttnn_parameters, + device=mesh_device, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + weights_mesh_mapper=weights_mesh_mapper, + inputs_mesh_mapper=inputs_mesh_mapper, + output_mesh_composer=output_mesh_composer, + ) + + last_hidden_state = ttnn_model.whisper( + config, + mesh_device, + input_embeds, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + output_mesh_composer=output_mesh_composer, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + + last_hidden_state = ttnn.to_torch(last_hidden_state, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_last_hidden_state, last_hidden_state, 0.97)