Skip to content

Commit

Permalink
[CB] Align the speculative decoding sample with other ones (#1007)
Browse files Browse the repository at this point in the history
Spin off of openvinotoolkit/openvino.genai#907.
Merge after head PR

---------

Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
iefode and ilya-lavrenov authored Oct 18, 2024
1 parent ba62a46 commit dce0cd0
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 30 deletions.
33 changes: 28 additions & 5 deletions include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,35 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline {
OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> streamer(StreamerVariant func);
OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> generation_config(const GenerationConfig& config);

OPENVINO_GENAI_EXPORTS
std::pair<std::string, Any> draft_model(
OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> _draft_model(
const std::string& model_path,
const std::string& device = "",
const ov::AnyMap& plugin_config = {},
const ov::genai::SchedulerConfig& scheduler_config = {});
const std::string& device,
const ov::AnyMap& llm_config);

template <typename... Properties,
typename std::enable_if<ov::util::StringAny<Properties...>::value, bool>::type = true>
inline std::pair<std::string, Any> draft_model(
const std::string& model_path,
const std::string& device,
Properties&&... properties) {
return _draft_model(model_path, device, ov::AnyMap{std::forward<Properties>(properties)...});
}

template <typename... Properties,
typename std::enable_if<ov::util::StringAny<Properties...>::value, bool>::type = true>
inline std::pair<std::string, Any> draft_model(
const std::string& model_path,
Properties&&... properties) {
return _draft_model(model_path, "", ov::AnyMap{std::forward<Properties>(properties)...});
}


inline std::pair<std::string, Any>
draft_model(
const std::string& model_path,
const std::string& device = "",
const ov::AnyMap& llm_config = ov::AnyMap()) {
return _draft_model(model_path, device, llm_config);
}
} // namespace genai
} // namespace ov
10 changes: 5 additions & 5 deletions src/continuous_batching_impl_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ class ContinuousBatchingPipeline::ImplInterface {
float m_infer_total_ms = 0.0f;

~PerfTime() {
std::cout << "Inference requests aggregated statistic: " << std::endl;
std::cout << "Paged attention % of inference execution: " << (m_paged_attention_time_ms / m_infer_total_ms) * 100 << std::endl;
std::cout << "MatMul % of inference execution: " << (m_matmul_time_ms / m_infer_total_ms) * 100 << std::endl;
std::cout << "Total inference execution secs: " << m_infer_total_ms / 1000. << std::endl;
std::cout << std::endl;
// std::cout << "Inference requests aggregated statistic: " << std::endl;
// std::cout << "Paged attention % of inference execution: " << (m_paged_attention_time_ms / m_infer_total_ms) * 100 << std::endl;
// std::cout << "MatMul % of inference execution: " << (m_matmul_time_ms / m_infer_total_ms) * 100 << std::endl;
// std::cout << "Total inference execution secs: " << m_infer_total_ms / 1000. << std::endl;
// std::cout << std::endl;
}
} m_perf;
bool m_is_chat_conversation = false;
Expand Down
5 changes: 3 additions & 2 deletions src/generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ void GenerationConfig::validate() const {
}
if (is_speculative_decoding()) {
if (assistant_confidence_threshold != 0.f) {
OPENVINO_ASSERT(num_assistant_tokens == 0);
OPENVINO_ASSERT(num_assistant_tokens == 0, "Parameters `assistant_confidence_threshold` and `num_assistant_tokens` are mutually excluded in `GenerationConfig`");
} else {
OPENVINO_ASSERT(num_assistant_tokens > 0);
OPENVINO_ASSERT(num_assistant_tokens > 0, "Parameters `assistant_confidence_threshold` and `num_assistant_tokens` are mutually excluded in `GenerationConfig`");
};
}
}
Expand Down Expand Up @@ -202,5 +202,6 @@ GenerationConfig multinomial() {
return multinomial_config;
}


} // namespace genai
} // namespace ov
15 changes: 11 additions & 4 deletions src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "openvino/genai/lora_adapter.hpp"
#include "lora_helper.hpp"
#include "speculative_decoding/speculative_decoding_impl.hpp"
#include "speculative_decoding/speculative_decoding_impl.hpp"

