Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: b66251317fbebfbb8e1f2ddc64ea5da84bceb7e5
  • Loading branch information
MediaPipe Team authored and jqtang committed May 7, 2022
1 parent 7fb37c8 commit 4a20e99
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 89 deletions.
81 changes: 50 additions & 31 deletions mediapipe/calculators/tensor/inference_calculator_gl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "mediapipe/calculators/tensor/inference_calculator.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/tflite/config.h"
#include "tensorflow/lite/interpreter_builder.h"

#if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h"
Expand Down Expand Up @@ -53,11 +52,9 @@ class InferenceCalculatorGlImpl
private:
absl::Status ReadGpuCaches();
absl::Status SaveGpuCaches();
absl::Status InitInterpreter(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc,
tflite::InterpreterBuilder* interpreter_builder);
absl::Status BindBuffersToTensors();
absl::Status AllocateTensors();
absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);

// TfLite requires us to keep the model alive as long as the interpreter is.
Expand Down Expand Up @@ -140,11 +137,17 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
#endif // MEDIAPIPE_ANDROID
}

// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
// for everything.
if (!use_advanced_gpu_api_) {
MP_RETURN_IF_ERROR(LoadModel(cc));
}

MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
: InitInterpreter(cc);
: LoadDelegateAndAllocateTensors(cc);
}));
return absl::OkStatus();
}
Expand Down Expand Up @@ -289,6 +292,9 @@ absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {

absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();

// Create runner
tflite::gpu::InferenceOptions options;
options.priority1 = allow_precision_loss_
Expand Down Expand Up @@ -326,12 +332,17 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
break;
}
}
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
const auto& op_resolver = op_resolver_packet.Get();
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true));
if (kSideInOpResolver(cc).IsConnected()) {
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true));
} else {
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true));
}

// Create and bind OpenGL buffers for outputs.
// The buffers are created once and their ids are passed to calculator outputs
Expand All @@ -350,27 +361,35 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
return absl::OkStatus();
}

absl::Status InferenceCalculatorGlImpl::InitInterpreter(CalculatorContext* cc) {
absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
const auto& op_resolver = op_resolver_packet.Get();
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
if (kSideInOpResolver(cc).IsConnected()) {
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
} else {
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
}
RET_CHECK(interpreter_);

#if defined(__EMSCRIPTEN__)
interpreter_builder.SetNumThreads(1);
interpreter_->SetNumThreads(1);
#else
interpreter_builder.SetNumThreads(
interpreter_->SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
#endif // __EMSCRIPTEN__
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
RET_CHECK(interpreter_);
MP_RETURN_IF_ERROR(BindBuffersToTensors());
MP_RETURN_IF_ERROR(AllocateTensors());

return absl::OkStatus();
}

absl::Status InferenceCalculatorGlImpl::AllocateTensors() {
absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));

// AllocateTensors() can be called only after ModifyGraphWithDelegate.
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors.
RET_CHECK_NE(
Expand All @@ -379,8 +398,7 @@ absl::Status InferenceCalculatorGlImpl::AllocateTensors() {
return absl::OkStatus();
}

absl::Status InferenceCalculatorGlImpl::LoadDelegate(
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
// Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed =
Expand All @@ -391,11 +409,7 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(
options.compile_options.inline_parameters = 1;
delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options),
&TfLiteGpuDelegateDelete);
interpreter_builder->AddDelegate(delegate_.get());
return absl::OkStatus();
}

absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() {
// Get input image sizes.
const auto& input_indices = interpreter_->inputs();
for (int i = 0; i < input_indices.size(); ++i) {
Expand Down Expand Up @@ -427,6 +441,11 @@ absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() {
output_indices[i]),
kTfLiteOk);
}

// Must call this last.
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);

return absl::OkStatus();
}

