Skip to content

Commit

Permalink
Add model modifications for StaticWhisperPipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
eshiryae committed Oct 29, 2024
1 parent 0a5dc02 commit 0aea7f9
Showing 1 changed file with 186 additions and 15 deletions.
201 changes: 186 additions & 15 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
#include "whisper/whisper.hpp"
#include "whisper/whisper_config.hpp"


#include "openvino/core/layout.hpp"
#include "openvino/core/preprocess/pre_post_process.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/pass/graph_rewrite.hpp"
Expand Down Expand Up @@ -68,8 +69,9 @@ ov::Tensor make_tensor_slice(ov::Tensor tensor, size_t dim, size_t start_pos, si

void set_cross_attn_key_value(ov::InferRequest& source, ov::InferRequest& dest) {
// NB: Source outputs:
// present_key_values.0.encoder.key
// present_key_values.0.encoder.value
// for optimum-cli
// present.0.encoder.key
// present.0.encoder.value

// NB: Dest inputs:
// past_key_values.0.encoder.key
Expand All @@ -80,15 +82,16 @@ void set_cross_attn_key_value(ov::InferRequest& source, ov::InferRequest& dest)
if (source_output_name.find("encoder") == std::string::npos) {
continue;
}
std::string with_past_input_name = std::regex_replace(source_output_name, std::regex("present"), "past");
std::string with_past_input_name = std::regex_replace(source_output_name, std::regex("present"), "past_key_values");
dest.set_tensor(with_past_input_name, source.get_tensor(source_output_name));
}
}

void update_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, const size_t kv_pos = 0u) {
// NB: Source outputs:
// present_key_values.0.decoder.key
// present_key_values.0.decoder.value
// for optimum-cli
// present.0.decoder.key
// present.0.decoder.value

// NB: Dest inputs:
// past_key_values.0.decoder.key
Expand All @@ -100,7 +103,7 @@ void update_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, con
continue;
}

std::string with_past_input_name = std::regex_replace(source_output_name, std::regex("present"), "past");
std::string with_past_input_name = std::regex_replace(source_output_name, std::regex("present"), "past_key_values");

auto src_kv_tensor = source.get_tensor(source_output_name);
auto dst_kv_tensor = dest.get_tensor(with_past_input_name);
Expand Down Expand Up @@ -133,6 +136,9 @@ void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder,
auto attention_mask_data = attention_mask_tensor.data<ov::float16>();
std::fill_n(attention_mask_data, init_ids.size(), 1u);
std::fill(attention_mask_data + init_ids.size(), attention_mask_data + attention_mask_tensor.get_size(), 0u);

//decoder.get_tensor("attention_mask").data<ov::float16>()[input_ids.size() - 1] = 0u;
// ^ Need to used attention_mask size here!
}

int64_t decode(ov::Tensor& encoder_hidden_state,
Expand Down Expand Up @@ -171,7 +177,8 @@ int64_t decode_with_past(ov::InferRequest& decoder_with_past,
// FIXME: Avoid this cast to i32. Why it's not i64 precision in model?
decoder_with_past.get_tensor("input_ids").data<int32_t>()[0] = static_cast<int32_t>(input_id);
// FIXME: Avoid this cast to i32. Why it's not i64 precision in model?
decoder_with_past.get_tensor("position_ids").data<int32_t>()[0] = static_cast<int32_t>(position_id);
//decoder_with_past.get_tensor("position_ids").data<int32_t>()[0] = static_cast<int32_t>(position_id);
decoder_with_past.get_tensor("cache_position").data<int64_t>()[0] = position_id; // for optimum-cli
// FIXME: Is "attention_mask" supposed to be f16?
decoder_with_past.get_tensor("attention_mask").data<ov::float16>()[position_id - 1] = 1u;

Expand All @@ -195,7 +202,7 @@ void zero_past_key_values(ov::InferRequest& request) {
past_key_value_decoder_name.find("past_key_values") == std::string::npos) {
continue;
}
fill_tensor<float>(request.get_tensor(past_key_value_decoder_name), 0);
fill_tensor<ov::float16>(request.get_tensor(past_key_value_decoder_name), 0); // for optimum-cli
}
}

