Skip to content

Commit

Permalink
Add Type Argument To Greedy Decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
apaniukov committed Jan 31, 2024
1 parent 22e115b commit 2de6890
Showing 1 changed file with 15 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,22 @@ OutputVector translate_sentencepiece_op(const NodeContext& node) {
std::cout << "[ Trace 1 ] As string" << std::endl;
// str_spm_model = str_spm_model.substr(2);
// str_spm_model = str_spm_model.substr(0, str_spm_model.size() - 1);
auto sp_model_const = std::make_shared<Constant>(element::u8, Shape{ str_spm_model.size() }, str_spm_model);
auto sp_model_const = std::make_shared<Constant>(element::u8, Shape{ str_spm_model.size() }, str_spm_model.data());
// std::cout << "[ Trace 1 ] Successful size:"<< str_spm_model.size() << "\n" << str_spm_model.substr(0, 100) << std::endl;
std::cout << "[ Trace 1 ] Successful" << std::endl;
return { sp_model_const };
}

//OutputVector translate_sentencepiece_op(const NodeContext& node) {
// // extract model to configure SentencePieceTokenizer
// auto sp_model_ov_any = node.get_attribute_as_any("model");
// FRONT_END_GENERAL_CHECK(sp_model_ov_any.is<std::string>(),
// "SentencePieceOp configuration model is in incorrect format");
// auto str_spm_model = sp_model_ov_any.as<std::string>();
// auto sp_model_const = std::make_shared<Constant>(element::u8, Shape{ str_spm_model.size() }, str_spm_model.data());
// return { sp_model_const };
//}

NamedOutputVector translate_sentencepiece_tokenizer(const NodeContext& node) {
// this is custom translator that converts a sub-graph with SentencePieceOp, SentencePieceTokenizer,
// and RaggedTensorToSparse operation- into a custom operation SentencepieceTokenizerExtensionOp
Expand All @@ -76,8 +86,7 @@ NamedOutputVector translate_sentencepiece_tokenizer(const NodeContext& node) {

// prepare input
auto inputs = sp_tokenize_op->input_value(1);
std::cout << "[ Trace 222 ] Type: " << inputs.get_element_type() << std::endl;
auto parameter = std::dynamic_pointer_cast<Parameter>(inputs.get_node_shared_ptr())
auto parameter = std::dynamic_pointer_cast<Parameter>(inputs.get_node_shared_ptr());
parameter -> set_partial_shape(PartialShape{ Dimension() });

// extract values for nbest_size, alpha, add_bos, add_eos, reverse attributes
Expand Down Expand Up @@ -195,18 +204,9 @@ ov::OutputVector translate_reshape(const ov::frontend::NodeContext& node) {
FRONT_END_GENERAL_CHECK(node.get_input_size() == 2, "Tensorflow Reshape op should have two inputs");
auto tensor = node.get_input(0);
auto shape = node.get_input(1);
if(auto pack = dynamic_cast<StringTensorPack*>(tensor.get_node())) {
// TODO: If it is a beginning of the graph, how to detect strings? It falls in 'else' branch in this case.
// FIXME: Needs extension for a Parameter to prepare it first
auto begins = std::make_shared<Reshape>(pack->input_value(0), shape, false);
auto ends = std::make_shared<Reshape>(pack->input_value(1), shape, false);
auto chars = pack->input_value(2);
auto reshape = post_translate_string_tensor_output({begins, ends, chars});
return {reshape};
} else {
auto reshape = std::make_shared<Reshape>(tensor, shape, false);
return {reshape};
}
auto reshape = std::make_shared<Reshape>(tensor, shape, false);
return {reshape};
// }
// set_node_name(node.get_name(), reshape); // TODO: requires dependencies from TF FE internals
}

Expand Down

0 comments on commit 2de6890

Please sign in to comment.