Expand Down
8 changes: 5 additions & 3 deletions mediapipe/framework/packet.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class Packet {
// Returns an error if the packet does not contain data of type T.
template <typename T>
absl::Status ValidateAsType() const {
return ValidateAsType(tool::TypeId<T>());
return ValidateAsType(tool::TypeInfo::Get<T>());
}

// Returns an error if the packet is not an instance of
Expand Down Expand Up @@ -428,7 +428,7 @@ StatusOr<std::vector<const proto_ns::MessageLite*>>
ConvertToVectorOfProtoMessageLitePtrs(const T* data,
/*is_proto_vector=*/std::false_type) {
return absl::InvalidArgumentError(absl::StrCat(
"The Packet stores \"", tool::TypeId<T>().name(), "\"",
"The Packet stores \"", tool::TypeInfo::Get<T>().name(), "\"",
"which is not convertible to vector<proto_ns::MessageLite*>."));
}

Expand Down Expand Up @@ -510,7 +510,9 @@ class Holder : public HolderBase {
HolderSupport<T>::EnsureStaticInit();
return *ptr_;
}
const tool::TypeInfo& GetTypeInfo() const final { return tool::TypeId<T>(); }
const tool::TypeInfo& GetTypeInfo() const final {
return tool::TypeInfo::Get<T>();
}
// Releases the underlying data pointer and transfers the ownership to a
// unique pointer.
// This method is dangerous and is only used by Packet::Consume() if the
Expand Down
4 changes: 2 additions & 2 deletions mediapipe/framework/packet_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,14 @@ absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set);

template <typename T>
PacketType& PacketType::Set() {
type_spec_ = &tool::TypeId<T>();
type_spec_ = &tool::TypeInfo::Get<T>();
return *this;
}