Expand All @@ -204,8 +211,12 @@ void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferReq
auto attention_mask = decoder_with_past.get_tensor("attention_mask");
auto* attention_mask_ptr = attention_mask.data<ov::float16>();
std::fill(attention_mask_ptr, attention_mask_ptr + 3u, 1);
std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 1, 0);
attention_mask_ptr[attention_mask.get_size() - 1] = 1;
//std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 1, 0);
//attention_mask_ptr[attention_mask.get_size() - 1] = 1;
// NB: for optimum-cli models attention_mask should be [1, 1, 1, 0, 0, 0, 0, ..., 1, 0], size = size+1 :FIXME
std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 2, 0);
attention_mask_ptr[attention_mask.get_size() - 2] = 1;
attention_mask_ptr[attention_mask.get_size() - 1] = 0;
// NB: Zero past_key_values.*.decoder.value tensors
zero_past_key_values(decoder_with_past);
// NB: Copy KV-caches from decoder
Expand Down Expand Up @@ -395,6 +406,128 @@ void add_attention_mask_input(std::shared_ptr<ov::Model> model) {
pm.run_passes(model);
}

void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_size, const uint32_t kvcache_size) {
//std::cout << "[DEBUG] Reshaping decoder_with_past_model ..." << std::endl;

std::map<std::string, ov::PartialShape> new_shapes;
for (auto input : model->inputs()) {
const auto& input_name = input.get_any_name();
ov::PartialShape new_shape;
if (input_name.find("input_ids") != std::string::npos) {
new_shape = ov::PartialShape({1, input_size});
} else if (input_name.find("attention_mask") != std::string::npos) {
new_shape = ov::PartialShape({1, kvcache_size + 1}); // Artefact in attention_mask
} else if (input_name.find("position_ids") != std::string::npos) {
new_shape = ov::PartialShape({1, input_size});
} else if (input_name.find("cache_position") != std::string::npos) {
new_shape = ov::PartialShape({1});
} else if (input_name.find("encoder_hidden_states") != std::string::npos) {
const auto& partial_shape = input.get_partial_shape();
new_shape = partial_shape;
new_shape[0] = 1; // batch_dim
new_shape[1] = 1500; // FIXME: where to get this? is it got from encoder output{'last_hidden_state'}
} else if (input_name.find("past_key_values") != std::string::npos) {
const auto& partial_shape = input.get_partial_shape();
new_shape = partial_shape;
new_shape[0] = 1; // Use batch dim here
new_shape[2] = input_name.find(".decoder") != std::string::npos
? kvcache_size - input_size
: 1500; // kv_size for decoder, 1500 for encoder : is it got from encoder
// output{'last_hidden_state'}

// ^ use kv_dim here
}
new_shapes.emplace(input_name, new_shape);
}

model->reshape(new_shapes);
}

void reshape_to_static_encoder(std::shared_ptr<ov::Model> model) {
std::map<std::string, ov::PartialShape> new_shapes;
for (auto input : model->inputs()) {
const auto& input_name = input.get_any_name();
ov::PartialShape new_shape;
if (input_name.find("input_features") != std::string::npos) {
const auto& partial_shape = input.get_partial_shape();
new_shape = partial_shape;
new_shape[0] = 1; // batch_dim
}
new_shapes.emplace(input_name, new_shape);
}
model->reshape(new_shapes);
}

void preprocess_encoder(std::shared_ptr<ov::Model> model) {
ov::preprocess::PrePostProcessor preprocessor(model);

preprocessor.input("input_features").tensor().set_element_type(ov::element::Type_t::f32);
preprocessor.input("input_features").preprocess().convert_element_type(ov::element::Type_t::f32);
preprocessor.output("last_hidden_state").tensor().set_element_type(ov::element::Type_t::f16);

model = preprocessor.build();
}

void preprocess_decoder(std::shared_ptr<ov::Model> model) {
ov::preprocess::PrePostProcessor preprocessor(model);

for (auto tensor : model->inputs()) {
if (tensor.get_any_name().find("input_ids") != std::string::npos) {
preprocessor.input("input_ids").tensor().set_element_type(ov::element::Type_t::i32);
preprocessor.input("input_ids").preprocess().convert_element_type(ov::element::Type_t::i32);
} else if (tensor.get_any_name().find("attention_mask") != std::string::npos) {
preprocessor.input("attention_mask").tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input("attention_mask").preprocess().convert_element_type();
} else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) {
preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); // ()
} else if (tensor.get_any_name().find("past_key_values") != std::string::npos) {
preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type();

// if (tensor.get_any_name().find(".value") != std::string::npos) {
// preprocessor.output(tensor.get_any_name()).tensor().set_layout(ov::Layout("NCWH"));
// preprocessor.output(tensor.get_any_name()).model().set_layout(ov::Layout("NCHW"));
//} else if (tensor.get_any_name().find(".key") != std::string::npos) {
// preprocessor.output(tensor.get_any_name()).tensor().set_layout(ov::Layout("NCHW"));
// preprocessor.output(tensor.get_any_name()).model().set_layout(ov::Layout("NCHW"));
//}
}
}

