Skip to content

Commit

Permalink
[CB] Align the speculative decoding sample with other ones (#1007)
Browse files Browse the repository at this point in the history
Spin off of #907.
Merge after head PR

---------

Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
iefode and ilya-lavrenov authored Oct 18, 2024
1 parent 2240669 commit b925196
Show file tree
Hide file tree
Showing 15 changed files with 133 additions and 191 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,16 @@ jobs:
- name: run and compare
run: |
source ./ov/setupvars.sh
./build/samples/cpp/speculative_decoding_lm/speculative_decoding_lm -a ./dolly-v2-3b/ -m ./dolly-v2-7b/ -n 5
./build/samples/cpp/speculative_decoding_lm/continuous_batching_speculative_decoding -a ./dolly-v2-3b/ -m ./dolly-v2-7b/ -n 5
./build/samples/cpp/speculative_decoding_lm/speculative_decoding_lm ./dolly-v2-7b/ ./dolly-v2-3b/ "Alan Turing was a" > predictions_speculative.txt
./build/samples/cpp/greedy_causal_lm/greedy_causal_lm ./dolly-v2-7b/ "Alan Turing was a" > predictions_greedy.txt
python -c "
with open('predictions_greedy.txt', 'r') as f:
predicted_greedy = f.readline()
with open('predictions_speculative.txt', 'r') as f:
predicted_speculative = f.readline()
assert predicted_greedy == predicted_speculative
"
echo "Alan Turing was a" passed
cpp-prompt_lookup_decoding_lm-ubuntu:
runs-on: ubuntu-20.04-16-cores
Expand Down
2 changes: 2 additions & 0 deletions samples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ install(DIRECTORY
cpp/whisper_speech_recognition
cpp/text2image
cpp/lora_greedy_causal_lm
cpp/speculative_decoding_lm
DESTINATION samples/cpp COMPONENT cpp_samples_genai)

install(DIRECTORY
Expand All @@ -39,6 +40,7 @@ install(DIRECTORY
python/multinomial_causal_lm
python/visual_language_chat
python/whisper_speech_recognition
# python/speculative_decoding_lm
# python/text2image
DESTINATION samples/python COMPONENT cpp_samples_genai
USE_SOURCE_PERMISSIONS)
4 changes: 4 additions & 0 deletions samples/cpp/continuous_batching_accuracy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ find_package(OpenVINO REQUIRED COMPONENTS Runtime)
set(TARGET_NAME continuous_batching_accuracy)
add_executable(${TARGET_NAME} ${TARGET_NAME}.cpp)
target_link_libraries(${TARGET_NAME} PRIVATE openvino::genai cxxopts::cxxopts)

set(TARGET_NAME_CB continuous_batching_speculative_decoding)
add_executable(${TARGET_NAME_CB} ${TARGET_NAME_CB}.cpp)
target_link_libraries(${TARGET_NAME_CB} PRIVATE openvino::genai cxxopts::cxxopts)
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ int main(int argc, char* argv[]) try {
("m,model", "Path to model and tokenizers base directory", cxxopts::value<std::string>()->default_value("."))
("a,draft_model", "Path to assisting model base directory", cxxopts::value<std::string>()->default_value("."))
("d,device", "Target device to run the model", cxxopts::value<std::string>()->default_value("CPU"))
("use_prefix", "Whether to use a prefix or not", cxxopts::value<bool>()->default_value("false"))
("h,help", "Print usage");

cxxopts::ParseResult result;
Expand All @@ -77,14 +76,6 @@ int main(int argc, char* argv[]) try {
const std::string model_path = result["model"].as<std::string>();
const std::string draft_model_path = result["draft_model"].as<std::string>();
const std::string device = result["device"].as<std::string>();
const bool use_prefix = result["use_prefix"].as<bool>();

std::string prefix_str =
"You are an advanced language model designed to assist users by providing accurate, "
"relevant, and helpful information. Your responses should be accurate, concise, contextual, "
"respectful, and helpful. The request is: ";

// create dataset

std::vector<std::string> prompt_examples = {
"What is OpenVINO?",
Expand Down Expand Up @@ -117,15 +108,14 @@ int main(int argc, char* argv[]) try {

ov::genai::SchedulerConfig scheduler_config;
// batch size
scheduler_config.max_num_batched_tokens = use_prefix ? 256 : 32;
scheduler_config.max_num_batched_tokens = 32;
// cache params
scheduler_config.num_kv_blocks = 364;
scheduler_config.block_size = get_default_block_size(device);
// mode - vLLM or dynamic_split_fuse
scheduler_config.dynamic_split_fuse = dynamic_split_fuse;
// vLLM specific params
scheduler_config.max_num_seqs = 2;
scheduler_config.enable_prefix_caching = use_prefix;

ov::genai::ContinuousBatchingPipeline pipe(model_path, scheduler_config, device, {ov::genai::draft_model(draft_model_path, device)});
std::vector<ov::genai::GenerationResult> generation_results = pipe.generate(prompts, generation_config);
Expand Down
43 changes: 16 additions & 27 deletions samples/cpp/speculative_decoding_lm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,34 +1,23 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# start of dependencies

include(FetchContent)

if(POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
endif()

FetchContent_Declare(cxxopts
URL https://github.com/jarro2783/cxxopts/archive/refs/tags/v3.1.1.tar.gz
URL_HASH SHA256=523175f792eb0ff04f9e653c90746c12655f10cb70f1d5e6d6d9491420298a08)
FetchContent_MakeAvailable(cxxopts)

if(NOT TARGET nlohmann_json)
FetchContent_Declare(nlohmann_json
URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.tar.gz
URL_HASH SHA256=0d8ef5af7f9794e3263480193c491549b2ba6cc74bb018906202ada498a79406)
FetchContent_MakeAvailable(nlohmann_json)
endif()

find_package(OpenVINO REQUIRED COMPONENTS Runtime)

# end of dependencies
find_package(OpenVINOGenAI REQUIRED
PATHS
"${CMAKE_BINARY_DIR}" # Reuse the package from the build.
${OpenVINO_DIR} # GenAI may be installed alogside OpenVINO.
NO_CMAKE_FIND_ROOT_PATH
)

set(TARGET_NAME speculative_decoding_lm)
add_executable(${TARGET_NAME} ${TARGET_NAME}.cpp)
target_link_libraries(${TARGET_NAME} PRIVATE openvino::genai cxxopts::cxxopts)
target_link_libraries(${TARGET_NAME} PRIVATE openvino::genai)

set_target_properties(${TARGET_NAME} PROPERTIES
COMPILE_PDB_NAME ${TARGET_NAME}
# Ensure out of box LC_RPATH on macOS with SIP
INSTALL_RPATH_USE_LINK_PATH ON)

set(TARGET_NAME_CB continuous_batching_speculative_decoding)
add_executable(${TARGET_NAME_CB} ${TARGET_NAME_CB}.cpp)
target_link_libraries(${TARGET_NAME_CB} PRIVATE openvino::genai cxxopts::cxxopts)
install(TARGETS ${TARGET_NAME}
RUNTIME DESTINATION samples_bin/
COMPONENT samples_bin
EXCLUDE_FROM_ALL)
150 changes: 29 additions & 121 deletions samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,108 +2,29 @@
// SPDX-License-Identifier: Apache-2.0

#include <openvino/openvino.hpp>
#include <cxxopts.hpp>

#include "openvino/genai/llm_pipeline.hpp"

void print_generation_result(const std::vector<std::string>& texts, const std::vector<float>& log_probs) {
for (size_t output_id = 0; output_id < texts.size(); ++output_id) {
std::cout << "Answer " << output_id << " (" << log_probs[output_id] << ") : " << texts[output_id] << std::endl;
}
}

std::vector<ov::genai::GenerationConfig> get_spec_decoding_generation_config_examples() {

// sampling param for speulative decoding
ov::genai::GenerationConfig generation_config_greedy_constant = ov::genai::greedy();
{
generation_config_greedy_constant.num_assistant_tokens = 5;
}

ov::genai::GenerationConfig generation_config_multinomial_constant = ov::genai::multinomial();
{
generation_config_multinomial_constant.num_assistant_tokens = 5;
generation_config_multinomial_constant.num_return_sequences = 1;
}

ov::genai::GenerationConfig generation_config_greedy_dynamic = ov::genai::greedy();
{
generation_config_greedy_dynamic.assistant_confidence_threshold = 0.8f;
}

ov::genai::GenerationConfig generation_config_multinomial_dynamic = ov::genai::multinomial();
{
generation_config_multinomial_dynamic.assistant_confidence_threshold = 0.8f;
}

return {
generation_config_greedy_constant,
generation_config_multinomial_constant,
generation_config_greedy_dynamic,
generation_config_multinomial_dynamic,
};
}

int main(int argc, char* argv[]) try {
// Command line options

cxxopts::Options options("accuracy_sample", "Help command");

options.add_options()
("n,num_prompts", "A number of prompts", cxxopts::value<size_t>()->default_value("1"))
("dynamic_split_fuse", "Whether to use dynamic split-fuse or vLLM scheduling", cxxopts::value<bool>()->default_value("false"))
("m,model", "Path to model and tokenizers base directory", cxxopts::value<std::string>()->default_value("."))
("a,draft_model", "Path to assisting model base directory", cxxopts::value<std::string>()->default_value("."))
("d,device", "Target device to run the model", cxxopts::value<std::string>()->default_value("CPU"))
("use_prefix", "Whether to use a prefix or not", cxxopts::value<bool>()->default_value("false"))
("h,help", "Print usage");

cxxopts::ParseResult result;
try {
result = options.parse(argc, argv);
} catch (const cxxopts::exceptions::exception& e) {
std::cout << e.what() << "\n\n";
std::cout << options.help() << std::endl;
return EXIT_FAILURE;
}

if (result.count("help")) {
std::cout << options.help() << std::endl;
return EXIT_SUCCESS;
}

const size_t num_prompts = result["num_prompts"].as<size_t>();
const bool dynamic_split_fuse = result["dynamic_split_fuse"].as<bool>();
const std::string model_path = result["model"].as<std::string>();
const std::string draft_model_path = result["draft_model"].as<std::string>();
const std::string device = result["device"].as<std::string>();
const bool use_prefix = result["use_prefix"].as<bool>();

std::string prefix_str =
"You are an advanced language model designed to assist users by providing accurate, "
"relevant, and helpful information. Your responses should be accurate, concise, contextual, "
"respectful, and helpful. The request is: ";

// create dataset

std::vector<std::string> prompt_examples = {
"What is OpenVINO?",
"How are you?",
"What is your name?",
"Tell me something about Canada",
"What is OpenVINO?",
};

auto generation_config = get_spec_decoding_generation_config_examples();
auto default_config_size = generation_config.size();
for (size_t i = default_config_size; i < num_prompts; ++i) {
generation_config.push_back(generation_config[i % default_config_size]);
if (4 != argc) {
throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <MODEL_DIR> <DRAFT_MODEL_DIR> '<PROMPT>'");
}

std::vector<std::string> prompts(num_prompts);
for (size_t i = 0; i < num_prompts; ++i) {
prompts[i] = prompt_examples[i % prompt_examples.size()];
}
ov::genai::GenerationConfig config;
config.max_new_tokens = 100;
// Speculative decoding generation parameters are mutually excluded
// add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration
config.num_assistant_tokens = 5;
// add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than `assistant_confidence_threshold`
// config.assistant_confidence_threshold = 0.4

std::string main_model_path = argv[1];
std::string draft_model_path = argv[2];
std::string prompt = argv[3];

// User can run main and draft model on different devices.
// Please, set device for main model in `LLMPipeline` constructor and in in `ov::genai::draft_model` for draft.
std::string main_device = "CPU", draft_device = main_device;

// Perform the inference
auto get_default_block_size = [](const std::string& device) {
Expand All @@ -116,34 +37,21 @@ int main(int argc, char* argv[]) try {
};

ov::genai::SchedulerConfig scheduler_config;
// batch size
scheduler_config.max_num_batched_tokens = use_prefix ? 256 : 32;
// cache params
scheduler_config.num_kv_blocks = 364;
scheduler_config.block_size = get_default_block_size(device);
// mode - vLLM or dynamic_split_fuse
scheduler_config.dynamic_split_fuse = dynamic_split_fuse;
// vLLM specific params
scheduler_config.max_num_seqs = 2;
scheduler_config.enable_prefix_caching = use_prefix;
scheduler_config.cache_size = 5;
scheduler_config.block_size = get_default_block_size(main_device);

// It's possible to construct a Tokenizer from a different path.
// If the Tokenizer isn't specified, it's loaded from the same folder.
ov::genai::LLMPipeline pipe(model_path, device, ov::genai::draft_model(draft_model_path, device), ov::genai::scheduler_config(scheduler_config));
// Example to run main_model on GPU and draft_model on CPU:
// ov::genai::LLMPipeline pipe(main_model_path, "GPU", ov::genai::draft_model(draft_model_path, "CPU"), ov::genai::scheduler_config(scheduler_config));
ov::genai::LLMPipeline pipe(main_model_path, main_device, ov::genai::draft_model(draft_model_path, draft_device), ov::genai::scheduler_config(scheduler_config));

if (use_prefix) {
std::cout << "Running inference for prefix to compute the shared prompt's KV cache..." << std::endl;
auto generation_results = pipe.generate(prefix_str, ov::genai::greedy());
}
auto streamer = [](std::string subword) {
std::cout << subword << std::flush;
return false;
};

for (size_t request_id = 0; request_id < prompts.size(); ++request_id) {
ov::genai::DecodedResults generation_results = pipe.generate(prompts[request_id], generation_config[request_id]);
std::cout << "Question: " << prompts[request_id] << std::endl;
const std::vector<std::string>& text_results = generation_results.texts;
const std::vector<float>& log_prob_results = generation_results.scores;
print_generation_result(text_results, log_prob_results);
std::cout << std::endl;
}
// Since the streamer is set, the results will
// be printed each time a new token is generated.
pipe.generate(prompt, config, streamer);
} catch (const std::exception& error) {
try {
std::cerr << error.what() << '\n';
Expand Down
33 changes: 28 additions & 5 deletions src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,35 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline {
OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> streamer(StreamerVariant func);
OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> generation_config(const GenerationConfig& config);

OPENVINO_GENAI_EXPORTS
std::pair<std::string, Any> draft_model(
OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> _draft_model(
const std::string& model_path,
const std::string& device = "",
const ov::AnyMap& plugin_config = {},
const ov::genai::SchedulerConfig& scheduler_config = {});
const std::string& device,
const ov::AnyMap& llm_config);

template <typename... Properties,
typename std::enable_if<ov::util::StringAny<Properties...>::value, bool>::type = true>
inline std::pair<std::string, Any> draft_model(
const std::string& model_path,
const std::string& device,
Properties&&... properties) {
return _draft_model(model_path, device, ov::AnyMap{std::forward<Properties>(properties)...});
}

template <typename... Properties,
typename std::enable_if<ov::util::StringAny<Properties...>::value, bool>::type = true>
inline std::pair<std::string, Any> draft_model(
const std::string& model_path,
Properties&&... properties) {
return _draft_model(model_path, "", ov::AnyMap{std::forward<Properties>(properties)...});
}


inline std::pair<std::string, Any>
draft_model(
const std::string& model_path,
const std::string& device = "",
const ov::AnyMap& llm_config = ov::AnyMap()) {
return _draft_model(model_path, device, llm_config);
}
} // namespace genai
} // namespace ov
10 changes: 5 additions & 5 deletions src/cpp/src/continuous_batching_impl_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ class ContinuousBatchingPipeline::ImplInterface {
float m_infer_total_ms = 0.0f;

~PerfTime() {
std::cout << "Inference requests aggregated statistic: " << std::endl;
std::cout << "Paged attention % of inference execution: " << (m_paged_attention_time_ms / m_infer_total_ms) * 100 << std::endl;
std::cout << "MatMul % of inference execution: " << (m_matmul_time_ms / m_infer_total_ms) * 100 << std::endl;
std::cout << "Total inference execution secs: " << m_infer_total_ms / 1000. << std::endl;
std::cout << std::endl;
// std::cout << "Inference requests aggregated statistic: " << std::endl;
// std::cout << "Paged attention % of inference execution: " << (m_paged_attention_time_ms / m_infer_total_ms) * 100 << std::endl;
// std::cout << "MatMul % of inference execution: " << (m_matmul_time_ms / m_infer_total_ms) * 100 << std::endl;
// std::cout << "Total inference execution secs: " << m_infer_total_ms / 1000. << std::endl;
// std::cout << std::endl;
}
} m_perf;
bool m_is_chat_conversation = false;
Expand Down
5 changes: 3 additions & 2 deletions src/cpp/src/generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ void GenerationConfig::validate() const {
}
if (is_speculative_decoding()) {
if (assistant_confidence_threshold != 0.f) {
OPENVINO_ASSERT(num_assistant_tokens == 0);
OPENVINO_ASSERT(num_assistant_tokens == 0, "Parameters `assistant_confidence_threshold` and `num_assistant_tokens` are mutually excluded in `GenerationConfig`");
} else {
OPENVINO_ASSERT(num_assistant_tokens > 0);
OPENVINO_ASSERT(num_assistant_tokens > 0, "Parameters `assistant_confidence_threshold` and `num_assistant_tokens` are mutually excluded in `GenerationConfig`");
};
}
}
Expand Down Expand Up @@ -202,5 +202,6 @@ GenerationConfig multinomial() {
return multinomial_config;
}


} // namespace genai
} // namespace ov
Loading

0 comments on commit b925196

Please sign in to comment.