From 4a20e9909d55838d5630366ce719844cf06ae85c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 6 May 2022 14:39:20 -0700 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: b66251317fbebfbb8e1f2ddc64ea5da84bceb7e5 --- .../tensor/inference_calculator_gl.cc | 81 ++++++++++++------- mediapipe/framework/packet.h | 8 +- mediapipe/framework/packet_type.h | 4 +- mediapipe/framework/tool/BUILD | 2 + mediapipe/framework/tool/options_map.h | 6 +- mediapipe/framework/tool/test_util.cc | 40 +++++++++ mediapipe/framework/tool/test_util.h | 40 ++++++--- mediapipe/framework/tool/type_util.h | 8 +- mediapipe/framework/type_map.h | 2 +- mediapipe/gpu/gpu_buffer_test.cc | 36 +++------ .../mediapipe/framework/PacketGetter.java | 21 ++++- .../google/mediapipe/framework/ProtoUtil.java | 5 +- 12 files changed, 164 insertions(+), 89 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index dfdf7382c1..eb6ab9f40b 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -22,7 +22,6 @@ #include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/util/tflite/config.h" -#include "tensorflow/lite/interpreter_builder.h" #if MEDIAPIPE_TFLITE_GL_INFERENCE #include "mediapipe/gpu/gl_calculator_helper.h" @@ -53,11 +52,9 @@ class InferenceCalculatorGlImpl private: absl::Status ReadGpuCaches(); absl::Status SaveGpuCaches(); - absl::Status InitInterpreter(CalculatorContext* cc); - absl::Status LoadDelegate(CalculatorContext* cc, - tflite::InterpreterBuilder* interpreter_builder); - absl::Status BindBuffersToTensors(); - absl::Status AllocateTensors(); + absl::Status LoadModel(CalculatorContext* cc); + absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); absl::Status InitTFLiteGPURunner(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. @@ -140,11 +137,17 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { #endif // MEDIAPIPE_ANDROID } + // When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner + // for everything. + if (!use_advanced_gpu_api_) { + MP_RETURN_IF_ERROR(LoadModel(cc)); + } + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) - : InitInterpreter(cc); + : LoadDelegateAndAllocateTensors(cc); })); return absl::OkStatus(); } @@ -289,6 +292,9 @@ absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() { absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( CalculatorContext* cc) { + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); + const auto& model = *model_packet_.Get(); + // Create runner tflite::gpu::InferenceOptions options; options.priority1 = allow_precision_loss_ @@ -326,12 +332,17 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( break; } } - ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); - const auto& model = *model_packet_.Get(); - ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc)); - const auto& op_resolver = op_resolver_packet.Get(); - MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( - model, op_resolver, /*allow_quant_ops=*/true)); + if (kSideInOpResolver(cc).IsConnected()) { + const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get(); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( + model, op_resolver, /*allow_quant_ops=*/true)); + } else { + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( + model, op_resolver, /*allow_quant_ops=*/true)); + } // Create and bind OpenGL buffers for outputs. // The buffers are created once and their ids are passed to calculator outputs @@ -350,27 +361,35 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( return absl::OkStatus(); } -absl::Status InferenceCalculatorGlImpl::InitInterpreter(CalculatorContext* cc) { +absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); const auto& model = *model_packet_.Get(); - ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc)); - const auto& op_resolver = op_resolver_packet.Get(); - tflite::InterpreterBuilder interpreter_builder(model, op_resolver); - MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder)); + if (kSideInOpResolver(cc).IsConnected()) { + const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get(); + tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); + } else { + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); + tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); + } + RET_CHECK(interpreter_); + #if defined(__EMSCRIPTEN__) - interpreter_builder.SetNumThreads(1); + interpreter_->SetNumThreads(1); #else - interpreter_builder.SetNumThreads( + interpreter_->SetNumThreads( cc->Options().cpu_num_thread()); #endif // __EMSCRIPTEN__ - RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk); - RET_CHECK(interpreter_); - MP_RETURN_IF_ERROR(BindBuffersToTensors()); - MP_RETURN_IF_ERROR(AllocateTensors()); + return absl::OkStatus(); } -absl::Status InferenceCalculatorGlImpl::AllocateTensors() { +absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors( + CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + + // AllocateTensors() can be called only after ModifyGraphWithDelegate. RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. RET_CHECK_NE( @@ -379,8 +398,7 @@ absl::Status InferenceCalculatorGlImpl::AllocateTensors() { return absl::OkStatus(); } -absl::Status InferenceCalculatorGlImpl::LoadDelegate( - CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) { +absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) { // Configure and create the delegate. TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); options.compile_options.precision_loss_allowed = @@ -391,11 +409,7 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate( options.compile_options.inline_parameters = 1; delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options), &TfLiteGpuDelegateDelete); - interpreter_builder->AddDelegate(delegate_.get()); - return absl::OkStatus(); -} -absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() { // Get input image sizes. const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { @@ -427,6 +441,11 @@ absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() { output_indices[i]), kTfLiteOk); } + + // Must call this last. + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), + kTfLiteOk); + return absl::OkStatus(); } diff --git a/mediapipe/framework/packet.h b/mediapipe/framework/packet.h index 4b0e48fbcf..82f0ec0872 100644 --- a/mediapipe/framework/packet.h +++ b/mediapipe/framework/packet.h @@ -180,7 +180,7 @@ class Packet { // Returns an error if the packet does not contain data of type T. template absl::Status ValidateAsType() const { - return ValidateAsType(tool::TypeId()); + return ValidateAsType(tool::TypeInfo::Get()); } // Returns an error if the packet is not an instance of @@ -428,7 +428,7 @@ StatusOr> ConvertToVectorOfProtoMessageLitePtrs(const T* data, /*is_proto_vector=*/std::false_type) { return absl::InvalidArgumentError(absl::StrCat( - "The Packet stores \"", tool::TypeId().name(), "\"", + "The Packet stores \"", tool::TypeInfo::Get().name(), "\"", "which is not convertible to vector.")); } @@ -510,7 +510,9 @@ class Holder : public HolderBase { HolderSupport::EnsureStaticInit(); return *ptr_; } - const tool::TypeInfo& GetTypeInfo() const final { return tool::TypeId(); } + const tool::TypeInfo& GetTypeInfo() const final { + return tool::TypeInfo::Get(); + } // Releases the underlying data pointer and transfers the ownership to a // unique pointer. // This method is dangerous and is only used by Packet::Consume() if the diff --git a/mediapipe/framework/packet_type.h b/mediapipe/framework/packet_type.h index 738141a29d..09d4d93451 100644 --- a/mediapipe/framework/packet_type.h +++ b/mediapipe/framework/packet_type.h @@ -259,14 +259,14 @@ absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set); template PacketType& PacketType::Set() { - type_spec_ = &tool::TypeId(); + type_spec_ = &tool::TypeInfo::Get(); return *this; } template PacketType& PacketType::SetOneOf() { static const NoDestructor> types{ - {&tool::TypeId()...}}; + {&tool::TypeInfo::Get()...}}; static const NoDestructor name{TypeNameForOneOf(*types)}; type_spec_ = MultiType{*types, &*name}; return *this; diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index d44c8fe261..66f3061e06 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -761,9 +761,11 @@ cc_library( "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:file_helpers", + "//mediapipe/framework/port:gtest", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", diff --git a/mediapipe/framework/tool/options_map.h b/mediapipe/framework/tool/options_map.h index 242ffe161f..023e1dfb03 100644 --- a/mediapipe/framework/tool/options_map.h +++ b/mediapipe/framework/tool/options_map.h @@ -58,14 +58,14 @@ class TypeMap { public: template bool Has() const { - return content_.count(TypeId()) > 0; + return content_.count(TypeInfo::Get()) > 0; } template T* Get() const { if (!Has()) { - content_[TypeId()] = std::make_shared(); + content_[TypeInfo::Get()] = std::make_shared(); } - return static_cast(content_[TypeId()].get()); + return static_cast(content_[TypeInfo::Get()].get()); } private: diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index c77aed3770..2f8953ef87 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -20,6 +20,7 @@ #include #include +#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/status/status.h" @@ -33,6 +34,7 @@ #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/port/advanced_proto_inc.h" #include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/ret_check.h" @@ -208,6 +210,27 @@ bool CompareImageFrames(const ImageFrame& image1, const ImageFrame& image2, return false; } +absl::Status CompareAndSaveImageOutput( + absl::string_view golden_image_path, const ImageFrame& actual, + const ImageFrameComparisonOptions& options) { + ASSIGN_OR_RETURN(auto output_img_path, SavePngTestOutput(actual, "output")); + + auto expected = LoadTestImage(GetTestFilePath(golden_image_path)); + if (!expected.ok()) { + return expected.status(); + } + ASSIGN_OR_RETURN(auto expected_img_path, + SavePngTestOutput(**expected, "expected")); + + std::unique_ptr diff_img; + auto status = CompareImageFrames(**expected, actual, options.max_color_diff, + options.max_alpha_diff, options.max_avg_diff, + diff_img); + ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff")); + + return status; +} + std::string GetTestRootDir() { return file::JoinPath(std::getenv("TEST_SRCDIR"), "mediapipe"); } @@ -275,6 +298,23 @@ std::unique_ptr LoadTestPng(absl::string_view path, return nullptr; } +// Write an ImageFrame as PNG to the test undeclared outputs directory. +// The image's name will contain the given prefix and a timestamp. +// Returns the path to the output if successful. +absl::StatusOr SavePngTestOutput( + const mediapipe::ImageFrame& image, absl::string_view prefix) { + std::string now_string = absl::FormatTime(absl::Now()); + std::string output_relative_path = + absl::StrCat(prefix, "_", now_string, ".png"); + std::string output_full_path = + file::JoinPath(GetTestOutputsDir(), output_relative_path); + RET_CHECK(stbi_write_png(output_full_path.c_str(), image.Width(), + image.Height(), image.NumberOfChannels(), + image.PixelData(), image.WidthStep())) + << " path: " << output_full_path; + return output_relative_path; +} + bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path) { int fd = open(path.c_str(), O_RDONLY); if (fd == -1) { diff --git a/mediapipe/framework/tool/test_util.h b/mediapipe/framework/tool/test_util.h index ae3de37060..71c096db7a 100644 --- a/mediapipe/framework/tool/test_util.h +++ b/mediapipe/framework/tool/test_util.h @@ -22,20 +22,33 @@ namespace mediapipe { using mediapipe::CalculatorGraphConfig; +struct ImageFrameComparisonOptions { + // NOTE: these values are not normalized: use a value from 0 to 2^8-1 + // for 8-bit data and a value from 0 to 2^16-1 for 16-bit data. + // Although these members are declared as floats,, all uint8/uint16 + // values are exactly representable. (2^24 + 1 is the first non-representable + // positive integral value.) + + // Maximum value difference allowed for non-alpha channels. + float max_color_diff; + // Maximum value difference allowed for alpha channel (if present). + float max_alpha_diff; + // Maximum difference for all channels, averaged across all pixels. + float max_avg_diff; +}; + +// Compares an output image with a golden file. Saves the output and difference +// to the undeclared test outputs. +// Returns ok if they are equal within the tolerances specified in options. +absl::Status CompareAndSaveImageOutput( + absl::string_view golden_image_path, const ImageFrame& actual, + const ImageFrameComparisonOptions& options); + // Checks if two image frames are equal within the specified tolerance. // image1 and image2 may be of different-but-compatible image formats (e.g., // SRGB and SRGBA); in that case, only the channels available in both are // compared. -// max_color_diff applies to the first 3 channels; i.e., R, G, B for sRGB and -// sRGBA, and the single gray channel for GRAY8 and GRAY16. It is the maximum -// pixel color value difference allowed; i.e., a value from 0 to 2^8-1 for 8-bit -// data and a value from 0 to 2^16-1 for 16-bit data. -// max_alpha_diff applies to the 4th (alpha) channel only, if present. -// max_avg_diff applies to all channels, normalized across all pixels. -// -// Note: Although max_color_diff and max_alpha_diff are floats, all uint8/uint16 -// values are exactly representable. (2^24 + 1 is the first non-representable -// positive integral value.) +// The diff arguments are as in ImageFrameComparisonOptions. absl::Status CompareImageFrames(const ImageFrame& image1, const ImageFrame& image2, const float max_color_diff, @@ -77,6 +90,13 @@ absl::StatusOr> LoadTestImage( std::unique_ptr LoadTestPng( absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA); +// Write an ImageFrame as PNG to the test undeclared outputs directory. +// The image's name will contain the given prefix and a timestamp. +// If successful, returns the path to the output file relative to the output +// directory. +absl::StatusOr SavePngTestOutput( + const mediapipe::ImageFrame& image, absl::string_view prefix); + // Returns the luminance image of |original_image|. // The format of |original_image| must be sRGB or sRGBA. std::unique_ptr GenerateLuminanceImage( diff --git a/mediapipe/framework/tool/type_util.h b/mediapipe/framework/tool/type_util.h index cd3540989e..4157389e43 100644 --- a/mediapipe/framework/tool/type_util.h +++ b/mediapipe/framework/tool/type_util.h @@ -78,12 +78,6 @@ class TypeIndex { const TypeInfo& info_; }; -// Returns a unique identifier for type T. -template -const TypeInfo& TypeId() { - return TypeInfo::Get(); -} - // Helper method that returns a hash code of the given type. This allows for // typeid testing across multiple binaries, unlike FastTypeId which used a // memory location that only works within the same binary. Moreover, we use this @@ -94,7 +88,7 @@ const TypeInfo& TypeId() { // as much as possible. template size_t GetTypeHash() { - return TypeId().hash_code(); + return TypeInfo::Get().hash_code(); } } // namespace tool diff --git a/mediapipe/framework/type_map.h b/mediapipe/framework/type_map.h index d07ad1024f..0b11959443 100644 --- a/mediapipe/framework/type_map.h +++ b/mediapipe/framework/type_map.h @@ -386,7 +386,7 @@ inline std::string MediaPipeTypeStringOrDemangled( template std::string MediaPipeTypeStringOrDemangled() { - return MediaPipeTypeStringOrDemangled(tool::TypeId()); + return MediaPipeTypeStringOrDemangled(tool::TypeInfo::Get()); } // Returns type hash id of type identified by type_string or NULL if not diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index daf64d9c53..c207acf60a 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -26,22 +26,6 @@ namespace mediapipe { namespace { -// Write an ImageFrame as PNG to the test undeclared outputs directory. -// The image's name will contain the given prefix and a timestamp. -// Returns the path to the output if successful. -std::string SavePngImage(const mediapipe::ImageFrame& image, - absl::string_view prefix) { - std::string output_dir = mediapipe::GetTestOutputsDir(); - std::string now_string = absl::FormatTime(absl::Now()); - std::string out_file_path = - absl::StrCat(output_dir, "/", prefix, "_", now_string, ".png"); - EXPECT_TRUE(stbi_write_png(out_file_path.c_str(), image.Width(), - image.Height(), image.NumberOfChannels(), - image.PixelData(), image.WidthStep())) - << " path: " << out_file_path; - return out_file_path; -} - void FillImageFrameRGBA(ImageFrame& image, uint8 r, uint8 g, uint8 b, uint8 a) { auto* data = image.MutablePixelData(); for (int y = 0; y < image.Height(); ++y) { @@ -143,8 +127,8 @@ TEST_F(GpuBufferTest, GlTextureView) { FillImageFrameRGBA(red, 255, 0, 0, 255); EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); - SavePngImage(red, "gltv_red_gold"); - SavePngImage(*view, "gltv_red_view"); + MP_EXPECT_OK(SavePngTestOutput(red, "gltv_red_gold")); + MP_EXPECT_OK(SavePngTestOutput(*view, "gltv_red_view")); } TEST_F(GpuBufferTest, ImageFrame) { @@ -178,8 +162,8 @@ TEST_F(GpuBufferTest, ImageFrame) { FillImageFrameRGBA(red, 255, 0, 0, 255); EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); - SavePngImage(red, "if_red_gold"); - SavePngImage(*view, "if_red_view"); + MP_EXPECT_OK(SavePngTestOutput(red, "if_red_gold")); + MP_EXPECT_OK(SavePngTestOutput(*view, "if_red_view")); } } @@ -212,8 +196,8 @@ TEST_F(GpuBufferTest, Overwrite) { FillImageFrameRGBA(red, 255, 0, 0, 255); EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); - SavePngImage(red, "ow_red_gold"); - SavePngImage(*view, "ow_red_view"); + MP_EXPECT_OK(SavePngTestOutput(red, "ow_red_gold")); + MP_EXPECT_OK(SavePngTestOutput(*view, "ow_red_view")); } { @@ -246,8 +230,8 @@ TEST_F(GpuBufferTest, Overwrite) { FillImageFrameRGBA(green, 0, 255, 0, 255); EXPECT_TRUE(mediapipe::CompareImageFrames(*view, green, 0.0, 0.0)); - SavePngImage(green, "ow_green_gold"); - SavePngImage(*view, "ow_green_view"); + MP_EXPECT_OK(SavePngTestOutput(green, "ow_green_gold")); + MP_EXPECT_OK(SavePngTestOutput(*view, "ow_green_view")); } { @@ -256,8 +240,8 @@ TEST_F(GpuBufferTest, Overwrite) { FillImageFrameRGBA(blue, 0, 0, 255, 255); EXPECT_TRUE(mediapipe::CompareImageFrames(*view, blue, 0.0, 0.0)); - SavePngImage(blue, "ow_blue_gold"); - SavePngImage(*view, "ow_blue_view"); + MP_EXPECT_OK(SavePngTestOutput(blue, "ow_blue_gold")); + MP_EXPECT_OK(SavePngTestOutput(*view, "ow_blue_view")); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java index 109240bb90..f22de08dce 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java @@ -17,6 +17,7 @@ import com.google.common.base.Preconditions; import com.google.common.flogger.FluentLogger; import com.google.mediapipe.framework.ProtoUtil.SerializedMessage; +import com.google.protobuf.Internal; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.MessageLite; import com.google.protobuf.Parser; @@ -119,11 +120,20 @@ public static byte[] getProtoBytes(final Packet packet) { return nativeGetProtoBytes(packet.getNativeHandle()); } - public static T getProto(final Packet packet, Class clazz) + public static T getProto(final Packet packet, T defaultInstance) throws InvalidProtocolBufferException { SerializedMessage result = new SerializedMessage(); nativeGetProto(packet.getNativeHandle(), result); - return ProtoUtil.unpack(result, clazz); + return ProtoUtil.unpack(result, defaultInstance); + } + + /** + * @deprecated {@link #getProto(Packet, MessageLite)} is safer to use in obfuscated builds. + */ + @Deprecated + public static T getProto(final Packet packet, Class clazz) + throws InvalidProtocolBufferException { + return getProto(packet, Internal.getDefaultInstance(clazz)); } public static short[] getInt16Vector(final Packet packet) { @@ -162,6 +172,13 @@ public static List getProtoVector(final Packet packet, Parser messageP } } + public static List getProtoVector( + final Packet packet, T defaultInstance) { + @SuppressWarnings("unchecked") + Parser parser = (Parser) defaultInstance.getParserForType(); + return getProtoVector(packet, parser); + } + public static int getImageWidth(final Packet packet) { return nativeGetImageWidth(packet.getNativeHandle()); } diff --git a/mediapipe/java/com/google/mediapipe/framework/ProtoUtil.java b/mediapipe/java/com/google/mediapipe/framework/ProtoUtil.java index 524ded5f08..331d1b209c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/ProtoUtil.java +++ b/mediapipe/java/com/google/mediapipe/framework/ProtoUtil.java @@ -15,7 +15,6 @@ package com.google.mediapipe.framework; import com.google.protobuf.ExtensionRegistryLite; -import com.google.protobuf.Internal; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.MessageLite; import java.util.NoSuchElementException; @@ -52,10 +51,8 @@ public static SerializedMessage pack(T message) { } /** Deserializes a MessageLite from a SerializedMessage object. */ - public static T unpack( - SerializedMessage serialized, java.lang.Class clazz) + public static T unpack(SerializedMessage serialized, T defaultInstance) throws InvalidProtocolBufferException { - T defaultInstance = Internal.getDefaultInstance(clazz); String expectedType = ProtoUtil.getTypeName(defaultInstance.getClass()); if (!serialized.typeName.equals(expectedType)) { throw new InvalidProtocolBufferException(