for (auto tensor : model->outputs()) {
//preprocessor.output(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
if (tensor.get_any_name().find("present") != std::string::npos) { // "present" for models from arch team
preprocessor.output(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.output(tensor.get_any_name()).postprocess().convert_element_type();

// if (tensor.get_any_name().find(".value") != std::string::npos) {
// preprocessor.output(tensor.get_any_name()).tensor().set_layout(ov::Layout("NCWH"));
// preprocessor.output(tensor.get_any_name()).model().set_layout(ov::Layout("NCHW"));
//} else if (tensor.get_any_name().find(".key") != std::string::npos) {
// preprocessor.output(tensor.get_any_name()).tensor().set_layout(ov::Layout("NCHW"));
// preprocessor.output(tensor.get_any_name()).model().set_layout(ov::Layout("NCHW"));
//}
}
}

model = preprocessor.build();
}

std::shared_ptr<ov::Model> redirect_new_kv_to_output(const std::shared_ptr<ov::Model>& model) {
const auto kStartOutputKVCacheLayers = 1u;
for (int i = kStartOutputKVCacheLayers; i < model->outputs().size(); ++i) {
auto kvout = model->output(i);
auto kvrslt = kvout.get_node();
auto kvcat = kvrslt->inputs()[0].get_source_output().get_node();
auto kvval = kvcat->inputs()[1].get_source_output();
kvval.set_names({kvout.get_any_name()});
kvrslt->inputs()[0].replace_source_output(kvval);
}
model->validate_nodes_and_infer_types();
return model;
}

} // namespace

namespace ov {
Expand All @@ -418,10 +551,48 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
}

// TODO: There must be model reshape to eliminate dynamism!

m_models.encoder = core.compile_model(encoder_model, "NPU").create_infer_request();
m_models.decoder = core.compile_model(decoder_model, "NPU").create_infer_request();
m_models.decoder_with_past = core.compile_model(decoder_with_past_model, "NPU").create_infer_request();
size_t max_sequence_length = 128;

reshape_to_static_encoder(encoder_model);
reshape_to_static(decoder_model, 4, 4); // What is 4 here??
reshape_to_static(decoder_with_past_model, 1, max_sequence_length);

// Replace KV-tensors for the entire cache to tensors only for new token
decoder_with_past_model = redirect_new_kv_to_output(decoder_with_past_model);

ov::AnyMap config_encoder = {
{"NPU_COMPILATION_MODE_PARAMS", "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add"},
{"NPU_USE_NPUW", "YES"},
{"NPUW_ONLINE_PIPELINE", "NONE"},
//{"NPUW_FOLD", "YES"},
//{"NPUW_DCOFF_TYPE", "f16"},
//{"NPUW_DCOFF_SCALE", "YES"},
{"NPUW_DEVICES", "CPU"}};

ov::AnyMap config = {
{"NPU_COMPILATION_MODE_PARAMS", "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add"},
{"NPU_USE_NPUW", "YES"},
//{"NPUW_FOLD", "YES"},
//{"NPUW_DCOFF_TYPE", "f16"},
//{"NPUW_DCOFF_SCALE", "YES"},
{"NPUW_DEVICES", "CPU"}};

preprocess_encoder(encoder_model);
preprocess_decoder(decoder_model);
preprocess_decoder(decoder_with_past_model);

std::cout << "[DEBUG] All model modifications are done, saving models..." << std::endl;
ov::save_model(encoder_model, models_path / "0_openvino_encoder_model_attn.xml");
ov::save_model(decoder_model, models_path / "0_openvino_decoder_model_attn.xml");
ov::save_model(decoder_with_past_model, models_path / "0_openvino_decoder_with_past_model_attn.xml");

m_models.encoder = core.compile_model(encoder_model, "NPU", config_encoder).create_infer_request();
std::cout << "[DEBUG] Compile encoder model - DONE" << std::endl;
m_models.decoder = core.compile_model(decoder_model, "NPU", config_encoder).create_infer_request();
std::cout << "[DEBUG] Compile decoder model - DONE" << std::endl;
m_models.decoder_with_past =
core.compile_model(decoder_with_past_model, "NPU", config_encoder).create_infer_request();
std::cout << "[DEBUG] Compile decoder with past model - DONE" << std::endl;

// If eos_token_id was not provided, take value
if (m_generation_config.eos_token_id == -1) {
Expand Down

0 comments on commit 0aea7f9

Please sign in to comment.