Skip to content

Commit

Permalink
[CB] SpeculativeDecoding impl C++ (#907)
Browse files Browse the repository at this point in the history
Merge after [Validation mode implementation in
Sampler](#904)

Tickets:
* [153599](https://jira.devtools.intel.com/browse/CVS-153599)
* [153604](https://jira.devtools.intel.com/browse/CVS-153604)
* [154104](https://jira.devtools.intel.com/browse/CVS-154104)
* [154885](https://jira.devtools.intel.com/browse/CVS-154885)

---------

Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
iefode and ilya-lavrenov authored Oct 18, 2024
1 parent 64502bb commit 2378ab0
Show file tree
Hide file tree
Showing 36 changed files with 1,980 additions and 566 deletions.
12 changes: 2 additions & 10 deletions .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -422,16 +422,8 @@ jobs:
- name: run and compare
run: |
source ./ov/setupvars.sh
./build/samples/cpp/speculative_decoding_lm/speculative_decoding_lm ./dolly-v2-3b/ ./dolly-v2-7b/ "Alan Turing was a" > predictions_speculative.txt
./build/samples/cpp/greedy_causal_lm/greedy_causal_lm ./dolly-v2-7b/ "Alan Turing was a" > predictions_greedy.txt
python -c "
with open('predictions_greedy.txt', 'r') as f:
predicted_greedy = f.readline()
with open('predictions_speculative.txt', 'r') as f:
predicted_speculative = f.readline()
assert predicted_greedy == predicted_speculative
"
echo "Alan Turing was a" passed
./build/samples/cpp/speculative_decoding_lm/speculative_decoding_lm -a ./dolly-v2-3b/ -m ./dolly-v2-7b/ -n 5
./build/samples/cpp/speculative_decoding_lm/continuous_batching_speculative_decoding -a ./dolly-v2-3b/ -m ./dolly-v2-7b/ -n 5
cpp-prompt_lookup_decoding_lm-ubuntu:
runs-on: ubuntu-20.04-16-cores
Expand Down
60 changes: 32 additions & 28 deletions samples/cpp/speculative_decoding_lm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
# Copyright (C) 2023-2024 Intel Corporation
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

find_package(OpenVINO REQUIRED COMPONENTS Runtime Threading)

find_package(OpenVINOGenAI REQUIRED
PATHS
"${CMAKE_BINARY_DIR}" # Reuse the package from the build.
${OpenVINO_DIR} # GenAI may be installed alogside OpenVINO.
NO_CMAKE_FIND_ROOT_PATH
)

add_executable(speculative_decoding_lm speculative_decoding_lm.cpp)
target_link_libraries(speculative_decoding_lm PRIVATE openvino::runtime openvino::threading)
set_target_properties(speculative_decoding_lm PROPERTIES
COMPILE_PDB_NAME speculative_decoding_lm
# Ensure out of box LC_RPATH on macOS with SIP
INSTALL_RPATH_USE_LINK_PATH ON)
target_compile_features(speculative_decoding_lm PRIVATE cxx_std_17)

get_target_property(genai_imported openvino::genai IMPORTED_LOCATION)
set(OPENVINO_TOKENIZERS_PATH $<IF:$<BOOL:${genai_imported}>,${genai_imported},$<TARGET_FILE_DIR:openvino::genai>>)
set(OPENVINO_TOKENIZERS_FILENAME "${CMAKE_SHARED_LIBRARY_PREFIX}openvino_tokenizers${CMAKE_SHARED_LIBRARY_SUFFIX}")
target_compile_definitions(speculative_decoding_lm PRIVATE
OPENVINO_TOKENIZERS_PATH="${OPENVINO_TOKENIZERS_PATH}/${OPENVINO_TOKENIZERS_FILENAME}")

install(TARGETS speculative_decoding_lm
RUNTIME DESTINATION samples_bin/
COMPONENT samples_bin
EXCLUDE_FROM_ALL)
# start of dependencies

include(FetchContent)

if(POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
endif()

FetchContent_Declare(cxxopts
URL https://github.com/jarro2783/cxxopts/archive/refs/tags/v3.1.1.tar.gz
URL_HASH SHA256=523175f792eb0ff04f9e653c90746c12655f10cb70f1d5e6d6d9491420298a08)
FetchContent_MakeAvailable(cxxopts)

if(NOT TARGET nlohmann_json)
FetchContent_Declare(nlohmann_json
URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.tar.gz
URL_HASH SHA256=0d8ef5af7f9794e3263480193c491549b2ba6cc74bb018906202ada498a79406)
FetchContent_MakeAvailable(nlohmann_json)
endif()

find_package(OpenVINO REQUIRED COMPONENTS Runtime)

# end of dependencies

set(TARGET_NAME speculative_decoding_lm)
add_executable(${TARGET_NAME} ${TARGET_NAME}.cpp)
target_link_libraries(${TARGET_NAME} PRIVATE openvino::genai cxxopts::cxxopts)

set(TARGET_NAME_CB continuous_batching_speculative_decoding)
add_executable(${TARGET_NAME_CB} ${TARGET_NAME_CB}.cpp)
target_link_libraries(${TARGET_NAME_CB} PRIVATE openvino::genai cxxopts::cxxopts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <openvino/openvino.hpp>
#include <cxxopts.hpp>

#include "openvino/genai/continuous_batching_pipeline.hpp"

void print_cb_generation_result(const ov::genai::GenerationResult& generation_result) {
for (size_t output_id = 0; output_id < generation_result.m_generation_ids.size(); ++output_id) {
std::cout << "Answer " << output_id << " (" << generation_result.m_scores[output_id] << ") : " << generation_result.m_generation_ids[output_id] << std::endl;
}
}

std::vector<ov::genai::GenerationConfig> get_spec_decoding_generation_config_examples() {

// sampling param for speulative decoding
ov::genai::GenerationConfig generation_config_greedy_constant = ov::genai::greedy();
{
generation_config_greedy_constant.num_assistant_tokens = 5;
}

ov::genai::GenerationConfig generation_config_multinomial_constant = ov::genai::multinomial();
{
generation_config_multinomial_constant.num_assistant_tokens = 5;
generation_config_multinomial_constant.num_return_sequences = 1;
}

ov::genai::GenerationConfig generation_config_greedy_dynamic = ov::genai::greedy();
{
generation_config_greedy_dynamic.assistant_confidence_threshold = 0.8f;
}

ov::genai::GenerationConfig generation_config_multinomial_dynamic = ov::genai::multinomial();
{
generation_config_multinomial_dynamic.assistant_confidence_threshold = 0.8f;
}

return {
generation_config_greedy_constant,
generation_config_multinomial_constant,
generation_config_greedy_dynamic,
generation_config_multinomial_dynamic,
};
}

int main(int argc, char* argv[]) try {
// Command line options

cxxopts::Options options("accuracy_sample", "Help command");

options.add_options()
("n,num_prompts", "A number of prompts", cxxopts::value<size_t>()->default_value("1"))
("dynamic_split_fuse", "Whether to use dynamic split-fuse or vLLM scheduling", cxxopts::value<bool>()->default_value("false"))
("m,model", "Path to model and tokenizers base directory", cxxopts::value<std::string>()->default_value("."))
("a,draft_model", "Path to assisting model base directory", cxxopts::value<std::string>()->default_value("."))
("d,device", "Target device to run the model", cxxopts::value<std::string>()->default_value("CPU"))
("use_prefix", "Whether to use a prefix or not", cxxopts::value<bool>()->default_value("false"))
("h,help", "Print usage");

cxxopts::ParseResult result;
try {
result = options.parse(argc, argv);
} catch (const cxxopts::exceptions::exception& e) {
std::cout << e.what() << "\n\n";
std::cout << options.help() << std::endl;
return EXIT_FAILURE;
}

if (result.count("help")) {
std::cout << options.help() << std::endl;
return EXIT_SUCCESS;
}

const size_t num_prompts = result["num_prompts"].as<size_t>();
const bool dynamic_split_fuse = result["dynamic_split_fuse"].as<bool>();
const std::string model_path = result["model"].as<std::string>();
const std::string draft_model_path = result["draft_model"].as<std::string>();
const std::string device = result["device"].as<std::string>();
const bool use_prefix = result["use_prefix"].as<bool>();

std::string prefix_str =
"You are an advanced language model designed to assist users by providing accurate, "
"relevant, and helpful information. Your responses should be accurate, concise, contextual, "
"respectful, and helpful. The request is: ";

// create dataset

std::vector<std::string> prompt_examples = {
"What is OpenVINO?",
"How are you?",
"What is your name?",
"Tell me something about Canada",
"What is OpenVINO?",
};

auto generation_config = get_spec_decoding_generation_config_examples();
auto default_config_size = generation_config.size();
for (size_t i = default_config_size; i < num_prompts; ++i) {
generation_config.push_back(generation_config[i % default_config_size]);
}

std::vector<std::string> prompts(num_prompts);
for (size_t i = 0; i < num_prompts; ++i) {
prompts[i] = prompt_examples[i % prompt_examples.size()];
}

// Perform the inference
auto get_default_block_size = [](const std::string& device) {
const size_t cpu_block_size = 32;
const size_t gpu_block_size = 16;

bool is_gpu = device.find("GPU") != std::string::npos;

return is_gpu ? gpu_block_size : cpu_block_size;
};

ov::genai::SchedulerConfig scheduler_config;
// batch size
scheduler_config.max_num_batched_tokens = use_prefix ? 256 : 32;
// cache params
scheduler_config.num_kv_blocks = 364;
scheduler_config.block_size = get_default_block_size(device);
// mode - vLLM or dynamic_split_fuse
scheduler_config.dynamic_split_fuse = dynamic_split_fuse;
// vLLM specific params
scheduler_config.max_num_seqs = 2;
scheduler_config.enable_prefix_caching = use_prefix;

ov::genai::ContinuousBatchingPipeline pipe(model_path, scheduler_config, device, {ov::genai::draft_model(draft_model_path, device)});
std::vector<ov::genai::GenerationResult> generation_results = pipe.generate(prompts, generation_config);

for (size_t request_id = 0; request_id < generation_results.size(); ++request_id) {
const ov::genai::GenerationResult & generation_result = generation_results[request_id];
std::cout << "Question: " << prompts[request_id] << std::endl;
switch (generation_result.m_status)
{
case ov::genai::GenerationStatus::FINISHED:
print_cb_generation_result(generation_result);
break;
case ov::genai::GenerationStatus::IGNORED:
std::cout << "Request was ignored due to lack of memory." <<std::endl;
if (generation_result.m_generation_ids.size() > 0) {
std::cout << "Partial result:" << std::endl;
print_cb_generation_result(generation_result);
}
break;
case ov::genai::GenerationStatus::DROPPED_BY_PIPELINE:
std::cout << "Request was aborted." <<std::endl;
if (generation_result.m_generation_ids.size() > 0) {
std::cout << "Partial result:" << std::endl;
print_cb_generation_result(generation_result);
}
break;
default:
break;
}
std::cout << std::endl;
}
} catch (const std::exception& error) {
try {
std::cerr << error.what() << '\n';
} catch (const std::ios_base::failure&) {}
return EXIT_FAILURE;
} catch (...) {
try {
std::cerr << "Non-exception object thrown\n";
} catch (const std::ios_base::failure&) {}
return EXIT_FAILURE;
}
Loading

0 comments on commit 2378ab0

Please sign in to comment.