diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml index b8fbe397d2..5d688cfc39 100644 --- a/.github/workflows/causal_lm_cpp.yml +++ b/.github/workflows/causal_lm_cpp.yml @@ -254,7 +254,7 @@ jobs: && python samples\python\greedy_causal_lm\greedy_causal_lm.py .\TinyLlama-1.1B-Chat-v1.0\ 69 > .\py.txt - run: fc .\cpp.txt .\py.txt - cpp-beam_search_causal_lm-Qwen-7B-Chat: + cpp-greedy_causal_lm-Qwen-7B-Chat: runs-on: ubuntu-20.04-16-cores defaults: run: @@ -866,7 +866,7 @@ jobs: Overall_Status: name: ci/gha_overall_status_causal_lm needs: [cpp-multinomial-greedy_causal_lm-ubuntu, cpp-beam_search_causal_lm-ubuntu, cpp-greedy_causal_lm-windows, - cpp-beam_search_causal_lm-Qwen-7B-Chat, cpp-beam_search_causal_lm-Qwen1_5-7B-Chat, cpp-beam_search_causal_lm-Phi-2, + cpp-greedy_causal_lm-Qwen-7B-Chat, cpp-beam_search_causal_lm-Qwen1_5-7B-Chat, cpp-beam_search_causal_lm-Phi-2, cpp-beam_search_causal_lm-notus-7b-v1, cpp-speculative_decoding_lm-ubuntu, cpp-prompt_lookup_decoding_lm-ubuntu, cpp-Phi-1_5, cpp-greedy_causal_lm-redpajama-3b-chat, cpp-chat_sample-ubuntu, cpp-continuous-batching-ubuntu, visual_language_chat_sample-ubuntu, diff --git a/samples/cpp/text2image/README.md b/samples/cpp/text2image/README.md index f73da334f4..16b1aff53c 100644 --- a/samples/cpp/text2image/README.md +++ b/samples/cpp/text2image/README.md @@ -36,17 +36,6 @@ Prompt: `cyberpunk cityscape like Tokyo New York with tall buildings at dusk gol ![](./512x512.bmp) -## Supported models - -Models can be downloaded from [HuggingFace](https://huggingface.co/models). This sample can run the following list of models, but not limited to: - -- [botp/stable-diffusion-v1-5](https://huggingface.co/botp/stable-diffusion-v1-5) -- [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) -- [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) -- [dreamlike-art/dreamlike-anime-1.0](https://huggingface.co/dreamlike-art/dreamlike-anime-1.0) -- [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) -- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) -- [stabilityai/stable-diffusion-xl-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9) ## Run with optional LoRA adapters diff --git a/samples/cpp/text2image/imwrite.cpp b/samples/cpp/text2image/imwrite.cpp index 31c9a19d6d..b25db03051 100644 --- a/samples/cpp/text2image/imwrite.cpp +++ b/samples/cpp/text2image/imwrite.cpp @@ -30,60 +30,59 @@ unsigned char file[14] = { }; unsigned char info[40] = { - 40, - 0, - 0, - 0, // info hd size - 0, - 0, - 0, - 0, // width - 0, - 0, - 0, - 0, // height - 1, - 0, // number color planes - 24, - 0, // bits per pixel - 0, - 0, - 0, - 0, // compression is none - 0, - 0, - 0, - 0, // image bits size - 0x13, - 0x0B, - 0, - 0, // horz resolution in pixel / m - 0x13, - 0x0B, - 0, - 0, // vert resolution (0x03C3 = 96 dpi, 0x0B13 = 72 - // dpi) - 0, - 0, - 0, - 0, // #colors in palette - 0, - 0, - 0, - 0, // #important colors - }; - -} - -void imwrite(const std::string& name, ov::Tensor image, bool convert_bgr2rgb) { - std::ofstream output_file(name, std::ofstream::binary); - OPENVINO_ASSERT(output_file.is_open(), "Failed to open the output BMP image path"); + 40, + 0, + 0, + 0, // info hd size + 0, + 0, + 0, + 0, // width + 0, + 0, + 0, + 0, // height + 1, + 0, // number color planes + 24, + 0, // bits per pixel + 0, + 0, + 0, + 0, // compression is none + 0, + 0, + 0, + 0, // image bits size + 0x13, + 0x0B, + 0, + 0, // horz resolution in pixel / m + 0x13, + 0x0B, + 0, + 0, // vert resolution (0x03C3 = 96 dpi, 0x0B13 = 72 + // dpi) + 0, + 0, + 0, + 0, // #colors in palette + 0, + 0, + 0, + 0, // #important colors +}; +void imwrite_single_image(const std::string& name, ov::Tensor image, bool convert_bgr2rgb) { const ov::Shape shape = image.get_shape(); const size_t width = shape[2], height = shape[1], channels = shape[3]; OPENVINO_ASSERT(image.get_element_type() == ov::element::u8 && shape.size() == 4 && shape[0] == 1 && channels == 3, - "Image of u8 type and [1, H, W, 3] shape is expected"); + "Image of u8 type and [1, H, W, 3] shape is expected.", + "Given image has shape ", shape, " and element type ", image.get_element_type()); + + std::ofstream output_file(name, std::ofstream::binary); + OPENVINO_ASSERT(output_file.is_open(), "Failed to open the output BMP image path"); int padSize = static_cast(4 - (width * channels) % 4) % 4; int sizeData = static_cast(width * height * channels + height * padSize); @@ -131,3 +130,19 @@ void imwrite(const std::string& name, ov::Tensor image, bool convert_bgr2rgb) { output_file.write(reinterpret_cast(pad), padSize); } } + +} // namespace + + +void imwrite(const std::string& name, ov::Tensor images, bool convert_bgr2rgb) { + const ov::Shape shape = images.get_shape(), img_shape = {1, shape[1], shape[2], shape[3]}; + uint8_t* img_data = images.data(); + + for (int img_num = 0, num_images = shape[0], img_size = ov::shape_size(img_shape); img_num < num_images; ++img_num, img_data += img_size) { + char img_name[25]; + sprintf(img_name, name.c_str(), img_num); + + ov::Tensor image(images.get_element_type(), img_shape, img_data); + imwrite_single_image(img_name, image, true); + } +} diff --git a/samples/cpp/text2image/imwrite.hpp b/samples/cpp/text2image/imwrite.hpp index 4fd48004dd..9b8752fb07 100644 --- a/samples/cpp/text2image/imwrite.hpp +++ b/samples/cpp/text2image/imwrite.hpp @@ -8,9 +8,9 @@ #include "openvino/runtime/tensor.hpp" /** - * @brief Writes image to file - * @param name File name - * @param image Image tensor + * @brief Writes mutiple images (depending on `image` tensor batch size) to BPM file(s) + * @param name File name or pattern to use to write images + * @param image Image(s) tensor * @param convert_bgr2rgb Convert BGR to RGB */ -void imwrite(const std::string& name, ov::Tensor image, bool convert_bgr2rgb); +void imwrite(const std::string& name, ov::Tensor images, bool convert_bgr2rgb); diff --git a/samples/cpp/text2image/main.cpp b/samples/cpp/text2image/main.cpp index 02c632d53e..1cef148796 100644 --- a/samples/cpp/text2image/main.cpp +++ b/samples/cpp/text2image/main.cpp @@ -5,35 +5,6 @@ #include "imwrite.hpp" -namespace { - - void imwrite_output_imgs(const ov::Tensor& output) { - ov::Shape out_shape = output.get_shape(); - - if (out_shape[0] == 1) { - imwrite("image.bmp", output, true); - return; - } - - ov::Shape img_shape = {1, out_shape[1], out_shape[2], out_shape[3]}; - size_t img_size = output.get_size() / out_shape[0]; - - ov::Tensor image(output.get_element_type(), img_shape); - uint8_t* out_data = output.data(); - uint8_t* img_data = image.data(); - - for (int img_num = 0; img_num < out_shape[0]; ++img_num) { - std::memcpy(img_data, out_data + img_size * img_num, img_size * sizeof(uint8_t)); - - char img_name[25]; - sprintf(img_name, "image_%d.bmp", img_num); - - imwrite(img_name, image, true); - } - } - -} //namespace - int32_t main(int32_t argc, char* argv[]) try { OPENVINO_ASSERT(argc == 3, "Usage: ", argv[0], " ''"); @@ -47,7 +18,8 @@ int32_t main(int32_t argc, char* argv[]) try { ov::genai::num_inference_steps(20), ov::genai::num_images_per_prompt(1)); - imwrite_output_imgs(image); + // writes `num_images_per_prompt` images by pattern name + imwrite("image_%d.bmp", image, true); return EXIT_SUCCESS; } catch (const std::exception& error) { diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index 435a122d55..ae40818ed8 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -35,6 +35,13 @@ function(ov_genai_build_jinja2cpp) option(RAPIDJSON_BUILD_DOC "Build rapidjson documentation." OFF) add_subdirectory("${jinja2cpp_SOURCE_DIR}" "${jinja2cpp_BINARY_DIR}" EXCLUDE_FROM_ALL) + + if(CMAKE_COMPILER_IS_GNUCXX OR OV_COMPILER_IS_CLANG OR (OV_COMPILER_IS_INTEL_LLVM AND UNIX)) + target_compile_options(jinja2cpp PRIVATE -Wno-undef) + endif() + if(SUGGEST_OVERRIDE_SUPPORTED) + target_compile_options(jinja2cpp PRIVATE -Wno-suggest-override) + endif() endif() endfunction() diff --git a/src/cpp/include/openvino/genai/text2image/pipeline.hpp b/src/cpp/include/openvino/genai/text2image/pipeline.hpp index 5ce6a08b11..e3a59cf025 100644 --- a/src/cpp/include/openvino/genai/text2image/pipeline.hpp +++ b/src/cpp/include/openvino/genai/text2image/pipeline.hpp @@ -70,8 +70,8 @@ class OPENVINO_GENAI_EXPORTS Text2ImagePipeline { // SD XL: prompt2 and negative_prompt2 // FLUX: prompt2 (prompt if prompt2 is not defined explicitly) // SD 3: prompt2, prompt3 (with fallback to prompt) and negative_prompt2, negative_prompt3 - std::string prompt2, prompt3; - std::string negative_prompt, negative_prompt2, negative_prompt3; + std::optional prompt_2 = std::nullopt, prompt_3 = std::nullopt; + std::string negative_prompt, negative_prompt_2, negative_prompt_3; size_t num_images_per_prompt = 1; @@ -165,12 +165,12 @@ class OPENVINO_GENAI_EXPORTS Text2ImagePipeline { // Generation config properties // -static constexpr ov::Property prompt2{"prompt2"}; -static constexpr ov::Property prompt3{"prompt3"}; +static constexpr ov::Property prompt_2{"prompt_2"}; +static constexpr ov::Property prompt_3{"prompt_3"}; static constexpr ov::Property negative_prompt{"negative_prompt"}; -static constexpr ov::Property negative_prompt2{"negative_prompt2"}; -static constexpr ov::Property negative_prompt3{"negative_prompt3"}; +static constexpr ov::Property negative_prompt_2{"negative_prompt_2"}; +static constexpr ov::Property negative_prompt_3{"negative_prompt_3"}; static constexpr ov::Property num_images_per_prompt{"num_images_per_prompt"}; static constexpr ov::Property guidance_scale{"guidance_scale"}; diff --git a/src/cpp/src/text2image/diffusion_pipeline.hpp b/src/cpp/src/text2image/diffusion_pipeline.hpp index 1884df4ca6..58843b8667 100644 --- a/src/cpp/src/text2image/diffusion_pipeline.hpp +++ b/src/cpp/src/text2image/diffusion_pipeline.hpp @@ -71,7 +71,9 @@ class Text2ImagePipeline::DiffusionPipeline { protected: virtual void initialize_generation_config(const std::string& class_name) = 0; - virtual void check_inputs(const int height, const int width) const = 0; + virtual void check_image_size(const int height, const int width) const = 0; + + virtual void check_inputs(const GenerationConfig& generation_config) const = 0; std::shared_ptr m_scheduler; GenerationConfig m_generation_config; diff --git a/src/cpp/src/text2image/models/clip_text_model.cpp b/src/cpp/src/text2image/models/clip_text_model.cpp index b8ec871eb0..06cbdd1852 100644 --- a/src/cpp/src/text2image/models/clip_text_model.cpp +++ b/src/cpp/src/text2image/models/clip_text_model.cpp @@ -94,16 +94,16 @@ ov::Tensor CLIPTextModel::infer(const std::string& pos_prompt, const std::string if (do_classifier_free_guidance) { perform_tokenization(neg_prompt, - ov::Tensor(input_ids, {current_batch_idx , 0}, - {current_batch_idx + 1, m_config.max_position_embeddings})); + ov::Tensor(input_ids, {current_batch_idx , 0}, + {current_batch_idx + 1, m_config.max_position_embeddings})); ++current_batch_idx; } else { // Negative prompt is ignored when --guidanceScale < 1.0 } perform_tokenization(pos_prompt, - ov::Tensor(input_ids, {current_batch_idx , 0}, - {current_batch_idx + 1, m_config.max_position_embeddings})); + ov::Tensor(input_ids, {current_batch_idx , 0}, + {current_batch_idx + 1, m_config.max_position_embeddings})); // text embeddings m_request.set_tensor("input_ids", input_ids); diff --git a/src/cpp/src/text2image/models/clip_text_model_with_projection.cpp b/src/cpp/src/text2image/models/clip_text_model_with_projection.cpp index 2fa7b83738..6a268402e1 100644 --- a/src/cpp/src/text2image/models/clip_text_model_with_projection.cpp +++ b/src/cpp/src/text2image/models/clip_text_model_with_projection.cpp @@ -83,16 +83,16 @@ ov::Tensor CLIPTextModelWithProjection::infer(const std::string& pos_prompt, con if (do_classifier_free_guidance) { perform_tokenization(neg_prompt, - ov::Tensor(input_ids, {current_batch_idx , 0}, - {current_batch_idx + 1, m_config.max_position_embeddings})); + ov::Tensor(input_ids, {current_batch_idx , 0}, + {current_batch_idx + 1, m_config.max_position_embeddings})); ++current_batch_idx; } else { // Negative prompt is ignored when --guidanceScale < 1.0 } perform_tokenization(pos_prompt, - ov::Tensor(input_ids, {current_batch_idx , 0}, - {current_batch_idx + 1, m_config.max_position_embeddings})); + ov::Tensor(input_ids, {current_batch_idx , 0}, + {current_batch_idx + 1, m_config.max_position_embeddings})); // text embeddings m_request.set_tensor("input_ids", input_ids); diff --git a/src/cpp/src/text2image/schedulers/lcm.cpp b/src/cpp/src/text2image/schedulers/lcm.cpp index d8da78e5f9..f9a87da8fb 100644 --- a/src/cpp/src/text2image/schedulers/lcm.cpp +++ b/src/cpp/src/text2image/schedulers/lcm.cpp @@ -47,7 +47,7 @@ LCMScheduler::LCMScheduler(const std::string scheduler_config_path) : LCMScheduler::LCMScheduler(const Config& scheduler_config) : m_config(scheduler_config), m_seed(42), - m_gen(100, std::mt19937(m_seed)), + m_gen(m_seed), m_normal(0.0f, 1.0f) { m_sigma_data = 0.5f; // Default: 0.5 @@ -191,7 +191,7 @@ std::map LCMScheduler::step(ov::Tensor noise_pred, ov:: if (inference_step != m_num_inference_steps - 1) { for (std::size_t i = 0; i < batch_size * latent_size; ++i) { - float gen_noise = m_normal(m_gen[i / latent_size]); + float gen_noise = m_normal(m_gen); prev_sample_data[i] = alpha_prod_t_prev_sqrt * denoised_data[i] + beta_prod_t_prev_sqrt * gen_noise; } } else { diff --git a/src/cpp/src/text2image/schedulers/lcm.hpp b/src/cpp/src/text2image/schedulers/lcm.hpp index 374aea5d1e..8abbcd3e29 100644 --- a/src/cpp/src/text2image/schedulers/lcm.hpp +++ b/src/cpp/src/text2image/schedulers/lcm.hpp @@ -62,7 +62,7 @@ class LCMScheduler : public IScheduler { std::vector m_timesteps; uint32_t m_seed; - std::vector m_gen; + std::mt19937 m_gen; std::normal_distribution m_normal; std::vector threshold_sample(const std::vector& flat_sample); diff --git a/src/cpp/src/text2image/stable_diffusion_pipeline.hpp b/src/cpp/src/text2image/stable_diffusion_pipeline.hpp index 54d2d43c19..f2543474ec 100644 --- a/src/cpp/src/text2image/stable_diffusion_pipeline.hpp +++ b/src/cpp/src/text2image/stable_diffusion_pipeline.hpp @@ -120,7 +120,7 @@ class Text2ImagePipeline::StableDiffusionPipeline : public Text2ImagePipeline::D m_vae_decoder(std::make_shared(vae_decoder)) { } void reshape(const int num_images_per_prompt, const int height, const int width, const float guidance_scale) override { - check_inputs(height, width); + check_image_size(height, width); const size_t batch_size_multiplier = do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG m_clip_text_encoder->reshape(batch_size_multiplier); @@ -150,7 +150,7 @@ class Text2ImagePipeline::StableDiffusionPipeline : public Text2ImagePipeline::D generation_config.height = unet_config.sample_size * vae_scale_factor; if (generation_config.width < 0) generation_config.width = unet_config.sample_size * vae_scale_factor; - check_inputs(generation_config.height, generation_config.width); + check_inputs(generation_config); m_clip_text_encoder->set_adapters(generation_config.adapters); m_unet->set_adapters(generation_config.adapters); @@ -183,7 +183,7 @@ class Text2ImagePipeline::StableDiffusionPipeline : public Text2ImagePipeline::D m_unet->set_hidden_states("encoder_hidden_states", encoder_hidden_states_repeated); } - if (unet_config.time_cond_proj_dim >= 0) { + if (unet_config.time_cond_proj_dim >= 0) { // LCM ov::Tensor guidance_scale_embedding = get_guidance_scale_embedding(generation_config.guidance_scale, unet_config.time_cond_proj_dim); m_unet->set_hidden_states("timestep_cond", guidance_scale_embedding); } @@ -249,7 +249,7 @@ class Text2ImagePipeline::StableDiffusionPipeline : public Text2ImagePipeline::D private: bool do_classifier_free_guidance(float guidance_scale) const { - return guidance_scale > 1.0 && m_unet->get_config().time_cond_proj_dim < 0; + return guidance_scale >= 1.0f && m_unet->get_config().time_cond_proj_dim < 0; } void initialize_generation_config(const std::string& class_name) override { @@ -271,7 +271,7 @@ class Text2ImagePipeline::StableDiffusionPipeline : public Text2ImagePipeline::D } } - void check_inputs(const int height, const int width) const override { + void check_image_size(const int height, const int width) const override { assert(m_unet != nullptr); const size_t vae_scale_factor = m_unet->get_vae_scale_factor(); OPENVINO_ASSERT((height % vae_scale_factor == 0 || height < 0) && @@ -279,6 +279,24 @@ class Text2ImagePipeline::StableDiffusionPipeline : public Text2ImagePipeline::D vae_scale_factor); } + void check_inputs(const GenerationConfig& generation_config) const override { + check_image_size(generation_config.width, generation_config.height); + + const bool is_classifier_free_guidance = do_classifier_free_guidance(generation_config.guidance_scale); + const bool is_lcm = m_unet->get_config().time_cond_proj_dim > 0; + const char * const pipeline_name = is_lcm ? "Latent Consistency Model" : "Stable Diffusion"; + + OPENVINO_ASSERT(generation_config.prompt_2 == std::nullopt, "Prompt 2 is not used by ", pipeline_name); + OPENVINO_ASSERT(generation_config.prompt_3 == std::nullopt, "Prompt 3 is not used by ", pipeline_name); + if (is_lcm) { + OPENVINO_ASSERT(generation_config.negative_prompt.empty(), "Negative prompt is not used by ", pipeline_name); + } else if (!is_classifier_free_guidance) { + OPENVINO_ASSERT(generation_config.negative_prompt.empty(), "Negative prompt is not used when guidance scale < 1.0"); + } + OPENVINO_ASSERT(generation_config.negative_prompt_2.empty(), "Negative prompt 2 is not used by ", pipeline_name); + OPENVINO_ASSERT(generation_config.negative_prompt_3.empty(), "Negative prompt 3 is not used by ", pipeline_name); + } + std::shared_ptr m_clip_text_encoder; std::shared_ptr m_unet; std::shared_ptr m_vae_decoder; diff --git a/src/cpp/src/text2image/stable_diffusion_xl_pipeline.hpp b/src/cpp/src/text2image/stable_diffusion_xl_pipeline.hpp index 95ea2abc5d..15c82fc36a 100644 --- a/src/cpp/src/text2image/stable_diffusion_xl_pipeline.hpp +++ b/src/cpp/src/text2image/stable_diffusion_xl_pipeline.hpp @@ -108,7 +108,7 @@ class Text2ImagePipeline::StableDiffusionXLPipeline : public Text2ImagePipeline: m_vae_decoder(std::make_shared(vae_decoder)) { } void reshape(const int num_images_per_prompt, const int height, const int width, const float guidance_scale) override { - check_inputs(height, width); + check_image_size(height, width); const size_t batch_size_multiplier = do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG m_clip_text_encoder->reshape(batch_size_multiplier); @@ -140,7 +140,7 @@ class Text2ImagePipeline::StableDiffusionXLPipeline : public Text2ImagePipeline: generation_config.height = unet_config.sample_size * vae_scale_factor; if (generation_config.width < 0) generation_config.width = unet_config.sample_size * vae_scale_factor; - check_inputs(generation_config.height, generation_config.width); + check_image_size(generation_config.height, generation_config.width); if (generation_config.random_generator == nullptr) { uint32_t seed = time(NULL); @@ -308,7 +308,7 @@ class Text2ImagePipeline::StableDiffusionXLPipeline : public Text2ImagePipeline: private: bool do_classifier_free_guidance(float guidance_scale) const { - return guidance_scale > 1.0 && m_unet->get_config().time_cond_proj_dim < 0; + return guidance_scale >= 1.0f && m_unet->get_config().time_cond_proj_dim < 0; } void initialize_generation_config(const std::string& class_name) override { @@ -327,7 +327,7 @@ class Text2ImagePipeline::StableDiffusionXLPipeline : public Text2ImagePipeline: } } - void check_inputs(const int height, const int width) const override { + void check_image_size(const int height, const int width) const override { assert(m_unet != nullptr); const size_t vae_scale_factor = m_unet->get_vae_scale_factor(); OPENVINO_ASSERT((height % vae_scale_factor == 0 || height < 0) && @@ -335,6 +335,18 @@ class Text2ImagePipeline::StableDiffusionXLPipeline : public Text2ImagePipeline: vae_scale_factor); } + void check_inputs(const GenerationConfig& generation_config) const override { + check_image_size(generation_config.width, generation_config.height); + + const bool is_classifier_free_guidance = do_classifier_free_guidance(generation_config.guidance_scale); + const char * const pipeline_name = "Stable Diffusion XL"; + + OPENVINO_ASSERT(generation_config.prompt_3 == std::nullopt, "Prompt 3 is not used by ", pipeline_name); + OPENVINO_ASSERT(is_classifier_free_guidance || generation_config.negative_prompt.empty(), "Negative prompt is not used when guidance scale < 1.0"); + OPENVINO_ASSERT(is_classifier_free_guidance || generation_config.negative_prompt_2.empty(), "Negative prompt 2 is not used when guidance scale < 1.0"); + OPENVINO_ASSERT(generation_config.negative_prompt_3.empty(), "Negative prompt 3 is not used by ", pipeline_name); + } + std::shared_ptr m_clip_text_encoder; std::shared_ptr m_clip_text_encoder_with_projection; std::shared_ptr m_unet; diff --git a/src/cpp/src/text2image/text2image_pipeline.cpp b/src/cpp/src/text2image/text2image_pipeline.cpp index f7a6ab65ae..04422ef12f 100644 --- a/src/cpp/src/text2image/text2image_pipeline.cpp +++ b/src/cpp/src/text2image/text2image_pipeline.cpp @@ -37,8 +37,13 @@ void Text2ImagePipeline::GenerationConfig::update_generation_config(const ov::An // override whole generation config first read_anymap_param(properties, SD_GENERATION_CONFIG, *this); + // then try per-parameter values + read_anymap_param(properties, "prompt_2", prompt_2); + read_anymap_param(properties, "prompt_3", prompt_3); read_anymap_param(properties, "negative_prompt", negative_prompt); + read_anymap_param(properties, "negative_prompt_2", negative_prompt_2); + read_anymap_param(properties, "negative_prompt_3", negative_prompt_3); read_anymap_param(properties, "num_images_per_prompt", num_images_per_prompt); read_anymap_param(properties, "random_generator", random_generator); read_anymap_param(properties, "guidance_scale", guidance_scale); diff --git a/src/docs/SUPPORTED_MODELS.md b/src/docs/SUPPORTED_MODELS.md index fb6df36950..d9dddc64b7 100644 --- a/src/docs/SUPPORTED_MODELS.md +++ b/src/docs/SUPPORTED_MODELS.md @@ -157,6 +157,45 @@ The pipeline can work with other similar topologies produced by `optimum-intel` > [!NOTE] > Models should belong to the same family and have the same tokenizers. +## Text 2 image models + + + + + + + + + + + + + + + + + + + + +
ArchitectureExample HuggingFace Models
Latent Consistency Model + +
Stable Diffusion + +
Stable Diffusion XL + +
+ ## Visual language models