namespace ov {
namespace genai {
Expand Down Expand Up @@ -368,12 +369,18 @@ std::pair<std::string, Any> generation_config(const GenerationConfig& config) {
return {utils::CONFIG_ARG_NAME, Any::make<GenerationConfig>(config)};
}

std::pair<std::string, Any> draft_model(
std::pair<std::string, Any> _draft_model(
const std::string& model_path,
const std::string& device,
const ov::AnyMap& plugin_config,
const ov::genai::SchedulerConfig& scheduler_config) {
return { utils::DRAFT_MODEL_ARG_NAME, Any::make<ModelDesc>(model_path, device, plugin_config, scheduler_config) };
const ov::AnyMap& llm_config) {
ov::AnyMap plugin_config = llm_config;
if (plugin_config.count(ov::genai::scheduler_config.name())) {
auto scheduler_config = plugin_config.at(ov::genai::scheduler_config.name()).as<SchedulerConfig>();
plugin_config.erase(ov::genai::scheduler_config.name());
return { utils::DRAFT_MODEL_ARG_NAME, Any::make<ModelDesc>(model_path, device, plugin_config, scheduler_config) };
}
SchedulerConfig scheduler_config;
return { utils::DRAFT_MODEL_ARG_NAME, Any::make<ModelDesc>(model_path, device, plugin_config, scheduler_config) };
}

} // namespace genai
Expand Down
1 change: 1 addition & 0 deletions src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Sampler {

SamplerOutput sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false);
void set_seed(size_t seed) { rng_engine.seed(seed); }

void clear_request_info(uint64_t request_id);

LogitProcessor& get_logit_processor(uint64_t request_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ init_request(
LogitProcessor& logit_processor,
bool is_update_logit_processor,
bool is_init_all_sequences_in_request = false) {
OPENVINO_ASSERT(request->get_sampling_parameters().is_speculative_decoding(),
"Speculative decoding should have initialized options `assistant_confidence_threshold` xor `num_assistant_tokens` in `GenerationConfig`.");
if (candidates.begin()->second.token_ids.empty() && !is_init_all_sequences_in_request) {
return 0;
}
Expand Down
32 changes: 19 additions & 13 deletions src/speculative_decoding/speculative_decoding_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ namespace ov::genai {
template<class... Ts> struct overloaded : Ts... {using Ts::operator()...;};
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;

bool operator==(const SchedulerConfig& lhs, const SchedulerConfig& rhs) {
return ov::Any(lhs).as<std::string>() == ov::Any(rhs).as<std::string>();
}

ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(
const std::string& main_models_path,
const SchedulerConfig& main_scheduler_config,
Expand All @@ -31,16 +35,13 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(
utils::apply_paged_attention_transformations(main_model, main_scheduler_config.use_cache_eviction);
utils::apply_paged_attention_transformations(draft_model, main_scheduler_config.use_cache_eviction);

std::string draft_device = draft_model_desc.device;
bool is_draft_device_undefined = false;
if (draft_device.empty() || draft_device == main_device) {
draft_device = main_device;
is_draft_device_undefined = true;
}
std::string draft_device = draft_model_desc.device.empty() ? main_device : draft_model_desc.device;

bool is_scheduler_undefined = draft_model_desc.scheduler_config == SchedulerConfig();

ov::genai::SchedulerConfig main_scheduler_config_updated = main_scheduler_config,
draft_scheduler_config = is_draft_device_undefined ? main_scheduler_config : draft_model_desc.scheduler_config;
if (is_draft_device_undefined) {
draft_scheduler_config = is_scheduler_undefined ? main_scheduler_config : draft_model_desc.scheduler_config;
if (is_scheduler_undefined) {
// split KV cache to 2 caches for main and draft models
size_t main_model_cache_size = utils::get_kv_cache_size(main_model),
draft_model_cache_size = utils::get_kv_cache_size(draft_model);
Expand All @@ -57,7 +58,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(
draft_scheduler_config.cache_size = draft_cache_size;
}

ov::AnyMap draft_plugin_config = is_draft_device_undefined ? compile_plugin_config : draft_model_desc.plugin_config;
ov::AnyMap draft_plugin_config = draft_model_desc.plugin_config == ov::AnyMap{} ? compile_plugin_config : draft_model_desc.plugin_config;

DeviceConfig main_device_config(core, main_scheduler_config, main_device, compile_plugin_config),
draft_device_config(core, draft_scheduler_config, draft_device, draft_plugin_config);
Expand Down Expand Up @@ -194,11 +195,16 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
bool continue_generation = true;
while (has_non_finished_requests() && continue_generation) {
step();
if (streamer_ptr) {
if (streamer_ptr) {
std::unordered_map<uint64_t, GenerationOutput> token = main_generations.at(0).get()->back();
OPENVINO_ASSERT(1 == token.size());
OPENVINO_ASSERT(1 == token.begin()->second.generated_ids.size());
continue_generation = !streamer_ptr->put(token.begin()->second.generated_ids.at(0));
OPENVINO_ASSERT(1 <= token.size());
OPENVINO_ASSERT(1 <= token.begin()->second.generated_ids.size());
for (const auto& gen_token : token.begin()->second.generated_ids) {
continue_generation = !streamer_ptr->put(gen_token);
if (!continue_generation) {
break;
}
}
}
}
if (streamer_ptr) {
Expand Down
2 changes: 1 addition & 1 deletion src/timer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ class ManualTimer {
}

~ManualTimer() {
std::cout << m_title << ": " << m_total / 1000. << " secs" << std::endl;
// std::cout << m_title << ": " << m_total / 1000. << " secs" << std::endl;
}
};

0 comments on commit dce0cd0

Please sign in to comment.