Skip to content

Commit

Permalink
apply comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Oct 16, 2024
1 parent 689fd3c commit 6ff3055
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 18 deletions.
19 changes: 10 additions & 9 deletions samples/cpp/visual_language_chat/visual_language_chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
#include <filesystem>
#include <openvino/runtime/intel_gpu/properties.hpp>

namespace fs = std::filesystem;

bool print_subword(std::string&& subword) {
return !(std::cout << subword << std::flush);
}
Expand All @@ -19,6 +17,9 @@ int main(int argc, char* argv[]) try {

std::vector<ov::Tensor> images = utils::load_images(argv[2]);

ov::genai::GenerationConfig generation_config;
generation_config.max_new_tokens = 30;

std::string device = "CPU"; // GPU can be used as well
ov::AnyMap enable_compile_cache;
if ("GPU" == device) {
Expand All @@ -33,16 +34,16 @@ int main(int argc, char* argv[]) try {
std::cout << "question:\n";

std::getline(std::cin, prompt);
auto resuls = pipe.generate(prompt,
ov::genai::images(images),
ov::genai::generation_config(ov::genai::greedy()),
ov::genai::streamer(print_subword));
pipe.generate(prompt,
ov::genai::images(images),
ov::genai::generation_config(generation_config),
ov::genai::streamer(print_subword));
std::cout << "\n----------\n"
"question:\n";
while (std::getline(std::cin, prompt)) {
resuls = pipe.generate(prompt,
ov::genai::generation_config(ov::genai::greedy()),
ov::genai::streamer(print_subword));
pipe.generate(prompt,
ov::genai::generation_config(generation_config),
ov::genai::streamer(print_subword));
std::cout << "\n----------\n"
"question:\n";
}
Expand Down
5 changes: 3 additions & 2 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,11 @@ void register_new_token(const Token& sampled_token_id,
}
};

std::vector<int32_t> Sampler::get_beam_idxs(uint64_t request_id) {
std::vector<int32_t> Sampler::get_beam_idxs(SequenceGroup::CPtr sequence_group) {
size_t request_id = sequence_group->get_request_id();
auto beam_searcher = m_beam_search_info.find(request_id);
if (m_beam_search_info.find(request_id) == m_beam_search_info.end()) {
return { 0 };
return std::vector<int32_t>(sequence_group->num_running_seqs(), 0);
}
return beam_searcher->second.get_beam_idxs();
}
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Sampler {
SamplerOutput sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false);
void set_seed(size_t seed) { rng_engine.seed(seed); }
void clear_beam_search_info(uint64_t request_id);
std::vector<int32_t> get_beam_idxs(uint64_t request_id);
std::vector<int32_t> get_beam_idxs(SequenceGroup::CPtr sequence_group);
};

class Sampler::GroupBeamSearcher {
Expand Down
8 changes: 2 additions & 6 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,9 @@ EncodedGenerationResult get_lm_encoded_results(

language.set_tensor("position_ids", position_ids);

std::vector<int32_t> beam_idxs = sampler.get_beam_idxs(request->get_request_id());
std::vector<int32_t> beam_idxs = sampler.get_beam_idxs(request);
int32_t *beam_idx_data = beam_idx.data<int32_t>();
if (total_num_tokens > beam_idxs.size()) {
std::fill_n(beam_idx_data, total_num_tokens, 0);
} else {
copy(beam_idxs.begin(), beam_idxs.end(), beam_idx_data);
}
std::fill_n(beam_idx_data, total_num_tokens, 0);
language.set_tensor("beam_idx", beam_idx);

language.infer();
Expand Down

0 comments on commit 6ff3055

Please sign in to comment.