template <typename... T>
PacketType& PacketType::SetOneOf() {
static const NoDestructor<std::vector<const tool::TypeInfo*>> types{
{&tool::TypeId<T>()...}};
{&tool::TypeInfo::Get<T>()...}};
static const NoDestructor<std::string> name{TypeNameForOneOf(*types)};
type_spec_ = MultiType{*types, &*name};
return *this;
Expand Down
2 changes: 2 additions & 0 deletions mediapipe/framework/tool/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -761,9 +761,11 @@ cc_library(
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
Expand Down
6 changes: 3 additions & 3 deletions mediapipe/framework/tool/options_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ class TypeMap {
public:
template <class T>
bool Has() const {
return content_.count(TypeId<T>()) > 0;
return content_.count(TypeInfo::Get<T>()) > 0;
}
template <class T>
T* Get() const {
if (!Has<T>()) {
content_[TypeId<T>()] = std::make_shared<T>();
content_[TypeInfo::Get<T>()] = std::make_shared<T>();
}
return static_cast<T*>(content_[TypeId<T>()].get());
return static_cast<T*>(content_[TypeInfo::Get<T>()].get());
}

private:
Expand Down
40 changes: 40 additions & 0 deletions mediapipe/framework/tool/test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <memory>
#include <string>

#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
Expand All @@ -33,6 +34,7 @@
#include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/ret_check.h"
Expand Down Expand Up @@ -208,6 +210,27 @@ bool CompareImageFrames(const ImageFrame& image1, const ImageFrame& image2,
return false;
}

absl::Status CompareAndSaveImageOutput(
absl::string_view golden_image_path, const ImageFrame& actual,
const ImageFrameComparisonOptions& options) {
ASSIGN_OR_RETURN(auto output_img_path, SavePngTestOutput(actual, "output"));

auto expected = LoadTestImage(GetTestFilePath(golden_image_path));
if (!expected.ok()) {
return expected.status();
}
ASSIGN_OR_RETURN(auto expected_img_path,
SavePngTestOutput(**expected, "expected"));

std::unique_ptr<ImageFrame> diff_img;
auto status = CompareImageFrames(**expected, actual, options.max_color_diff,
options.max_alpha_diff, options.max_avg_diff,
diff_img);
ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff"));

return status;
}

std::string GetTestRootDir() {
return file::JoinPath(std::getenv("TEST_SRCDIR"), "mediapipe");
}
Expand Down Expand Up @@ -275,6 +298,23 @@ std::unique_ptr<ImageFrame> LoadTestPng(absl::string_view path,
return nullptr;
}

// Write an ImageFrame as PNG to the test undeclared outputs directory.
// The image's name will contain the given prefix and a timestamp.
// Returns the path to the output if successful.
absl::StatusOr<std::string> SavePngTestOutput(
const mediapipe::ImageFrame& image, absl::string_view prefix) {
std::string now_string = absl::FormatTime(absl::Now());
std::string output_relative_path =
absl::StrCat(prefix, "_", now_string, ".png");
std::string output_full_path =
file::JoinPath(GetTestOutputsDir(), output_relative_path);
RET_CHECK(stbi_write_png(output_full_path.c_str(), image.Width(),
image.Height(), image.NumberOfChannels(),
image.PixelData(), image.WidthStep()))
<< " path: " << output_full_path;
return output_relative_path;
}

bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path) {
int fd = open(path.c_str(), O_RDONLY);
if (fd == -1) {
Expand Down
40 changes: 30 additions & 10 deletions mediapipe/framework/tool/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,33 @@
namespace mediapipe {
using mediapipe::CalculatorGraphConfig;

struct ImageFrameComparisonOptions {
// NOTE: these values are not normalized: use a value from 0 to 2^8-1
// for 8-bit data and a value from 0 to 2^16-1 for 16-bit data.
// Although these members are declared as floats,, all uint8/uint16
// values are exactly representable. (2^24 + 1 is the first non-representable
// positive integral value.)

// Maximum value difference allowed for non-alpha channels.
float max_color_diff;
// Maximum value difference allowed for alpha channel (if present).
float max_alpha_diff;
// Maximum difference for all channels, averaged across all pixels.
float max_avg_diff;
};

// Compares an output image with a golden file. Saves the output and difference
// to the undeclared test outputs.
// Returns ok if they are equal within the tolerances specified in options.
absl::Status CompareAndSaveImageOutput(
absl::string_view golden_image_path, const ImageFrame& actual,
const ImageFrameComparisonOptions& options);

// Checks if two image frames are equal within the specified tolerance.
// image1 and image2 may be of different-but-compatible image formats (e.g.,
// SRGB and SRGBA); in that case, only the channels available in both are
// compared.
// max_color_diff applies to the first 3 channels; i.e., R, G, B for sRGB and
// sRGBA, and the single gray channel for GRAY8 and GRAY16. It is the maximum
// pixel color value difference allowed; i.e., a value from 0 to 2^8-1 for 8-bit
// data and a value from 0 to 2^16-1 for 16-bit data.
// max_alpha_diff applies to the 4th (alpha) channel only, if present.
// max_avg_diff applies to all channels, normalized across all pixels.
//
// Note: Although max_color_diff and max_alpha_diff are floats, all uint8/uint16
// values are exactly representable. (2^24 + 1 is the first non-representable
// positive integral value.)
// The diff arguments are as in ImageFrameComparisonOptions.
absl::Status CompareImageFrames(const ImageFrame& image1,
const ImageFrame& image2,
const float max_color_diff,
Expand Down Expand Up @@ -77,6 +90,13 @@ absl::StatusOr<std::unique_ptr<ImageFrame>> LoadTestImage(
std::unique_ptr<ImageFrame> LoadTestPng(
absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA);

// Write an ImageFrame as PNG to the test undeclared outputs directory.
// The image's name will contain the given prefix and a timestamp.
// If successful, returns the path to the output file relative to the output
// directory.
absl::StatusOr<std::string> SavePngTestOutput(
const mediapipe::ImageFrame& image, absl::string_view prefix);

// Returns the luminance image of |original_image|.
// The format of |original_image| must be sRGB or sRGBA.
std::unique_ptr<ImageFrame> GenerateLuminanceImage(
Expand Down
8 changes: 1 addition & 7 deletions mediapipe/framework/tool/type_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,6 @@ class TypeIndex {
const TypeInfo& info_;
};

// Returns a unique identifier for type T.
template <typename T>
const TypeInfo& TypeId() {
return TypeInfo::Get<T>();
}

// Helper method that returns a hash code of the given type. This allows for
// typeid testing across multiple binaries, unlike FastTypeId which used a
// memory location that only works within the same binary. Moreover, we use this
Expand All @@ -94,7 +88,7 @@ const TypeInfo& TypeId() {
// as much as possible.
template <typename T>
size_t GetTypeHash() {
return TypeId<T>().hash_code();
return TypeInfo::Get<T>().hash_code();
}

} // namespace tool
Expand Down
2 changes: 1 addition & 1 deletion mediapipe/framework/type_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ inline std::string MediaPipeTypeStringOrDemangled(

template <typename T>
std::string MediaPipeTypeStringOrDemangled() {
return MediaPipeTypeStringOrDemangled(tool::TypeId<T>());
return MediaPipeTypeStringOrDemangled(tool::TypeInfo::Get<T>());
}

// Returns type hash id of type identified by type_string or NULL if not
Expand Down
Loading

0 comments on commit 4a20e99

Please sign in to comment.