Skip to content

Commit

Permalink
text2image: Add static batch-size 1 implementation of unet_inference …
Browse files Browse the repository at this point in the history
…for NPU
  • Loading branch information
RyanMetcalfeInt8 committed Oct 29, 2024
1 parent 77226cc commit da387fd
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class OPENVINO_GENAI_EXPORTS UNet2DConditionModel {
size_t m_vae_scale_factor;

class UNetInferenceDynamic;
class UNetInferenceStaticBS1;
};

} // namespace genai
Expand Down
7 changes: 6 additions & 1 deletion src/cpp/src/text2image/models/unet2d_condition_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "openvino/genai/text2image/unet2d_condition_model.hpp"
#include "text2image/models/unet_inference_dynamic.hpp"
#include "text2image/models/unet_inference_static_bs1.hpp"

#include <fstream>

Expand Down Expand Up @@ -73,7 +74,11 @@ UNet2DConditionModel& UNet2DConditionModel::reshape(int batch_size, int height,
UNet2DConditionModel& UNet2DConditionModel::compile(const std::string& device, const ov::AnyMap& properties) {
OPENVINO_ASSERT(m_model, "Model has been already compiled. Cannot re-compile already compiled model");

m_impl = std::make_shared<UNet2DConditionModel::UNetInferenceDynamic>();
if (device == "NPU") {
m_impl = std::make_shared<UNet2DConditionModel::UNetInferenceStaticBS1>();
} else {
m_impl = std::make_shared<UNet2DConditionModel::UNetInferenceDynamic>();
}

std::optional<AdapterConfig> adapters;
if (auto filtered_properties = extract_adapters_from_properties(properties, &adapters)) {
Expand Down
31 changes: 0 additions & 31 deletions src/cpp/src/text2image/models/unet_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,6 @@
namespace ov {
namespace genai {

//REMOVE THIS
static inline void logBasicModelInfo(const std::shared_ptr<ov::Model>& model) {
std::cout << "Model name: " << model->get_friendly_name() << std::endl;

// Dump information about model inputs/outputs
ov::OutputVector inputs = model->inputs();
ov::OutputVector outputs = model->outputs();

std::cout << "\tInputs: " << std::endl;
for (const ov::Output<ov::Node>& input : inputs) {
const std::string name = input.get_any_name();
const ov::element::Type type = input.get_element_type();
const ov::PartialShape shape = input.get_partial_shape();
const ov::Layout layout = ov::layout::get_layout(input);

std::cout << "\t\t" << name << ", " << type << ", " << shape << ", " << layout.to_string() << std::endl;
}

std::cout << "\tOutputs: " << std::endl;
for (const ov::Output<ov::Node>& output : outputs) {
const std::string name = output.get_any_name();
const ov::element::Type type = output.get_element_type();
const ov::PartialShape shape = output.get_partial_shape();
const ov::Layout layout = ov::layout::get_layout(output);

std::cout << "\t\t" << name << ", " << type << ", " << shape << ", " << layout.to_string() << std::endl;
}

return;
}

class UNet2DConditionModel::UNetInference {

public:
Expand Down
144 changes: 144 additions & 0 deletions src/cpp/src/text2image/models/unet_inference_static_bs1.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "lora_helper.hpp"
#include "text2image/models/unet_inference.hpp"
#include "utils.hpp"

namespace ov {
namespace genai {

// Static Batch-Size 1 variant of UNetInference
class UNet2DConditionModel::UNetInferenceStaticBS1 : public UNet2DConditionModel::UNetInference {
public:
virtual void compile(std::shared_ptr<ov::Model> model,
const std::string& device,
const ov::AnyMap& properties) override {

// All shapes for input/output tensors should be static.
// Double check this and throw runtime error if it's not the case.
for (auto& input : model->inputs()) {
if (input.get_partial_shape().is_dynamic()) {
throw std::runtime_error(
"UNetInferenceStaticBS1::compile: input tensor " + input.get_any_name() +
" shape is dynamic. Tensors must be reshaped to be static before compile is invoked.");
}
}

for (auto& output : model->outputs()) {
if (output.get_partial_shape().is_dynamic()) {
throw std::runtime_error(
"UNetInferenceStaticBS1::compile: output tensor " + output.get_any_name() +
" shape is dynamic. Tensors must be reshaped to be static before compile is invoked.");
}
}

// we'll create a separate infer request for each batch.
m_nativeBatchSize = model->input("sample").get_shape()[0];
m_requests.resize(m_nativeBatchSize);

//reshape to batch-1
UNetInference::reshape(model, 1);

ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model = core.compile_model(model, device, properties);

for (int i = 0; i < m_nativeBatchSize; i++ )
{
m_requests[i] = compiled_model.create_infer_request();
}
}

virtual void set_hidden_states(const std::string& tensor_name, ov::Tensor encoder_hidden_states) override {
OPENVINO_ASSERT(m_nativeBatchSize && m_nativeBatchSize == m_requests.size(),
"UNet model must be compiled first");

size_t encoder_hidden_states_bs = encoder_hidden_states.get_shape()[0];
if (encoder_hidden_states_bs != m_nativeBatchSize)
{
throw std::runtime_error("UNetInferenceStaticBS1::set_hidden_states: native batch size is " + std::to_string(m_nativeBatchSize)
+ ", but encoder_hidden_states has batch size of " + std::to_string(encoder_hidden_states_bs));
}

char* pHiddenStates = (char *)encoder_hidden_states.data();
size_t hidden_states_batch_stride_bytes = encoder_hidden_states.get_strides()[0];

for (int i = 0; i < m_nativeBatchSize; i++)
{
auto hidden_states_bs1 = m_requests[i].get_tensor(tensor_name);

// wrap current pHiddenStates location as batch-1 tensor.
ov::Tensor bs1_wrapper(hidden_states_bs1.get_element_type(),
hidden_states_bs1.get_shape(),
pHiddenStates,
encoder_hidden_states.get_strides());

// copy it to infer request batch-1 tensor
bs1_wrapper.copy_to(hidden_states_bs1);

// increment pHiddenStates to start location of next batch (using stride)
pHiddenStates += hidden_states_batch_stride_bytes;
}
}

virtual void set_adapters(AdapterController& adapter_controller, const AdapterConfig& adapters) override {
OPENVINO_ASSERT(m_nativeBatchSize && m_nativeBatchSize == m_requests.size(),
"UNet model must be compiled first");
for (int i = 0; i < m_nativeBatchSize; i++) {
adapter_controller.apply(m_requests[i], adapters);
}
}

virtual ov::Tensor infer(ov::Tensor sample, ov::Tensor timestep) override {
OPENVINO_ASSERT(m_nativeBatchSize && m_nativeBatchSize == m_requests.size(),
"UNet model must be compiled first");

char* pSample = (char *)sample.data();
size_t sample_batch_stride_bytes = sample.get_strides()[0];

for (int i = 0; i < m_nativeBatchSize; i++) {
m_requests[i].set_tensor("timestep", timestep);

auto sample_bs1 = m_requests[i].get_tensor("sample");

// wrap current pSample location as batch-1 tensor.
ov::Tensor bs1_wrapper(sample_bs1.get_element_type(), sample_bs1.get_shape(), pSample, sample.get_strides());

// copy it to infer request batch-1 tensor
bs1_wrapper.copy_to(sample_bs1);

//increment pSample to start location of next batch (using stride)
pSample += sample_batch_stride_bytes;

// kick off infer for this request.
m_requests[i].start_async();
}

auto out_sample = ov::Tensor(sample.get_element_type(), sample.get_shape());

char* pOutSample = (char *)out_sample.data();
size_t out_sample_batch_stride_bytes = out_sample.get_strides()[0];
for (int i = 0; i < m_nativeBatchSize; i++) {

// wait for infer to complete.
m_requests[i].wait();

auto out_sample_bs1 = m_requests[i].get_tensor("out_sample");
ov::Tensor bs1_wrapper(out_sample_bs1.get_element_type(), out_sample_bs1.get_shape(), pOutSample, out_sample.get_strides());
out_sample_bs1.copy_to(bs1_wrapper);

pOutSample += out_sample_batch_stride_bytes;
}

return out_sample;
}

private:
std::vector<ov::InferRequest> m_requests;
size_t m_nativeBatchSize = 0;
};

} // namespace genai
} // namespace ov

0 comments on commit da387fd

Please sign in to comment.