diff --git a/.bazelversion b/.bazelversion index 0062ac9718..91ff57278e 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -5.0.0 +5.2.0 diff --git a/Dockerfile b/Dockerfile index 9da695ef93..0b096fc56d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM ubuntu:18.04 +FROM ubuntu:20.04 MAINTAINER @@ -42,6 +42,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ software-properties-common && \ add-apt-repository -y ppa:openjdk-r/ppa && \ apt-get update && apt-get install -y openjdk-8-jdk && \ + apt-get install -y mesa-common-dev libegl1-mesa-dev libgles2-mesa-dev && \ + apt-get install -y mesa-utils && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -50,13 +52,13 @@ RUN pip3 install --upgrade setuptools RUN pip3 install wheel RUN pip3 install future RUN pip3 install six==1.14.0 -RUN pip3 install tensorflow==1.14.0 +RUN pip3 install tensorflow==2.2.0 RUN pip3 install tf_slim RUN ln -s /usr/bin/python3 /usr/bin/python # Install bazel -ARG BAZEL_VERSION=5.0.0 +ARG BAZEL_VERSION=5.2.0 RUN mkdir /bazel && \ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ diff --git a/WORKSPACE b/WORKSPACE index e85e34d84a..7a75537db1 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -35,8 +35,9 @@ http_archive( http_archive( name = "rules_cc", - strip_prefix = "rules_cc-main", - urls = ["https://github.com/bazelbuild/rules_cc/archive/main.zip"], + strip_prefix = "rules_cc-2f8c04c04462ab83c545ab14c0da68c3b4c96191", +# The commit can be updated if the build passes. Last updated 6/23/22. + urls = ["https://github.com/bazelbuild/rules_cc/archive/2f8c04c04462ab83c545ab14c0da68c3b4c96191.zip"], ) http_archive( diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index ed6a509dc8..3f79575523 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -244,6 +244,7 @@ cc_test( "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:test_util", "@com_google_absl//absl/flags:flag", ], ) diff --git a/mediapipe/calculators/audio/audio_decoder_calculator_test.cc b/mediapipe/calculators/audio/audio_decoder_calculator_test.cc index 8e3babeb01..f8b07101cd 100644 --- a/mediapipe/calculators/audio/audio_decoder_calculator_test.cc +++ b/mediapipe/calculators/audio/audio_decoder_calculator_test.cc @@ -20,8 +20,12 @@ #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/test_util.h" namespace mediapipe { +namespace { + +constexpr char kTestPackageRoot[] = "mediapipe/calculators/audio"; TEST(AudioDecoderCalculatorTest, TestWAV) { CalculatorGraphConfig::Node node_config = @@ -37,9 +41,8 @@ TEST(AudioDecoderCalculatorTest, TestWAV) { })pb"); CalculatorRunner runner(node_config); runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( - file::JoinPath("./", - "/mediapipe/calculators/audio/" - "testdata/sine_wave_1k_44100_mono_2_sec_wav.audio")); + file::JoinPath(GetTestDataDir(kTestPackageRoot), + "sine_wave_1k_44100_mono_2_sec_wav.audio")); MP_ASSERT_OK(runner.Run()); MP_EXPECT_OK(runner.Outputs() .Tag("AUDIO_HEADER") @@ -68,9 +71,8 @@ TEST(AudioDecoderCalculatorTest, Test48KWAV) { })pb"); CalculatorRunner runner(node_config); runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( - file::JoinPath("./", - "/mediapipe/calculators/audio/" - "testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio")); + file::JoinPath(GetTestDataDir(kTestPackageRoot), + "sine_wave_1k_48000_stereo_2_sec_wav.audio")); MP_ASSERT_OK(runner.Run()); MP_EXPECT_OK(runner.Outputs() .Tag("AUDIO_HEADER") @@ -99,9 +101,8 @@ TEST(AudioDecoderCalculatorTest, TestMP3) { })pb"); CalculatorRunner runner(node_config); runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( - file::JoinPath("./", - "/mediapipe/calculators/audio/" - "testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio")); + file::JoinPath(GetTestDataDir(kTestPackageRoot), + "sine_wave_1k_44100_stereo_2_sec_mp3.audio")); MP_ASSERT_OK(runner.Run()); MP_EXPECT_OK(runner.Outputs() .Tag("AUDIO_HEADER") @@ -130,9 +131,8 @@ TEST(AudioDecoderCalculatorTest, TestAAC) { })pb"); CalculatorRunner runner(node_config); runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( - file::JoinPath("./", - "/mediapipe/calculators/audio/" - "testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio")); + file::JoinPath(GetTestDataDir(kTestPackageRoot), + "sine_wave_1k_44100_stereo_2_sec_aac.audio")); MP_ASSERT_OK(runner.Run()); MP_EXPECT_OK(runner.Outputs() .Tag("AUDIO_HEADER") @@ -147,4 +147,5 @@ TEST(AudioDecoderCalculatorTest, TestAAC) { std::ceil(44100.0 * 2 / 1024)); } +} // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/audio/spectrogram_calculator.cc b/mediapipe/calculators/audio/spectrogram_calculator.cc index bd2234f861..c038c0cd71 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator.cc @@ -20,24 +20,22 @@ #include #include -#include "Eigen/Core" #include "absl/strings/string_view.h" #include "audio/dsp/spectrogram/spectrogram.h" #include "audio/dsp/window_functions.h" #include "mediapipe/calculators/audio/spectrogram_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" -#include "mediapipe/framework/formats/time_series_header.pb.h" -#include "mediapipe/framework/port/core_proto_inc.h" -#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/logging.h" -#include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/source_location.h" #include "mediapipe/framework/port/status_builder.h" #include "mediapipe/util/time_series_util.h" namespace mediapipe { +namespace { +constexpr char kFrameDurationTag[] = "FRAME_DURATION"; +constexpr char kFrameOverlapTag[] = "FRAME_OVERLAP"; +} // namespace // MediaPipe Calculator for computing the "spectrogram" (short-time Fourier // transform squared-magnitude, by default) of a multichannel input // time series, including optionally overlapping frames. Options are @@ -46,11 +44,14 @@ namespace mediapipe { // // Result is a MatrixData record (for single channel input and when the // allow_multichannel_input flag is false), or a vector of MatrixData records, -// one for each channel (when the allow_multichannel_input flag is set). The -// rows of each spectrogram matrix correspond to the n_fft/2+1 unique complex -// values, or squared/linear/dB magnitudes, depending on the output_type option. -// Each input packet will result in zero or one output packets, each containing -// one Matrix for each channel of the input, where each Matrix has one or more +// one for each channel (when the allow_multichannel_input flag is set). Each +// waveform frame is converted to frequency by a fast Fourier transform whose +// size, n_fft, is the smallest power of two large enough to enclose the frame +// length of round(frame_duration_seconds * sample_rate).The rows of each +// spectrogram matrix(result) correspond to the n_fft/2+1 unique complex values, +// or squared/linear/dB magnitudes, depending on the output_type option. Each +// input packet will result in zero or one output packets, each containing one +// Matrix for each channel of the input, where each Matrix has one or more // columns of spectral values, one for each complete frame of input samples. If // the input packet contains too few samples to trigger a new output frame, no // output packet is generated (since zero-length packets are not legal since @@ -71,6 +72,22 @@ class SpectrogramCalculator : public CalculatorBase { // Input stream with TimeSeriesHeader. ); + if (cc->InputSidePackets().HasTag(kFrameDurationTag)) { + cc->InputSidePackets() + .Tag(kFrameDurationTag) + .Set( + // Optional side packet for frame_duration_seconds if provided. + ); + } + + if (cc->InputSidePackets().HasTag(kFrameOverlapTag)) { + cc->InputSidePackets() + .Tag(kFrameOverlapTag) + .Set( + // Optional side packet for frame_overlap_seconds if provided. + ); + } + SpectrogramCalculatorOptions spectrogram_options = cc->Options(); if (!spectrogram_options.allow_multichannel_input()) { @@ -184,27 +201,47 @@ class SpectrogramCalculator : public CalculatorBase { // Fixed scale factor applied to output values (regardless of type). double output_scale_; - static const float kLnPowerToDb; + static const float kLnSquaredMagnitudeToDb; }; REGISTER_CALCULATOR(SpectrogramCalculator); -// Factor to convert ln(magnitude_squared) to deciBels = 10.0/ln(10.0). -const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518; +// DECIBELS = 20*log10(LINEAR_MAGNITUDE) = 10*Log10(SQUARED_MAGNITUDE) +// =10/ln(10)*ln(SQUARED_MAGNITUDE). +// Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0). +const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518; absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { SpectrogramCalculatorOptions spectrogram_options = cc->Options(); + // Provide frame_duration_seconds and frame_overlap_seconds either from static + // options, or dynamically from a side packet, the side packet one will + // override the options one if provided. + + double frame_duration_seconds = 0; + double frame_overlap_seconds = 0; + if (cc->InputSidePackets().HasTag(kFrameDurationTag)) { + frame_duration_seconds = + cc->InputSidePackets().Tag(kFrameDurationTag).Get(); + } else { + frame_duration_seconds = spectrogram_options.frame_duration_seconds(); + } + + if (cc->InputSidePackets().HasTag(kFrameOverlapTag)) { + frame_overlap_seconds = + cc->InputSidePackets().Tag(kFrameOverlapTag).Get(); + } else { + frame_overlap_seconds = spectrogram_options.frame_overlap_seconds(); + } use_local_timestamp_ = spectrogram_options.use_local_timestamp(); - if (spectrogram_options.frame_duration_seconds() <= 0.0) { + if (frame_duration_seconds <= 0.0) { // TODO: return an error. } - if (spectrogram_options.frame_overlap_seconds() >= - spectrogram_options.frame_duration_seconds()) { + if (frame_overlap_seconds >= frame_duration_seconds) { // TODO: return an error. } - if (spectrogram_options.frame_overlap_seconds() < 0.0) { + if (frame_overlap_seconds < 0.0) { // TODO: return an error. } @@ -220,10 +257,8 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { // TODO: return an error. } - frame_duration_samples_ = - round(spectrogram_options.frame_duration_seconds() * input_sample_rate_); - frame_overlap_samples_ = - round(spectrogram_options.frame_overlap_seconds() * input_sample_rate_); + frame_duration_samples_ = round(frame_duration_seconds * input_sample_rate_); + frame_overlap_samples_ = round(frame_overlap_seconds * input_sample_rate_); pad_final_packet_ = spectrogram_options.pad_final_packet(); output_type_ = spectrogram_options.output_type(); @@ -419,7 +454,7 @@ absl::Status SpectrogramCalculator::ProcessVector(const Matrix& input_stream, return ProcessVectorToOutput( input_stream, +[](const Matrix& col) -> const Matrix { - return kLnPowerToDb * col.array().log().matrix(); + return kLnSquaredMagnitudeToDb * col.array().log().matrix(); }, cc); } // clang-format on diff --git a/mediapipe/calculators/audio/spectrogram_calculator.proto b/mediapipe/calculators/audio/spectrogram_calculator.proto index b721117d44..8e1e18051e 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.proto +++ b/mediapipe/calculators/audio/spectrogram_calculator.proto @@ -32,7 +32,11 @@ message SpectrogramCalculatorOptions { // Duration of overlap between adjacent windows. // Hence, frame_rate = 1/(frame_duration_seconds - frame_overlap_seconds). - // Required that 0 <= frame_overlap_seconds < frame_duration_seconds. + // Note the frame_rate here is not the MediaPipe packet rate, the frame here + // means each Fourier transform analysis waveform frame, the output MediaPipe + // packet rate will the the same as input, if frame rate is lower than input + // packet rate, will result in intermittent empty output packets. Required + // that 0 <= frame_overlap_seconds < frame_duration_seconds. optional double frame_overlap_seconds = 2 [default = 0.0]; // Whether to pad the final packet with zeros. If true, guarantees that @@ -42,6 +46,11 @@ message SpectrogramCalculatorOptions { // Output value type can be squared-magnitude, linear-magnitude, // deciBels (dB, = 20*log10(linear_magnitude)), or std::complex. + // Their relationship: + // COMPLEX c = Re + Im*i; + // SQUARED_MAGNITUDE = Re^2 + Im^2; + // LINEAR_MAGNITUDE = sqrt(SQUARED_MAGNITUDE); + // DECIBELS = 20*log10(LINEAR_MAGNITUDE) = 10*log10(SQUARED_MAGNITUDE); enum OutputType { SQUARED_MAGNITUDE = 0; LINEAR_MAGNITUDE = 1; diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index ff0a5d6639..e741ebad46 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -557,6 +557,22 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "packet_cloner_calculator_test", + srcs = ["packet_cloner_calculator_test.cc"], + deps = [ + ":packet_cloner_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "packet_inner_join_calculator", srcs = ["packet_inner_join_calculator.cc"], diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.cc b/mediapipe/calculators/core/concatenate_vector_calculator.cc index 20d6a3286a..1a6f9c36fd 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator.cc +++ b/mediapipe/calculators/core/concatenate_vector_calculator.cc @@ -73,8 +73,17 @@ typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmark> ConcatenateLandmarkVectorCalculator; MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkVectorCalculator); +typedef ConcatenateVectorCalculator<::mediapipe::LandmarkList> + ConcatenateLandmarkListVectorCalculator; +MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListVectorCalculator); + typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmarkList> - ConcatenateLandmarListVectorCalculator; + ConcatenateNormalizedLandmarkListVectorCalculator; +MEDIAPIPE_REGISTER_NODE(ConcatenateNormalizedLandmarkListVectorCalculator); + +// For backwards compatibility, keep the version with the typo. +using ConcatenateLandmarListVectorCalculator = + ConcatenateNormalizedLandmarkListVectorCalculator; MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarListVectorCalculator); typedef ConcatenateVectorCalculator diff --git a/mediapipe/calculators/core/flow_limiter_calculator.cc b/mediapipe/calculators/core/flow_limiter_calculator.cc index b365121bc0..d209b1dbbf 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator.cc @@ -32,8 +32,8 @@ constexpr char kOptionsTag[] = "OPTIONS"; // FlowLimiterCalculator is used to limit the number of frames in flight // by dropping input frames when necessary. // -// The input stream "FINISH" is used to signal the FlowLimiterCalculator -// when a frame is finished processing. Either a non-empty "FINISH" packet +// The input stream "FINISHED" is used to signal the FlowLimiterCalculator +// when a frame is finished processing. Either a non-empty "FINISHED" packet // or a timestamp bound should be received for each processed frame. // // The combination of `max_in_flight: 1` and `max_in_queue: 1` generally gives diff --git a/mediapipe/calculators/core/packet_cloner_calculator.cc b/mediapipe/calculators/core/packet_cloner_calculator.cc index ff55a87e7a..cc3e0ba2fa 100644 --- a/mediapipe/calculators/core/packet_cloner_calculator.cc +++ b/mediapipe/calculators/core/packet_cloner_calculator.cc @@ -16,9 +16,10 @@ // For every packet that appears in B, outputs the most recent packet from each // of the A_i on a separate stream. +#include #include -#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "mediapipe/calculators/core/packet_cloner_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -34,7 +35,18 @@ namespace mediapipe { // calculator: "PacketClonerCalculator" // input_stream: "first_base_signal" // input_stream: "second_base_signal" -// input_stream: "tick_signal" +// input_stream: "tick_signal" # or input_stream: "TICK:tick_signal" +// output_stream: "cloned_first_base_signal" +// output_stream: "cloned_second_base_signal" +// } +// +// Or you can use "TICK" tag and put corresponding input stream at any location, +// for example at the very beginning: +// node { +// calculator: "PacketClonerCalculator" +// input_stream: "TICK:tick_signal" +// input_stream: "first_base_signal" +// input_stream: "second_base_signal" // output_stream: "cloned_first_base_signal" // output_stream: "cloned_second_base_signal" // } @@ -46,12 +58,13 @@ namespace mediapipe { class PacketClonerCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - const int tick_signal_index = cc->Inputs().NumEntries() - 1; - for (int i = 0; i < tick_signal_index; ++i) { - cc->Inputs().Index(i).SetAny(); - cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i)); + const Ids ids = GetIds(*cc); + for (const auto& in_out : ids.inputs_outputs) { + auto& input = cc->Inputs().Get(in_out.in); + input.SetAny(); + cc->Outputs().Get(in_out.out).SetSameAs(&input); } - cc->Inputs().Index(tick_signal_index).SetAny(); + cc->Inputs().Get(ids.tick_id).SetAny(); return absl::OkStatus(); } @@ -65,13 +78,15 @@ class PacketClonerCalculator : public CalculatorBase { output_empty_packets_before_all_inputs_received_ = calculator_options.output_packets_only_when_all_inputs_received(); - // Parse input streams. - tick_signal_index_ = cc->Inputs().NumEntries() - 1; - current_.resize(tick_signal_index_); + // Prepare input and output ids. + ids_ = GetIds(*cc); + current_.resize(ids_.inputs_outputs.size()); + // Pass along the header for each stream if present. - for (int i = 0; i < tick_signal_index_; ++i) { - if (!cc->Inputs().Index(i).Header().IsEmpty()) { - cc->Outputs().Index(i).SetHeader(cc->Inputs().Index(i).Header()); + for (const auto& in_out : ids_.inputs_outputs) { + auto& input = cc->Inputs().Get(in_out.in); + if (!input.Header().IsEmpty()) { + cc->Outputs().Get(in_out.out).SetHeader(input.Header()); } } return absl::OkStatus(); @@ -79,17 +94,18 @@ class PacketClonerCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) final { // Store input signals. - for (int i = 0; i < tick_signal_index_; ++i) { - if (!cc->Inputs().Index(i).Value().IsEmpty()) { - current_[i] = cc->Inputs().Index(i).Value(); + for (int i = 0; i < ids_.inputs_outputs.size(); ++i) { + const auto& input = cc->Inputs().Get(ids_.inputs_outputs[i].in); + if (!input.IsEmpty()) { + current_[i] = input.Value(); } } // Output according to the TICK signal. - if (!cc->Inputs().Index(tick_signal_index_).Value().IsEmpty()) { + if (!cc->Inputs().Get(ids_.tick_id).IsEmpty()) { if (output_only_when_all_inputs_received_) { // Return if one of the input is null. - for (int i = 0; i < tick_signal_index_; ++i) { + for (int i = 0; i < ids_.inputs_outputs.size(); ++i) { if (current_[i].IsEmpty()) { if (output_empty_packets_before_all_inputs_received_) { SetAllNextTimestampBounds(cc); @@ -99,12 +115,12 @@ class PacketClonerCalculator : public CalculatorBase { } } // Output each stream. - for (int i = 0; i < tick_signal_index_; ++i) { + for (int i = 0; i < ids_.inputs_outputs.size(); ++i) { + auto& output = cc->Outputs().Get(ids_.inputs_outputs[i].out); if (!current_[i].IsEmpty()) { - cc->Outputs().Index(i).AddPacket( - current_[i].At(cc->InputTimestamp())); + output.AddPacket(current_[i].At(cc->InputTimestamp())); } else { - cc->Outputs().Index(i).SetNextTimestampBound( + output.SetNextTimestampBound( cc->InputTimestamp().NextAllowedInStream()); } } @@ -113,15 +129,44 @@ class PacketClonerCalculator : public CalculatorBase { } private: + struct Ids { + struct InputOutput { + CollectionItemId in; + CollectionItemId out; + }; + CollectionItemId tick_id; + std::vector inputs_outputs; + }; + + template + static Ids GetIds(CC& cc) { + Ids ids; + static constexpr absl::string_view kEmptyTag = ""; + int num_inputs_to_clone = cc.Inputs().NumEntries(kEmptyTag); + static constexpr absl::string_view kTickTag = "TICK"; + if (cc.Inputs().HasTag(kTickTag)) { + ids.tick_id = cc.Inputs().GetId(kTickTag, 0); + } else { + --num_inputs_to_clone; + ids.tick_id = cc.Inputs().GetId(kEmptyTag, num_inputs_to_clone); + } + for (int i = 0; i < num_inputs_to_clone; ++i) { + ids.inputs_outputs.push_back({.in = cc.Inputs().GetId(kEmptyTag, i), + .out = cc.Outputs().GetId(kEmptyTag, i)}); + } + return ids; + } + void SetAllNextTimestampBounds(CalculatorContext* cc) { - for (int j = 0; j < tick_signal_index_; ++j) { - cc->Outputs().Index(j).SetNextTimestampBound( - cc->InputTimestamp().NextAllowedInStream()); + for (const auto& in_out : ids_.inputs_outputs) { + cc->Outputs() + .Get(in_out.out) + .SetNextTimestampBound(cc->InputTimestamp().NextAllowedInStream()); } } std::vector current_; - int tick_signal_index_; + Ids ids_; bool output_only_when_all_inputs_received_; bool output_empty_packets_before_all_inputs_received_; }; diff --git a/mediapipe/calculators/core/packet_cloner_calculator_test.cc b/mediapipe/calculators/core/packet_cloner_calculator_test.cc new file mode 100644 index 0000000000..becb700722 --- /dev/null +++ b/mediapipe/calculators/core/packet_cloner_calculator_test.cc @@ -0,0 +1,349 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/sink.h" + +namespace mediapipe { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Value; + +MATCHER_P2(IntPacket, value, ts, "") { + return Value(arg.template Get(), Eq(value)) && + Value(arg.Timestamp(), Eq(Timestamp(ts))); +} + +MATCHER_P2(FloatPacket, value, ts, "") { + return Value(arg.template Get(), Eq(value)) && + Value(arg.Timestamp(), Eq(Timestamp(ts))); +} + +template +absl::Status SendPacket(const std::string& input_name, T value, int ts, + CalculatorGraph& graph) { + return graph.AddPacketToInputStream(input_name, + MakePacket(value).At(Timestamp(ts))); +} + +struct Params { + bool use_tick_tag = false; +}; + +class PacketClonerCalculatorTest : public testing::TestWithParam {}; + +TEST_P(PacketClonerCalculatorTest, ClonesSingleInputSameTimestamps) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie([&]() { + if (GetParam().use_tick_tag) { + return R"pb( + input_stream: 'in1' + input_stream: 'tick' + node { + calculator: 'PacketClonerCalculator' + input_stream: 'in1' + input_stream: 'TICK:tick' + output_stream: 'out1' + })pb"; + } + return R"pb( + input_stream: 'in1' + input_stream: 'tick' + node { + calculator: 'PacketClonerCalculator' + input_stream: 'in1' + input_stream: 'tick' + output_stream: 'out1' + })pb"; + }()); + std::vector out1; + tool::AddVectorSink("out1", &graph_config, &out1); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK(SendPacket("in1", 1, /*ts=*/10000, graph)); + MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + EXPECT_THAT(out1, ElementsAre(IntPacket(1, 10000))); +} + +TEST_P(PacketClonerCalculatorTest, ClonesSingleInputEarlierTimestamps) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie([&]() { + if (GetParam().use_tick_tag) { + return R"pb( + input_stream: 'in1' + input_stream: 'tick' + node { + calculator: 'PacketClonerCalculator' + input_stream: 'in1' + input_stream: 'TICK:tick' + output_stream: 'out1' + })pb"; + } + return R"pb( + input_stream: 'in1' + input_stream: 'tick' + node { + calculator: 'PacketClonerCalculator' + input_stream: 'in1' + input_stream: 'tick' + output_stream: 'out1' + })pb"; + }()); + std::vector out1; + tool::AddVectorSink("out1", &graph_config, &out1); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + + // PacketClonerCalculator is non-ImmediateInputStreamHandler + // PacketClonerCalculator waits for "in1" to arrive for ts=5000 + MP_ASSERT_OK(SendPacket("in1", 1, /*ts=*/5000, graph)); + // Newer tick at ts=10000, should NOT trigger output for ts=5000 + // PacketClonerCalculator waits for "in1" to arrive for ts=10000 + MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph)); + MP_ASSERT_OK(SendPacket("tick", 1001, /*ts=*/10001, graph)); + MP_ASSERT_OK(SendPacket("tick", 1002, /*ts=*/10002, graph)); + // Newer "in1" at ts=15000, should trigger output for ts=10000 + MP_ASSERT_OK(SendPacket("in1", 2, /*ts=*/15000, graph)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + EXPECT_THAT(out1, ElementsAre(IntPacket(1, 10000), IntPacket(1, 10001), + IntPacket(1, 10002))); +} + +TEST_P(PacketClonerCalculatorTest, ClonesFiveInputs) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie([&]() { + if (GetParam().use_tick_tag) { + return R"pb( + input_stream: 'in1' + input_stream: 'in2' + input_stream: 'in3' + input_stream: 'in4' + input_stream: 'in5' + input_stream: 'tick' + node { + calculator: 'PacketClonerCalculator' + input_stream: 'in1' + input_stream: 'in2' + input_stream: 'in3' + input_stream: 'in4' + input_stream: 'in5' + output_stream: 'out1' + output_stream: 'out2' + output_stream: 'out3' + input_stream: 'TICK:tick' # arbitrary location + output_stream: 'out4' + output_stream: 'out5' + } + )pb"; + } + return R"pb( + input_stream: 'in1' + input_stream: 'in2' + input_stream: 'in3' + input_stream: 'in4' + input_stream: 'in5' + input_stream: 'tick' + node { + calculator: 'PacketClonerCalculator' + input_stream: 'in1' + input_stream: 'in2' + input_stream: 'in3' + input_stream: 'in4' + input_stream: 'in5' + input_stream: 'tick' + output_stream: 'out1' + output_stream: 'out2' + output_stream: 'out3' + output_stream: 'out4' + output_stream: 'out5' + } + )pb"; + }()); + constexpr int kNumToClone = 5; + std::array, kNumToClone> outs; + for (int i = 0; i < kNumToClone; ++i) { + tool::AddVectorSink(absl::StrCat("out", i + 1), &graph_config, &outs[i]); + } + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK(SendPacket("in1", 10, /*ts=*/10000, graph)); + MP_ASSERT_OK(SendPacket("in2", 20.0f, /*ts=*/10000, graph)); + MP_ASSERT_OK(SendPacket("in3", 30, /*ts=*/10000, graph)); + MP_ASSERT_OK(SendPacket("in4", 40.0f, /*ts=*/10000, graph)); + MP_ASSERT_OK(SendPacket("in5", 50, /*ts=*/10000, graph)); + MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph)); + // Below "tick" packets won't trigger output, until newer inputs are sent, + // because inputs are missing and ImmediateInputStreamHandler is not + // configured. + MP_ASSERT_OK(SendPacket("tick", 1001, /*ts=*/10001, graph)); + MP_ASSERT_OK(SendPacket("tick", 1002, /*ts=*/10002, graph)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + EXPECT_THAT(outs, ElementsAre(ElementsAre(IntPacket(10, 10000)), + ElementsAre(FloatPacket(20.0f, 10000)), + ElementsAre(IntPacket(30, 10000)), + ElementsAre(FloatPacket(40.0f, 10000)), + ElementsAre(IntPacket(50, 10000)))); + + MP_ASSERT_OK(SendPacket("in1", 100, /*ts=*/20000, graph)); + MP_ASSERT_OK(SendPacket("in2", 200.0f, /*ts=*/20000, graph)); + MP_ASSERT_OK(SendPacket("in3", 300, /*ts=*/20000, graph)); + MP_ASSERT_OK(SendPacket("in4", 400.0f, /*ts=*/20000, graph)); + MP_ASSERT_OK(SendPacket("in5", 500, /*ts=*/20000, graph)); + MP_ASSERT_OK(SendPacket("tick", 2000, /*ts=*/20000, graph)); + // Below "tick" packets won't trigger output, because inputs are missing and + // ImmediateInputStreamHandler is not configured. + MP_ASSERT_OK(SendPacket("tick", 2001, /*ts=*/20001, graph)); + MP_ASSERT_OK(SendPacket("tick", 2002, /*ts=*/20002, graph)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + EXPECT_THAT( + outs, + ElementsAre( + ElementsAre(IntPacket(10, 10000), IntPacket(10, 10001), + IntPacket(10, 10002), IntPacket(100, 20000)), + ElementsAre(FloatPacket(20.0f, 10000), FloatPacket(20.0f, 10001), + FloatPacket(20.0f, 10002), FloatPacket(200.0f, 20000)), + ElementsAre(IntPacket(30, 10000), IntPacket(30, 10001), + IntPacket(30, 10002), IntPacket(300, 20000)), + ElementsAre(FloatPacket(40.0f, 10000), FloatPacket(40.0f, 10001), + FloatPacket(40.0f, 10002), FloatPacket(400.0f, 20000)), + ElementsAre(IntPacket(50, 10000), IntPacket(50, 10001), + IntPacket(50, 10002), IntPacket(500, 20000)))); +} + +TEST_P(PacketClonerCalculatorTest, + ClonesTwoInputsWithImmediateInputStreamHandler) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie([&]() { + if (GetParam().use_tick_tag) { + return R"pb( + input_stream: 'in1' + input_stream: 'in2' + input_stream: 'tick' + node { + calculator: 'PacketClonerCalculator' + input_stream: 'TICK:tick' + input_stream: 'in1' + input_stream: 'in2' + output_stream: 'out1' + output_stream: 'out2' + input_stream_handler { + input_stream_handler: "ImmediateInputStreamHandler" + } + })pb"; + } + return R"pb( + input_stream: 'in1' + input_stream: 'in2' + input_stream: 'tick' + node { + calculator: 'PacketClonerCalculator' + input_stream: 'in1' + input_stream: 'in2' + input_stream: 'tick' + output_stream: 'out1' + output_stream: 'out2' + input_stream_handler { + input_stream_handler: "ImmediateInputStreamHandler" + } + })pb"; + }()); + constexpr int kNumToClone = 2; + std::array, kNumToClone> outs; + for (int i = 0; i < kNumToClone; ++i) { + tool::AddVectorSink(absl::StrCat("out", i + 1), &graph_config, &outs[i]); + } + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + + // No packets to clone. + MP_ASSERT_OK(SendPacket("tick", 0, /*ts=*/0, graph)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Cloning current packets. + MP_ASSERT_OK(SendPacket("in1", 1, /*ts=*/10000, graph)); + MP_ASSERT_OK(SendPacket("in2", 10.0f, /*ts=*/10000, graph)); + MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Cloning past packets. + MP_ASSERT_OK(SendPacket("tick", 1500, /*ts=*/15000, graph)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Cloning past packets. + MP_ASSERT_OK(SendPacket("in1", 2, /*ts=*/10001, graph)); + MP_ASSERT_OK(SendPacket("in2", 20.0f, /*ts=*/10001, graph)); + MP_ASSERT_OK(SendPacket("tick", 2000, /*ts=*/20000, graph)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Cloning future packets. + MP_ASSERT_OK(SendPacket("in1", 3, /*ts=*/30000, graph)); + MP_ASSERT_OK(SendPacket("in2", 30.0f, /*ts=*/30000, graph)); + // Waiting to ensure newer packets (ts=30000) to clone would get into the + // cloner before tick (ts=25000) does. + MP_ASSERT_OK(graph.WaitUntilIdle()); + MP_ASSERT_OK(SendPacket("tick", 3000, /*ts=*/25000, graph)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Cloning packets having different timestamps. + MP_ASSERT_OK(SendPacket("in1", 4, /*ts=*/38000, graph)); + MP_ASSERT_OK(SendPacket("in2", 40.0f, /*ts=*/39000, graph)); + MP_ASSERT_OK(SendPacket("tick", 4000, /*ts=*/40000, graph)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + EXPECT_THAT( + outs, + ElementsAre( + ElementsAre(IntPacket(1, 10000), IntPacket(1, 15000), + IntPacket(2, 20000), IntPacket(3, 25000), + IntPacket(4, 40000)), + ElementsAre(FloatPacket(10.0f, 10000), FloatPacket(10.0f, 15000), + FloatPacket(20.0f, 20000), FloatPacket(30.0f, 25000), + FloatPacket(40.0f, 40000)))); +} + +INSTANTIATE_TEST_SUITE_P(PacketClonerCalculator, PacketClonerCalculatorTest, + testing::ValuesIn({Params{.use_tick_tag = false}, + Params{.use_tick_tag = true}})); +} // anonymous namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/packet_resampler_calculator.cc b/mediapipe/calculators/core/packet_resampler_calculator.cc index 81ccdbe654..76fb4f124e 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator.cc @@ -157,9 +157,7 @@ absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { } } - if (absl::Status status = strategy_->Process(cc); !status.ok()) { - return status; // Avoid MP_RETURN_IF_ERROR macro for external release. - } + MP_RETURN_IF_ERROR(strategy_->Process(cc)); last_packet_ = cc->Inputs().Get(input_data_id_).Value(); diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 5428f98fd5..458c5368b0 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -626,11 +626,8 @@ cc_library( "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", - "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:image_opencv", "//mediapipe/framework/port:logging", - "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", ] + select({ @@ -641,6 +638,13 @@ cc_library( "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:shader_util", ], + }) + select({ + "//mediapipe/framework/port:disable_opencv": [], + "//conditions:default": [ + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:image_opencv", + "//mediapipe/framework/port:opencv_core", + ], }), alwayslink = 1, ) @@ -727,7 +731,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":affine_transformation", - ":affine_transformation_runner_opencv", ":warp_affine_calculator_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -745,6 +748,9 @@ cc_library( "//mediapipe/gpu:gpu_buffer", ":affine_transformation_runner_gl", ], + }) + select({ + "//mediapipe/framework/port:disable_opencv": [], + "//conditions:default": [":affine_transformation_runner_opencv"], }), alwayslink = 1, ) @@ -799,3 +805,21 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "yuv_to_image_calculator", + srcs = ["yuv_to_image_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:yuv_image", + "//third_party/libyuv", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], + alwayslink = 1, +) diff --git a/mediapipe/calculators/image/segmentation_smoothing_calculator.cc b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc index 62d3b0d28b..81732f9044 100644 --- a/mediapipe/calculators/image/segmentation_smoothing_calculator.cc +++ b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc @@ -21,10 +21,7 @@ #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/framework/formats/image_frame_opencv.h" -#include "mediapipe/framework/formats/image_opencv.h" #include "mediapipe/framework/port/logging.h" -#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" @@ -34,6 +31,12 @@ #include "mediapipe/gpu/shader_util.h" #endif // !MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_OPENCV +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/image_opencv.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#endif // !MEDIAPIPE_DISABLE_OPENCV + namespace mediapipe { namespace { @@ -163,7 +166,11 @@ absl::Status SegmentationSmoothingCalculator::Process(CalculatorContext* cc) { return absl::InternalError("GPU processing is disabled."); #endif // !MEDIAPIPE_DISABLE_GPU } else { +#if !MEDIAPIPE_DISABLE_OPENCV MP_RETURN_IF_ERROR(RenderCpu(cc)); +#else + return absl::InternalError("OpenCV processing is disabled."); +#endif // !MEDIAPIPE_DISABLE_OPENCV } return absl::OkStatus(); @@ -181,6 +188,7 @@ absl::Status SegmentationSmoothingCalculator::Close(CalculatorContext* cc) { } absl::Status SegmentationSmoothingCalculator::RenderCpu(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_OPENCV // Setup source images. const auto& current_frame = cc->Inputs().Tag(kCurrentMaskTag).Get(); auto current_mat = mediapipe::formats::MatView(¤t_frame); @@ -245,6 +253,7 @@ absl::Status SegmentationSmoothingCalculator::RenderCpu(CalculatorContext* cc) { cc->Outputs() .Tag(kOutputMaskTag) .AddPacket(MakePacket(output_frame).At(cc->InputTimestamp())); +#endif // !MEDIAPIPE_DISABLE_OPENCV return absl::OkStatus(); } diff --git a/mediapipe/calculators/image/warp_affine_calculator.cc b/mediapipe/calculators/image/warp_affine_calculator.cc index e3d017a35d..615d1697c3 100644 --- a/mediapipe/calculators/image/warp_affine_calculator.cc +++ b/mediapipe/calculators/image/warp_affine_calculator.cc @@ -24,7 +24,9 @@ #endif // !MEDIAPIPE_DISABLE_GPU #include "absl/status/status.h" #include "absl/status/statusor.h" +#if !MEDIAPIPE_DISABLE_OPENCV #include "mediapipe/calculators/image/affine_transformation_runner_opencv.h" +#endif // !MEDIAPIPE_DISABLE_OPENCV #include "mediapipe/calculators/image/warp_affine_calculator.pb.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" @@ -54,6 +56,7 @@ AffineTransformation::BorderMode GetBorderMode( template class WarpAffineRunnerHolder {}; +#if !MEDIAPIPE_DISABLE_OPENCV template <> class WarpAffineRunnerHolder { public: @@ -69,6 +72,7 @@ class WarpAffineRunnerHolder { private: std::unique_ptr runner_; }; +#endif // !MEDIAPIPE_DISABLE_OPENCV #if !MEDIAPIPE_DISABLE_GPU template <> @@ -113,7 +117,9 @@ class WarpAffineRunnerHolder { mediapipe::Image> { public: absl::Status Open(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_OPENCV MP_RETURN_IF_ERROR(cpu_holder_.Open(cc)); +#endif // !MEDIAPIPE_DISABLE_OPENCV #if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_holder_.Open(cc)); #endif // !MEDIAPIPE_DISABLE_GPU @@ -133,20 +139,26 @@ class WarpAffineRunnerHolder { return absl::UnavailableError("GPU support is disabled"); #endif // !MEDIAPIPE_DISABLE_GPU } +#if !MEDIAPIPE_DISABLE_OPENCV ASSIGN_OR_RETURN(auto* runner, cpu_holder_.GetRunner()); const auto& frame_ptr = input.GetImageFrameSharedPtr(); // Wrap image into image frame. const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(), frame_ptr->Height(), frame_ptr->WidthStep(), const_cast(frame_ptr->PixelData()), - [](uint8* data) {}); + [](uint8* data){}); ASSIGN_OR_RETURN(auto result, runner->Run(image_frame, matrix, size, border_mode)); return mediapipe::Image(std::make_shared(std::move(result))); +#else + return absl::UnavailableError("OpenCV support is disabled"); +#endif // !MEDIAPIPE_DISABLE_OPENCV } private: +#if !MEDIAPIPE_DISABLE_OPENCV WarpAffineRunnerHolder cpu_holder_; +#endif // !MEDIAPIPE_DISABLE_OPENCV #if !MEDIAPIPE_DISABLE_GPU WarpAffineRunnerHolder gpu_holder_; #endif // !MEDIAPIPE_DISABLE_GPU @@ -200,8 +212,10 @@ class WarpAffineCalculatorImpl : public mediapipe::api2::NodeImpl { } // namespace +#if !MEDIAPIPE_DISABLE_OPENCV MEDIAPIPE_NODE_IMPLEMENTATION( WarpAffineCalculatorImpl); +#endif // !MEDIAPIPE_DISABLE_OPENCV #if !MEDIAPIPE_DISABLE_GPU MEDIAPIPE_NODE_IMPLEMENTATION( WarpAffineCalculatorImpl); diff --git a/mediapipe/calculators/image/warp_affine_calculator.h b/mediapipe/calculators/image/warp_affine_calculator.h index 4a1b07030c..461a333d0b 100644 --- a/mediapipe/calculators/image/warp_affine_calculator.h +++ b/mediapipe/calculators/image/warp_affine_calculator.h @@ -70,11 +70,13 @@ class WarpAffineCalculatorIntf : public mediapipe::api2::NodeIntf { static constexpr mediapipe::api2::Output kOutImage{"IMAGE"}; }; +#if !MEDIAPIPE_DISABLE_OPENCV class WarpAffineCalculatorCpu : public WarpAffineCalculatorIntf { public: MEDIAPIPE_NODE_INTERFACE(WarpAffineCalculatorCpu, kInImage, kMatrix, kOutputSize, kOutImage); }; +#endif // !MEDIAPIPE_DISABLE_OPENCV #if !MEDIAPIPE_DISABLE_GPU class WarpAffineCalculatorGpu : public WarpAffineCalculatorIntf { diff --git a/mediapipe/calculators/image/yuv_to_image_calculator.cc b/mediapipe/calculators/image/yuv_to_image_calculator.cc new file mode 100644 index 0000000000..e84eee74e5 --- /dev/null +++ b/mediapipe/calculators/image/yuv_to_image_calculator.cc @@ -0,0 +1,123 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "libyuv/convert_argb.h" +#include "libyuv/video_common.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/yuv_image.h" + +namespace mediapipe { +namespace api2 { + +namespace { + +// Utility function to convert FourCC enum to string, for error messages. +std::string FourCCToString(libyuv::FourCC fourcc) { + char buf[5]; + buf[0] = (fourcc >> 24) & 0xff; + buf[1] = (fourcc >> 16) & 0xff; + buf[2] = (fourcc >> 8) & 0xff; + buf[3] = (fourcc)&0xff; + buf[4] = 0; + return std::string(buf); +} +} // namespace + +// Converts a `YUVImage` into an RGB `Image` using libyuv. +// +// The input `YUVImage` is expected to be in the NV12, NV21, YV12 or I420 (aka +// YV21) format (as per the `fourcc()` property). This covers the most commonly +// used YUV image formats used on mobile devices. Other formats are not +// supported and wil result in an `InvalidArgumentError`. +class YUVToImageCalculator : public Node { + public: + static constexpr Input kInput{"YUV_IMAGE"}; + static constexpr Output kOutput{"IMAGE"}; + + MEDIAPIPE_NODE_CONTRACT(kInput, kOutput); + + absl::Status Process(CalculatorContext* cc) override { + const auto& yuv_image = *kInput(cc); + // Check that the format is supported. + auto format = yuv_image.fourcc(); + if (format != libyuv::FOURCC_NV12 && format != libyuv::FOURCC_NV21 && + format != libyuv::FOURCC_YV12 && format != libyuv::FOURCC_I420) { + return absl::InvalidArgumentError( + absl::StrFormat("Unsupported YUVImage format: %s. Only NV12, NV21, " + "YV12 and I420 (aka YV21) are supported.", + FourCCToString(format))); + } + // Build a transient ImageFrameSharedPtr with default alignment to host + // conversion results. + ImageFrameSharedPtr image_frame = std::make_shared( + ImageFormat::SRGB, yuv_image.width(), yuv_image.height()); + // Perform actual conversion. + switch (format) { + case libyuv::FOURCC_NV12: + // 8-bit Y plane followed by an interleaved 8-bit U/V plane with 2×2 + // subsampling. + libyuv::NV12ToRAW( + yuv_image.data(0), yuv_image.stride(0), yuv_image.data(1), + yuv_image.stride(1), image_frame->MutablePixelData(), + image_frame->WidthStep(), yuv_image.width(), yuv_image.height()); + break; + case libyuv::FOURCC_NV21: + // 8-bit Y plane followed by an interleaved 8-bit V/U plane with 2×2 + // subsampling. + libyuv::NV21ToRAW( + yuv_image.data(0), yuv_image.stride(0), yuv_image.data(1), + yuv_image.stride(1), image_frame->MutablePixelData(), + image_frame->WidthStep(), yuv_image.width(), yuv_image.height()); + break; + case libyuv::FOURCC_I420: + // Also known as YV21. + // 8-bit Y plane followed by 8-bit 2×2 subsampled U and V planes. + libyuv::I420ToRAW( + yuv_image.data(0), yuv_image.stride(0), yuv_image.data(1), + yuv_image.stride(1), yuv_image.data(2), yuv_image.stride(2), + image_frame->MutablePixelData(), image_frame->WidthStep(), + yuv_image.width(), yuv_image.height()); + break; + case libyuv::FOURCC_YV12: + // 8-bit Y plane followed by 8-bit 2×2 subsampled V and U planes. + libyuv::I420ToRAW( + yuv_image.data(0), yuv_image.stride(0), yuv_image.data(2), + yuv_image.stride(2), yuv_image.data(1), yuv_image.stride(1), + image_frame->MutablePixelData(), image_frame->WidthStep(), + yuv_image.width(), yuv_image.height()); + break; + default: + // This should never happen (caught by checks above). + return absl::InternalError("Unsupported YUVImage format."); + } + // Finally, build and send an Image object that takes ownership of the + // transient ImageFrameSharedPtr object. + kOutput(cc).Send(std::make_unique(std::move(image_frame))); + return absl::OkStatus(); + } +}; +MEDIAPIPE_REGISTER_NODE(YUVToImageCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 586fb0dd30..2529db4010 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -40,6 +40,63 @@ selects.config_setting_group( ], ) +mediapipe_proto_library( + name = "audio_to_tensor_calculator_proto", + srcs = ["audio_to_tensor_calculator.proto"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "audio_to_tensor_calculator", + srcs = ["audio_to_tensor_calculator.cc"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + ":audio_to_tensor_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/util:time_series_util", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_audio_tools//audio/dsp:resampler_q", + "@org_tensorflow//tensorflow/lite/c:common", + ], + alwayslink = 1, +) + +cc_test( + name = "audio_to_tensor_calculator_test", + srcs = ["audio_to_tensor_calculator_test.cc"], + deps = [ + ":audio_to_tensor_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/strings", + "@com_google_audio_tools//audio/dsp:resampler_q", + "@org_tensorflow//tensorflow/lite/c:common", + ], +) + mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], @@ -50,6 +107,14 @@ mediapipe_proto_library( ], ) +# This target defines the "InferenceCalculator" component, which looks for the available concrete +# implementations linked into the current binary and picks the one to use. +# You can depend on :inference_calculator instead if you want to automatically include a default +# set of implementations tailored for the current build configuration. +# If you want to have precise control of which implementations to include (e.g. for strict binary +# size concerns), depend on those implementations directly, and do not depend on +# :inference_calculator. +# In all cases, use "InferenceCalulator" in your graphs. cc_library( name = "inference_calculator_interface", srcs = ["inference_calculator.cc"], @@ -62,8 +127,9 @@ cc_library( ], "//conditions:default": [], }), + visibility = ["//visibility:public"], deps = [ - ":inference_calculator_cc_proto", + ":inference_calculator_options_lib", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:packet", @@ -85,18 +151,31 @@ cc_library( name = "inference_calculator_gl", srcs = ["inference_calculator_gl.cc"], tags = ["nomac"], # config problem with cpuinfo via TF + visibility = ["//visibility:public"], deps = [ - "inference_calculator_interface", + ":inference_calculator_interface", + "//mediapipe/gpu:gl_calculator_helper", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", + ], + alwayslink = 1, +) + +cc_library( + name = "inference_calculator_gl_advanced", + srcs = ["inference_calculator_gl_advanced.cc"], + tags = ["nomac"], + visibility = ["//visibility:public"], + deps = [ + ":inference_calculator_interface", "//mediapipe/framework/deps:file_path", "//mediapipe/gpu:gl_calculator_helper", - "//mediapipe/gpu:gpu_buffer", - "//mediapipe/util/tflite:config", "//mediapipe/util/tflite:tflite_gpu_runner", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@org_tensorflow//tensorflow/lite:framework_stable", - "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", ], alwayslink = 1, ) @@ -113,6 +192,7 @@ cc_library( "-framework MetalKit", ], tags = ["ios"], + visibility = ["//visibility:public"], deps = [ "inference_calculator_interface", "//mediapipe/gpu:MPPMetalHelper", @@ -142,6 +222,7 @@ cc_library( ], "//conditions:default": [], }), + visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", "@com_google_absl//absl/memory", @@ -161,9 +242,13 @@ cc_library( cc_library( name = "inference_calculator_gl_if_compute_shader_available", + visibility = ["//visibility:public"], deps = selects.with_or({ ":compute_shader_unavailable": [], - "//conditions:default": [":inference_calculator_gl"], + "//conditions:default": [ + ":inference_calculator_gl", + ":inference_calculator_gl_advanced", + ], }), ) @@ -484,6 +569,7 @@ cc_library( "//mediapipe/framework/formats:location", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/formats:tensor", + "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:resource_util", ] + select({ "//mediapipe:android": [ @@ -506,6 +592,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", + "//mediapipe/util:label_map_proto", ], ) @@ -672,6 +759,7 @@ cc_library( ], "//conditions:default": [], }), + visibility = ["//visibility:public"], deps = [ ":image_to_tensor_converter", ":image_to_tensor_utils", @@ -858,9 +946,7 @@ cc_library( "@com_google_absl//absl/types:span", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_frame", - "//mediapipe/framework/formats:image_opencv", "//mediapipe/framework/formats:tensor", - "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -890,6 +976,12 @@ cc_library( "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl/converters:util", ], + }) + select({ + "//mediapipe/framework/port:disable_opencv": [], + "//conditions:default": [ + "//mediapipe/framework/formats:image_opencv", + "//mediapipe/framework/port:opencv_imgproc", + ], }), alwayslink = 1, ) diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc new file mode 100644 index 0000000000..12820ed16e --- /dev/null +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc @@ -0,0 +1,401 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "audio/dsp/resampler_q.h" +#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/util/time_series_util.h" + +namespace mediapipe { +namespace api2 { + +// Converts audio buffers into tensors, possibly with resampling, buffering +// and framing, according to specified inputs and options. All input audio +// buffers will be first resampled from the input sample rate to the target +// sample rate if they are not equal. The resampled audio data (with the +// buffered samples from the previous runs in the streaming mode) will be broken +// into fixed-sized, possibly overlapping frames. Finally, all frames will be +// converted to and outputted as MediaPipe Tensors. The last output tensor will +// be zero-padding if the remaining samples are insufficient. +// +// This calculator assumes that the input timestamps refer to the first +// sample in each Matrix. The output timestamps follow this same convention. +// One Process() call may output multiple tensors packets. The timestamps of +// the output packets are determined by the timestamp of the previous output +// packet, the target sample rate, and the number of samples advanced after the +// previous output. +// +// The calculator has two running modes: +// Streaming mode: when "streaming_mode" is set to true in the calculator +// options, the calculator treats the input audio stream as a continuous +// stream. Thus, any samples that are not consumed in the previous runs will +// be cached in a global sample buffer. The audio data resampled from the +// current raw audio input will be appended to the global sample buffer. +// The calculator will process the global sample buffer and output as many +// tensors as possible. +// Non-streaming mode: when "streaming_mode" is set to false in the calculator +// options, the calculators treats the packets in the input audio stream as +// a batch of unrelated audio buffers. In each Process() call, the input +// buffer will be frist resampled, and framed as fixed-sized, possibly +// overlapping tensors. The last tensor produced by a Process() invocation +// will be zero-padding if the remaining samples are insufficient. As the +// calculator treats the input packets as unrelated, all samples will be +// processed immediately and no samples will be cached in the global sample +// buffer. +// +// Inputs: +// AUDIO - mediapipe::Matrix +// The audio data represented as mediapipe::Matrix. +// SAMPLE_RATE - double @Optional +// The sample rate of the corresponding audio data in the "AUDIO" stream. +// If a sample rate packet is provided at Timestamp::PreStream(), the sample +// rate will be used as the sample rate of every audio packets in the +// "AUDIO" stream. Note that one and only one of the "AUDIO" stream's time +// series header or the "SAMPLE_RATE" stream can exist. +// +// Outputs: +// TENSORS - std::vector +// Vector containing a single Tensor that represents a fix-sized audio +// frame. +// TIMESTAMPS - std::vector @Optional +// Vector containing the output timestamps emitted by the current Process() +// invocation. In the non-streaming mode, the vector contains all of the +// output timestamps for an input audio buffer. +// +// Example: +// node { +// calculator: "AudioToTensorCalculator" +// input_stream: "AUDIO:audio" +// output_stream: "TENSORS:tensors" +// output_stream: "TIMESTAMPS:timestamps" +// options { +// [mediapipe.AudioToTensorCalculatorOptions.ext] { +// num_channels: 2 +// num_samples: 512 +// num_overlapping_samples: 64 +// target_sample_rate: 16000 +// streaming_mode: true # or false +// } +// } +// } +class AudioToTensorCalculator : public Node { + public: + static constexpr Input kAudioIn{"AUDIO"}; + // TODO: Removes this optional input stream when the "AUDIO" stream + // uses the new mediapipe audio data containers that carry audio metatdata, + // such as sample rate. + static constexpr Input::Optional kAudioSampleRateIn{"SAMPLE_RATE"}; + static constexpr Output> kTensorsOut{"TENSORS"}; + // A vector of the output timestamps emitted by the current Process() + // invocation. The packet timestamp is the last emitted timestamp. + static constexpr Output>::Optional kTimestampsOut{ + "TIMESTAMPS"}; + MEDIAPIPE_NODE_CONTRACT(kAudioIn, kAudioSampleRateIn, kTensorsOut, + kTimestampsOut); + + static absl::Status UpdateContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc); + absl::Status Process(CalculatorContext* cc); + absl::Status Close(CalculatorContext* cc); + + private: + // The target number of channels. + int num_channels_; + // The target number of samples per channel. + int num_samples_; + // The number of samples per channel to advance after the current frame is + // processed. + int frame_step_; + bool streaming_mode_; + bool check_inconsistent_timestamps_; + Timestamp initial_timestamp_ = Timestamp::Unstarted(); + int64 cumulative_input_samples_ = 0; + Timestamp next_output_timestamp_ = Timestamp::Unstarted(); + + double source_sample_rate_ = -1; + double target_sample_rate_ = -1; + // TODO: Configures QResamplerParams through calculator options. + audio_dsp::QResamplerParams params_; + // A QResampler instance to resample an audio stream. + std::unique_ptr> resampler_; + Matrix sample_buffer_; + int processed_buffer_cols_ = 0; + + absl::Status ProcessStreamingData(CalculatorContext* cc); + absl::Status ProcessNonStreamingData(CalculatorContext* cc); + + absl::Status SetupStreamingResampler(double input_sample_rate_); + void AppendToSampleBuffer(Matrix buffer_to_append); + + absl::StatusOr> ConvertToTensor( + const Matrix& frame_to_convert); + absl::Status OutputTensors(const Matrix& buffer, bool should_flush, + CalculatorContext* cc); +}; + +absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) { + const auto& options = + cc->Options(); + if (!options.has_num_channels() || !options.has_num_samples() || + !options.has_target_sample_rate()) { + return absl::InvalidArgumentError( + "AudioToTensorCalculatorOptions must specifiy " + "`num_channels`, `num_samples`, and `target_sample_rate`."); + } + if (options.streaming_mode()) { + // Explicitly disables tiemstamp offset to disallow the timestamp bound + // from the input streams to be propagated to the output streams. + // In the streaming mode, the output timestamp bound is based on + // next_output_timestamp_, which can be smaller than the current input + // timestamps. + cc->SetTimestampOffset(TimestampDiff::Unset()); + } + return absl::OkStatus(); +} + +absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { + const auto& options = + cc->Options(); + num_channels_ = options.num_channels(); + num_samples_ = options.num_samples(); + if (options.has_num_overlapping_samples()) { + RET_CHECK_GE(options.num_overlapping_samples(), 0); + RET_CHECK_LT(options.num_overlapping_samples(), num_samples_); + frame_step_ = num_samples_ - options.num_overlapping_samples(); + } else { + frame_step_ = num_samples_; + } + target_sample_rate_ = options.target_sample_rate(); + streaming_mode_ = options.streaming_mode(); + if (streaming_mode_) { + check_inconsistent_timestamps_ = options.check_inconsistent_timestamps(); + sample_buffer_.resize(num_channels_, Eigen::NoChange); + } + + RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ + !kAudioIn(cc).Header().IsEmpty()) + << "Must either specify the time series header of the \"AUDIO\" stream " + "or have the \"SAMPLE_RATE\" stream connected."; + if (!kAudioIn(cc).Header().IsEmpty()) { + mediapipe::TimeSeriesHeader input_header; + MP_RETURN_IF_ERROR(mediapipe::time_series_util::FillTimeSeriesHeaderIfValid( + kAudioIn(cc).Header(), &input_header)); + if (streaming_mode_) { + MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate())); + } else { + source_sample_rate_ = input_header.sample_rate(); + } + } + return absl::OkStatus(); +} + +absl::Status AudioToTensorCalculator::Process(CalculatorContext* cc) { + if (cc->InputTimestamp() == Timestamp::PreStream()) { + double current_source_sample_rate = kAudioSampleRateIn(cc).Get(); + if (cc->Options() + .streaming_mode()) { + return SetupStreamingResampler(current_source_sample_rate); + } else { + source_sample_rate_ = current_source_sample_rate; + return absl::OkStatus(); + } + } + // Sanity checks. + const auto& input_frame = kAudioIn(cc).Get(); + if (input_frame.rows() != num_channels_) { + return absl::InvalidArgumentError(absl::StrFormat( + "Audio input has %d channel(s) but the model requires %d channel(s).", + input_frame.rows(), num_channels_)); + } + if (num_channels_ > 1 && input_frame.IsRowMajor) { + return absl::InvalidArgumentError( + "The audio data should be stored in column-major."); + } + return streaming_mode_ ? ProcessStreamingData(cc) + : ProcessNonStreamingData(cc); +} + +absl::Status AudioToTensorCalculator::Close(CalculatorContext* cc) { + if (!streaming_mode_) { + return absl::OkStatus(); + } + if (resampler_) { + Matrix resampled_buffer(num_channels_, 0); + resampler_->Flush(&resampled_buffer); + AppendToSampleBuffer(std::move(resampled_buffer)); + } + return OutputTensors(sample_buffer_, /*should_flush=*/true, cc); +} + +absl::Status AudioToTensorCalculator::ProcessStreamingData( + CalculatorContext* cc) { + const auto& input_buffer = kAudioIn(cc).Get(); + if (initial_timestamp_ == Timestamp::Unstarted()) { + initial_timestamp_ = cc->InputTimestamp(); + next_output_timestamp_ = initial_timestamp_; + } + if (source_sample_rate_ != -1 && check_inconsistent_timestamps_) { + mediapipe::time_series_util::LogWarningIfTimestampIsInconsistent( + cc->InputTimestamp(), initial_timestamp_, cumulative_input_samples_, + source_sample_rate_); + cumulative_input_samples_ += input_buffer.cols(); + } + if (!kAudioSampleRateIn(cc).IsEmpty()) { + double current_source_sample_rate = kAudioSampleRateIn(cc).Get(); + if (resampler_) { + RET_CHECK_EQ(current_source_sample_rate, source_sample_rate_); + } else { + MP_RETURN_IF_ERROR(SetupStreamingResampler(current_source_sample_rate)); + } + } + + if (resampler_) { + Matrix resampled_buffer(num_channels_, 0); + resampler_->ProcessSamples(input_buffer, &resampled_buffer); + AppendToSampleBuffer(std::move(resampled_buffer)); + } else { + // Tries to consume the input matrix first to avoid extra data copy. + auto status_or_matrix = kAudioIn(cc).packet().Consume(); + if (status_or_matrix.ok()) { + Matrix local_matrix(num_channels_, 0); + local_matrix.swap(*status_or_matrix.value()); + AppendToSampleBuffer(std::move(local_matrix)); + } else { + AppendToSampleBuffer(input_buffer); + } + } + + MP_RETURN_IF_ERROR(OutputTensors(sample_buffer_, /*should_flush=*/false, cc)); + // Removes the processed samples from the global sample buffer. + sample_buffer_ = Matrix(sample_buffer_.rightCols(sample_buffer_.cols() - + processed_buffer_cols_ - 1)); + return absl::OkStatus(); +} + +absl::Status AudioToTensorCalculator::ProcessNonStreamingData( + CalculatorContext* cc) { + initial_timestamp_ = cc->InputTimestamp(); + next_output_timestamp_ = initial_timestamp_; + const auto& input_frame = kAudioIn(cc).Get(); + double source_sample_rate = kAudioSampleRateIn(cc).GetOr(source_sample_rate_); + + if (source_sample_rate != -1 && source_sample_rate != target_sample_rate_) { + std::vector resampled = audio_dsp::QResampleSignal( + source_sample_rate, target_sample_rate_, num_channels_, params_, + input_frame); + Eigen::Map matrix_mapping(resampled.data(), num_channels_, + resampled.size() / num_channels_); + return OutputTensors(matrix_mapping, /*should_flush=*/true, cc); + } + return OutputTensors(input_frame, /*should_flush=*/true, cc); +} + +absl::Status AudioToTensorCalculator::SetupStreamingResampler( + double input_sample_rate) { + if (input_sample_rate == source_sample_rate_) { + return absl::OkStatus(); + } + source_sample_rate_ = input_sample_rate; + if (source_sample_rate_ != target_sample_rate_) { + resampler_ = absl::make_unique>( + source_sample_rate_, target_sample_rate_, num_channels_, params_); + if (!resampler_) { + return absl::InternalError("Failed to initialize resampler."); + } + } + return absl::OkStatus(); +} + +void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) { + sample_buffer_.conservativeResize( + Eigen::NoChange, sample_buffer_.cols() + buffer_to_append.cols()); + sample_buffer_.rightCols(buffer_to_append.cols()).swap(buffer_to_append); +} + +absl::StatusOr> AudioToTensorCalculator::ConvertToTensor( + const Matrix& frame_to_convert) { + Tensor tensor(Tensor::ElementType::kFloat32, + Tensor::Shape({num_channels_, num_samples_})); + auto buffer_view = tensor.GetCpuWriteView(); + if (frame_to_convert.size() < num_channels_ * num_samples_) { + std::memset(buffer_view.buffer(), 0, tensor.bytes()); + } + std::memcpy(buffer_view.buffer(), frame_to_convert.data(), + frame_to_convert.size() * sizeof(float)); + std::vector tensor_vector; + tensor_vector.push_back(std::move(tensor)); + return tensor_vector; +} + +absl::Status AudioToTensorCalculator::OutputTensors(const Matrix& buffer, + bool should_flush, + CalculatorContext* cc) { + int next_frame_first_col = 0; + std::vector timestamps; + while ((!streaming_mode_ || !should_flush) && + next_frame_first_col + num_samples_ <= buffer.cols()) { + ASSIGN_OR_RETURN(auto output_tensor, ConvertToTensor(buffer.block( + 0, next_frame_first_col, + num_channels_, num_samples_))); + kTensorsOut(cc).Send(std::move(output_tensor), next_output_timestamp_); + timestamps.push_back(next_output_timestamp_); + next_output_timestamp_ += round(frame_step_ / target_sample_rate_ * + Timestamp::kTimestampUnitsPerSecond); + next_frame_first_col += frame_step_; + } + if (should_flush && next_frame_first_col < buffer.cols()) { + ASSIGN_OR_RETURN(auto output_tensor, + ConvertToTensor(buffer.block( + 0, next_frame_first_col, num_channels_, + std::min(num_samples_, + (int)buffer.cols() - next_frame_first_col)))); + // In the streaming mode, the flush happens in Close() and a packet at + // Timestamp::Max() will be emitted. In the non-streaming mode, each + // Process() invocation will process the entire buffer completely. + Timestamp timestamp = + streaming_mode_ ? Timestamp::Max() : next_output_timestamp_; + timestamps.push_back(timestamp); + kTensorsOut(cc).Send(std::move(output_tensor), timestamp); + } + if (kTimestampsOut(cc).IsConnected()) { + Timestamp timestamp = timestamps.back(); + kTimestampsOut(cc).Send(std::move(timestamps), timestamp); + } + processed_buffer_cols_ = next_frame_first_col - 1; + return absl::OkStatus(); +} + +MEDIAPIPE_REGISTER_NODE(AudioToTensorCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto new file mode 100644 index 0000000000..c63991fc3f --- /dev/null +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto @@ -0,0 +1,46 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message AudioToTensorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional AudioToTensorCalculatorOptions ext = 448635064; + } + + // The required number of channels the output audio tensor has. + optional int64 num_channels = 1; + + // The required number of samples per channel the output audio tensor has. + optional int64 num_samples = 2; + + // The number of overlapping samples per channel the output audio tensor has. + optional int64 num_overlapping_samples = 3 [default = 0]; + + // The target number of samples per second (hertz) of the audio buffers that + // will be converted into tensors. + optional double target_sample_rate = 4; + + // Whether to treat the input audio stream as a continous stream or a batch + // of unrelated audio buffers. + optional bool streaming_mode = 5 [default = true]; + + // Set to false to disable checks for jitter in timestamp values. Useful with + // live audio input. + optional bool check_inconsistent_timestamps = 6 [default = true]; +} diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator_test.cc new file mode 100644 index 0000000000..1b8cb9c8d2 --- /dev/null +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator_test.cc @@ -0,0 +1,483 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "audio/dsp/resampler_q.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { +namespace { + +std::unique_ptr CreateTestMatrix(int num_channels, int num_samples, + int timestamp) { + auto matrix = std::make_unique(num_channels, num_samples); + for (int c = 0; c < num_channels; ++c) { + for (int i = 0; i < num_samples; ++i) { + // A float value with the sample, channel, and timestamp separated by a + // few orders of magnitude, for easy parsing by humans. + (*matrix)(c, i) = timestamp / 10000 + i + c / 100.0; + } + } + return matrix; +} + +std::unique_ptr ResampleBuffer(const Matrix& input_matrix, + double resampling_factor) { + audio_dsp::QResamplerParams params; + std::vector resampled; + int num_channels = input_matrix.rows(); + std::vector input_data(input_matrix.data(), + input_matrix.data() + input_matrix.size()); + resampled = audio_dsp::QResampleSignal( + 1, resampling_factor, num_channels, params, input_data); + Matrix res = Eigen::Map(resampled.data(), num_channels, + resampled.size() / num_channels); + return std::make_unique(std::move(res)); +} + +class AudioToTensorCalculatorNonStreamingModeTest : public ::testing::Test { + protected: + void SetUp() override {} + void Run(int num_samples, int num_overlapping_samples, + double resampling_factor, const Matrix& input_matrix) { + double input_sample_rate = 10000; + double target_sample_rate = input_sample_rate * resampling_factor; + auto graph_config = ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "audio" + input_stream: "sample_rate" + output_stream: "tensors" + output_stream: "timestamps" + node { + calculator: "AudioToTensorCalculator" + input_stream: "AUDIO:audio" + input_stream: "SAMPLE_RATE:sample_rate" + output_stream: "TENSORS:tensors" + output_stream: "TIMESTAMPS:timestamps" + options { + [mediapipe.AudioToTensorCalculatorOptions.ext] { + num_channels: $0 + num_samples: $1 + num_overlapping_samples: $2 + target_sample_rate: $3 + streaming_mode: false + } + } + } + )", + /*$0=*/input_matrix.rows(), + /*$1=*/num_samples, /*$2=*/num_overlapping_samples, + /*$3=*/target_sample_rate)); + tool::AddVectorSink("tensors", &graph_config, &tensors_packets_); + tool::AddVectorSink("timestamps", &graph_config, ×tamps_packets_); + + // Run the graph. + MP_ASSERT_OK(graph_.Initialize(graph_config)); + MP_ASSERT_OK(graph_.StartRun({})); + // Run with the input matrix multiple times. + for (int i = 0; i < num_iterations_; ++i) { + MP_ASSERT_OK(graph_.AddPacketToInputStream( + "audio", + MakePacket(input_matrix) + .At(Timestamp(i * Timestamp::kTimestampUnitsPerSecond)))); + MP_ASSERT_OK(graph_.AddPacketToInputStream( + "sample_rate", + MakePacket(input_sample_rate) + .At(Timestamp(i * Timestamp::kTimestampUnitsPerSecond)))); + } + MP_ASSERT_OK(graph_.CloseAllInputStreams()); + MP_ASSERT_OK(graph_.WaitUntilIdle()); + } + + void CheckTensorsOutputPackets(const Matrix& expected_matrix, + int sample_offset, int num_tensors_per_input) { + ASSERT_EQ(num_iterations_ * num_tensors_per_input, tensors_packets_.size()); + for (int i = 0; i < num_iterations_; ++i) { + for (int j = 0; j < num_tensors_per_input; ++j) { + CheckTensorsOutputPacket( + expected_matrix, tensors_packets_[i * num_tensors_per_input + j], + /*sample_offset*/ sample_offset * j, /*index=*/j); + } + } + } + + void CheckTensorsOutputPacket(const Matrix& expected_matrix, + const Packet& packet, int sample_offset, + int index) { + MP_ASSERT_OK(packet.ValidateAsType>()); + ASSERT_EQ(1, packet.Get>().size()); + const Tensor& output_tensor = packet.Get>()[0]; + auto* buffer = output_tensor.GetCpuReadView().buffer(); + int num_values = output_tensor.shape().num_elements(); + const std::vector output_floats(buffer, buffer + num_values); + for (int i = 0; i < num_values; ++i) { + if (i + sample_offset >= expected_matrix.size()) { + EXPECT_FLOAT_EQ(output_floats[i], 0); + } else { + EXPECT_FLOAT_EQ(output_floats[i], + expected_matrix.coeff((i + sample_offset) % 2, + (i + sample_offset) / 2)) + << "i=" << i << ", sample_offset=" << sample_offset; + } + } + } + + void CheckTimestampsOutputPackets( + std::vector expected_timestamp_values) { + ASSERT_EQ(num_iterations_, timestamps_packets_.size()); + for (int i = 0; i < timestamps_packets_.size(); ++i) { + const auto& p = timestamps_packets_[i]; + MP_ASSERT_OK(p.ValidateAsType>()); + auto output_timestamps = p.Get>(); + int64 base_timestamp = i * Timestamp::kTimestampUnitsPerSecond; + std::vector expected_timestamps; + expected_timestamps.resize(expected_timestamp_values.size()); + std::transform( + expected_timestamp_values.begin(), expected_timestamp_values.end(), + expected_timestamps.begin(), [base_timestamp](int64 v) -> Timestamp { + return Timestamp(v + base_timestamp); + }); + EXPECT_EQ(expected_timestamps, output_timestamps); + EXPECT_EQ(p.Timestamp(), expected_timestamps.back()); + } + } + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); } + + private: + CalculatorGraph graph_; + int num_iterations_ = 10; + std::vector tensors_packets_; + std::vector timestamps_packets_; +}; + +TEST_F(AudioToTensorCalculatorNonStreamingModeTest, + ConvertToNoOverlappingFp32Tensors) { + auto input_matrix = CreateTestMatrix(2, 8, 0); + Run(/*num_samples=*/4, /*num_overlapping_samples=*/0, + /*resampling_factor=*/1.0f, *input_matrix); + CheckTensorsOutputPackets(*input_matrix, /*sample_offset=*/8, + /*num_tensors_per_input=*/2); + CheckTimestampsOutputPackets({0, 400}); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorNonStreamingModeTest, + ConvertToOverlappingFp32Tensors) { + auto input_matrix = CreateTestMatrix(2, 8, 0); + Run(/*num_samples=*/4, /*num_overlapping_samples=*/2, + /*resampling_factor=*/1.0f, *input_matrix); + CheckTensorsOutputPackets(*input_matrix, /*sample_offset=*/4, + /*num_tensors_per_input=*/4); + CheckTimestampsOutputPackets({0, 200, 400, 600}); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorNonStreamingModeTest, TensorsWithZeroPadding) { + auto input_matrix = CreateTestMatrix(2, 7, 0); + Run(/*num_samples=*/4, /*num_overlapping_samples=*/2, + /*resampling_factor=*/1.0f, *input_matrix); + CheckTensorsOutputPackets(*input_matrix, /*sample_offset=*/4, + /*num_tensors_per_input=*/3); + CheckTimestampsOutputPackets({0, 200, 400}); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorNonStreamingModeTest, Downsampling) { + auto input_matrix = CreateTestMatrix(2, 1024, 0); + Run(/*num_samples=*/256, /*num_overlapping_samples=*/0, + /*resampling_factor=*/0.5f, *input_matrix); + auto expected_matrix = + ResampleBuffer(*input_matrix, /*resampling_factor=*/0.5f); + CheckTensorsOutputPackets(*expected_matrix, /*sample_offset=*/512, + /*num_tensors_per_input=*/3); + CheckTimestampsOutputPackets({0, 51200, 102400}); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorNonStreamingModeTest, + DownsamplingWithOverlapping) { + auto input_matrix = CreateTestMatrix(2, 1024, 0); + Run(/*num_samples=*/256, /*num_overlapping_samples=*/64, + /*resampling_factor=*/0.5f, *input_matrix); + auto expected_matrix = + ResampleBuffer(*input_matrix, /*resampling_factor=*/0.5f); + CheckTensorsOutputPackets(*expected_matrix, /*sample_offset=*/384, + /*num_tensors_per_input=*/3); + CheckTimestampsOutputPackets({0, 38400, 76800}); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorNonStreamingModeTest, Upsampling) { + auto input_matrix = CreateTestMatrix(2, 1024, 0); + Run(/*num_samples=*/256, /*num_overlapping_samples=*/0, + /*resampling_factor=*/2.0f, *input_matrix); + auto expected_matrix = + ResampleBuffer(*input_matrix, /*resampling_factor=*/2.0f); + CheckTensorsOutputPackets(*expected_matrix, + /*sample_offset=*/512, + /*num_tensors_per_input=*/9); + CheckTimestampsOutputPackets( + {0, 12800, 25600, 38400, 51200, 64000, 76800, 89600, 102400}); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorNonStreamingModeTest, UpsamplingWithOverlapping) { + auto input_matrix = CreateTestMatrix(2, 256, 0); + Run(/*num_samples=*/256, /*num_overlapping_samples=*/64, + /*resampling_factor=*/2.0f, *input_matrix); + auto expected_matrix = + ResampleBuffer(*input_matrix, /*resampling_factor=*/2.0f); + CheckTensorsOutputPackets(*expected_matrix, + /*sample_offset=*/384, + /*num_tensors_per_input=*/3); + CheckTimestampsOutputPackets({0, 9600, 19200}); + CloseGraph(); +} + +class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { + protected: + void SetUp() override { sample_buffer_ = std::make_unique(2, 0); } + + void SetInputBufferNumSamplesPerChannel(int num_samples) { + input_buffer_num_samples_ = num_samples; + } + + void SetNumIterations(int num_iterations) { + num_iterations_ = num_iterations; + } + + int GetExpectedNumOfSamples() { + Matrix* expected_matrix = + resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get(); + return expected_matrix->cols(); + } + + void Run(int num_samples, int num_overlapping_samples, + double resampling_factor) { + double input_sample_rate = 10000; + double target_sample_rate = input_sample_rate * resampling_factor; + auto graph_config = ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "audio" + input_stream: "sample_rate" + output_stream: "tensors" + node { + calculator: "AudioToTensorCalculator" + input_stream: "AUDIO:audio" + input_stream: "SAMPLE_RATE:sample_rate" + output_stream: "TENSORS:tensors" + options { + [mediapipe.AudioToTensorCalculatorOptions.ext] { + num_channels: 2 + num_samples: $0 + num_overlapping_samples: $1 + target_sample_rate: $2 + streaming_mode:true + } + } + } + )", + /*$0=*/num_samples, /*$1=*/num_overlapping_samples, + /*$2=*/target_sample_rate)); + tool::AddVectorSink("tensors", &graph_config, &tensors_packets_); + + // Run the graph. + MP_ASSERT_OK(graph_.Initialize(graph_config)); + MP_ASSERT_OK(graph_.StartRun({})); + for (int i = 0; i < num_iterations_; ++i) { + Timestamp input_timestamp(Timestamp::kTimestampUnitsPerSecond * i); + auto new_data = CreateTestMatrix(2, input_buffer_num_samples_, + input_timestamp.Value()); + MP_ASSERT_OK(graph_.AddPacketToInputStream( + "audio", MakePacket(*new_data).At(input_timestamp))); + MP_ASSERT_OK(graph_.AddPacketToInputStream( + "sample_rate", + MakePacket(input_sample_rate).At(input_timestamp))); + sample_buffer_->conservativeResize( + Eigen::NoChange, sample_buffer_->cols() + new_data->cols()); + sample_buffer_->rightCols(new_data->cols()).swap(*new_data); + } + MP_ASSERT_OK(graph_.CloseAllInputStreams()); + MP_ASSERT_OK(graph_.WaitUntilIdle()); + if (resampling_factor != 1) { + resampled_buffer_ = ResampleBuffer(*sample_buffer_, resampling_factor); + } + } + + void CheckTensorsOutputPackets(int sample_offset, int num_packets, + int64 timestamp_interval, + bool output_last_at_close) { + ASSERT_EQ(num_packets, tensors_packets_.size()); + for (int i = 0; i < num_packets; ++i) { + if (i == num_packets - 1 && output_last_at_close) { + CheckTensorsOutputPacket(sample_offset * i, i, Timestamp::Max()); + } else { + CheckTensorsOutputPacket(sample_offset * i, i, + Timestamp(timestamp_interval * i)); + } + } + } + + void CheckTensorsOutputPacket(int sample_offset, int index, + Timestamp expected_timestamp) { + const Packet& p = tensors_packets_[index]; + MP_ASSERT_OK(p.ValidateAsType>()); + const Tensor& output_tensor = p.Get>()[0]; + auto buffer = output_tensor.GetCpuReadView().buffer(); + int num_values = output_tensor.shape().num_elements(); + std::vector output_floats(buffer, buffer + num_values); + Matrix* expected_matrix = + resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get(); + for (int i = 0; i < num_values; ++i) { + if (i + sample_offset >= expected_matrix->size()) { + EXPECT_FLOAT_EQ(output_floats[i], 0); + } else { + EXPECT_NEAR(output_floats[i], + expected_matrix->coeff((i + sample_offset) % 2, + (i + sample_offset) / 2), + 0.001) + << "i=" << i << ", sample_offset=" << sample_offset + << ", packet index=" << index; + } + } + EXPECT_EQ(p.Timestamp(), expected_timestamp); + } + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); } + + private: + int input_buffer_num_samples_ = 10; + int num_iterations_ = 10; + CalculatorGraph graph_; + std::vector tensors_packets_; + std::unique_ptr sample_buffer_; + std::unique_ptr resampled_buffer_; +}; + +TEST_F(AudioToTensorCalculatorStreamingModeTest, + OutputNoOverlappingFp32Tensors) { + Run(/*num_samples=*/5, /*num_overlapping_samples=*/0, + /*resampling_factor=*/1.0f); + CheckTensorsOutputPackets( + /*sample_offset=*/10, + /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 5), + /*timestamp_interval=*/500, + /*output_last_at_close=*/false); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputRemainingInCloseMethod) { + Run(/*num_samples=*/6, /*num_overlapping_samples=*/0, + /*resampling_factor=*/1.0f); + CheckTensorsOutputPackets( + /*sample_offset=*/12, + /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 6), + /*timestamp_interval=*/600, + /*output_last_at_close=*/true); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputOverlappingFp32Tensors) { + SetInputBufferNumSamplesPerChannel(12); + Run(/*num_samples=*/10, /*num_overlapping_samples=*/2, + /*resampling_factor=*/1.0f); + CheckTensorsOutputPackets( + /*sample_offset=*/16, + /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 8), + /*timestamp_interval=*/800, + /*output_last_at_close=*/true); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorStreamingModeTest, Downsampling) { + SetInputBufferNumSamplesPerChannel(1000); + Run(/*num_samples=*/256, /*num_overlapping_samples=*/0, + /*resampling_factor=*/0.5f); + CheckTensorsOutputPackets( + /*sample_offset=*/512, + /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256), + /*timestamp_interval=*/51200, + /*output_last_at_close=*/true); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorStreamingModeTest, DownsamplingWithOverlapping) { + SetInputBufferNumSamplesPerChannel(1024); + Run(/*num_samples=*/256, /*num_overlapping_samples=*/64, + /*resampling_factor=*/0.5f); + CheckTensorsOutputPackets( + /*sample_offset=*/384, + /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192), + /*timestamp_interval=*/38400, + /*output_last_at_close=*/true); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorStreamingModeTest, Upsampling) { + SetInputBufferNumSamplesPerChannel(1000); + Run(/*num_samples=*/256, /*num_overlapping_samples=*/0, + /*resampling_factor=*/2.0f); + CheckTensorsOutputPackets( + /*sample_offset=*/512, + /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256), + /*timestamp_interval=*/12800, + /*output_last_at_close=*/true); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorStreamingModeTest, UpsamplingWithOverlapping) { + SetInputBufferNumSamplesPerChannel(1024); + Run(/*num_samples=*/256, /*num_overlapping_samples=*/64, + /*resampling_factor=*/2.0f); + CheckTensorsOutputPackets( + /*sample_offset=*/384, + /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192), + /*timestamp_interval=*/9600, + /*output_last_at_close=*/true); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorStreamingModeTest, + OnlyOutputInCloseIfNoSufficientSamples) { + SetNumIterations(1); + Run(/*num_samples=*/8, /*num_overlapping_samples=*/0, + /*resampling_factor=*/0.5f); + CheckTensorsOutputPackets( + /*sample_offset=*/0, + /*num_packets=*/1, + /*timestamp_interval=*/0, + /*output_last_at_close=*/true); + CloseGraph(); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index c143c99010..e2c5c9006d 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -19,8 +19,8 @@ #include #include -#include "absl/memory/memory.h" #include "absl/strings/string_view.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/tool/subgraph_expansion.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -43,8 +43,19 @@ class InferenceCalculatorSelectorImpl !options.has_delegate() || // Use GPU delegate if not specified (options.has_delegate() && options.delegate().has_gpu()); if (should_use_gpu) { + const auto& api = options.delegate().gpu().api(); + using Gpu = ::mediapipe::InferenceCalculatorOptions::Delegate::Gpu; impls.emplace_back("Metal"); - impls.emplace_back("Gl"); + const bool prefer_gl_advanced = + options.delegate().gpu().use_advanced_gpu_api() && + (api == Gpu::ANY || api == Gpu::OPENGL || api == Gpu::OPENCL); + if (prefer_gl_advanced) { + impls.emplace_back("GlAdvanced"); + impls.emplace_back("Gl"); + } else { + impls.emplace_back("Gl"); + impls.emplace_back("GlAdvanced"); + } } impls.emplace_back("Cpu"); for (const auto& suffix : impls) { diff --git a/mediapipe/calculators/tensor/inference_calculator.h b/mediapipe/calculators/tensor/inference_calculator.h index b5f3a0a157..52425dd069 100644 --- a/mediapipe/calculators/tensor/inference_calculator.h +++ b/mediapipe/calculators/tensor/inference_calculator.h @@ -134,6 +134,10 @@ struct InferenceCalculatorGl : public InferenceCalculator { static constexpr char kCalculatorName[] = "InferenceCalculatorGl"; }; +struct InferenceCalculatorGlAdvanced : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculatorGlAdvanced"; +}; + struct InferenceCalculatorMetal : public InferenceCalculator { static constexpr char kCalculatorName[] = "InferenceCalculatorMetal"; }; diff --git a/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc b/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc index bb383af714..03ee87d4de 100644 --- a/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc +++ b/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc @@ -75,9 +75,10 @@ const std::vector& GetParams() { class InferenceCalculatorTest : public testing::TestWithParam { protected: void SetDelegateForParam(mediapipe::CalculatorGraphConfig_Node* node) { - *node->mutable_options() - ->MutableExtension(mediapipe::InferenceCalculatorOptions::ext) - ->mutable_delegate() = GetParam().delegate; + auto options_map = tool::MutableOptionsMap().Initialize(*node); + auto options = options_map.Get(); + *options.mutable_delegate() = GetParam().delegate; + options_map.Set(options); } }; diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index eb6ab9f40b..55cb80c3a6 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -20,22 +20,8 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" #include "mediapipe/calculators/tensor/inference_calculator.h" -#include "mediapipe/framework/deps/file_path.h" -#include "mediapipe/util/tflite/config.h" - -#if MEDIAPIPE_TFLITE_GL_INFERENCE #include "mediapipe/gpu/gl_calculator_helper.h" -#include "mediapipe/gpu/gpu_buffer.h" -#include "mediapipe/util/tflite/tflite_gpu_runner.h" -#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE - -#if defined(MEDIAPIPE_ANDROID) -#include "mediapipe/util/android/file/base/file.h" -#include "mediapipe/util/android/file/base/filesystem.h" -#include "mediapipe/util/android/file/base/helpers.h" -#endif // ANDROID namespace mediapipe { namespace api2 { @@ -50,42 +36,22 @@ class InferenceCalculatorGlImpl absl::Status Close(CalculatorContext* cc) override; private: - absl::Status ReadGpuCaches(); - absl::Status SaveGpuCaches(); absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc); absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); - absl::Status InitTFLiteGPURunner(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. Packet model_packet_; -#if MEDIAPIPE_TFLITE_GL_INFERENCE mediapipe::GlCalculatorHelper gpu_helper_; - std::unique_ptr tflite_gpu_runner_; bool allow_precision_loss_ = false; - mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api - tflite_gpu_runner_api_; - mediapipe::InferenceCalculatorOptions::Delegate::Gpu::InferenceUsage - tflite_gpu_runner_usage_; -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE TfLiteDelegatePtr delegate_; std::unique_ptr interpreter_; -#if MEDIAPIPE_TFLITE_GPU_SUPPORTED std::vector output_shapes_; std::vector> gpu_buffers_in_; std::vector> gpu_buffers_out_; -#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED - - bool use_advanced_gpu_api_ = false; - bool use_gpu_delegate_ = false; - - bool use_kernel_caching_ = false; - std::string cached_kernel_filename_; - bool use_serialized_model_ = false; - std::string serialized_model_path_; }; absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) { @@ -93,8 +59,7 @@ absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) { RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) << "Either model as side packet or model path in options is required."; - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); - return absl::OkStatus(); + return mediapipe::GlCalculatorHelper::UpdateContract(cc); } absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { @@ -110,46 +75,12 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { << "for Gpu"; delegate.MergeFrom(input_side_packet_delegate); } - const bool has_delegate = options.has_delegate() || !kDelegate(cc).IsEmpty(); - use_advanced_gpu_api_ = has_delegate && delegate.has_gpu() && - delegate.gpu().use_advanced_gpu_api(); - allow_precision_loss_ = delegate.gpu().allow_precision_loss(); - tflite_gpu_runner_api_ = delegate.gpu().api(); - tflite_gpu_runner_usage_ = delegate.gpu().usage(); - use_kernel_caching_ = - use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path(); - use_serialized_model_ = use_advanced_gpu_api_ && - delegate.gpu().has_serialized_model_dir() && - delegate.gpu().has_model_token(); - use_gpu_delegate_ = !use_advanced_gpu_api_; - - if (use_kernel_caching_) { -#ifdef MEDIAPIPE_ANDROID - cached_kernel_filename_ = delegate.gpu().cached_kernel_path() + - mediapipe::File::Basename(options.model_path()) + - ".ker"; -#endif // MEDIAPIPE_ANDROID - } - if (use_serialized_model_) { -#ifdef MEDIAPIPE_ANDROID - serialized_model_path_ = mediapipe::file::JoinPath( - delegate.gpu().serialized_model_dir(), delegate.gpu().model_token()); -#endif // MEDIAPIPE_ANDROID - } - - // When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner - // for everything. - if (!use_advanced_gpu_api_) { - MP_RETURN_IF_ERROR(LoadModel(cc)); - } + MP_RETURN_IF_ERROR(LoadModel(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { - return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) - : LoadDelegateAndAllocateTensors(cc); - })); - return absl::OkStatus(); + return gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { + return LoadDelegateAndAllocateTensors(cc); + }); } absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) { @@ -160,205 +91,53 @@ absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) { RET_CHECK(!input_tensors.empty()); auto output_tensors = absl::make_unique>(); - if (use_advanced_gpu_api_) { - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &input_tensors, &output_tensors]() -> ::mediapipe::Status { - for (int i = 0; i < input_tensors.size(); ++i) { - MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( - input_tensors[i].GetOpenGlBufferReadView().name(), i)); - } - output_tensors->reserve(output_shapes_.size()); - for (int i = 0; i < output_shapes_.size(); ++i) { - output_tensors->emplace_back(Tensor::ElementType::kFloat32, - output_shapes_[i]); - MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor( - output_tensors->back().GetOpenGlBufferWriteView().name(), i)); - } - return absl::OkStatus(); - })); - } else { - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &input_tensors]() -> ::mediapipe::Status { - // Explicitly copy input. - for (int i = 0; i < input_tensors.size(); ++i) { - glBindBuffer(GL_COPY_READ_BUFFER, - input_tensors[i].GetOpenGlBufferReadView().name()); - glBindBuffer(GL_COPY_WRITE_BUFFER, - gpu_buffers_in_[i]->GetOpenGlBufferWriteView().name()); - glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, - input_tensors[i].bytes()); - } - return absl::OkStatus(); - })); - } + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &input_tensors]() -> ::mediapipe::Status { + // Explicitly copy input. + for (int i = 0; i < input_tensors.size(); ++i) { + glBindBuffer(GL_COPY_READ_BUFFER, + input_tensors[i].GetOpenGlBufferReadView().name()); + glBindBuffer(GL_COPY_WRITE_BUFFER, + gpu_buffers_in_[i]->GetOpenGlBufferWriteView().name()); + glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, + input_tensors[i].bytes()); + } + return absl::OkStatus(); + })); // Run inference. - if (use_advanced_gpu_api_) { - RET_CHECK(tflite_gpu_runner_->Invoke().ok()); - } else { - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); - } - - if (use_gpu_delegate_) { - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &output_tensors]() -> ::mediapipe::Status { - output_tensors->reserve(output_shapes_.size()); - for (int i = 0; i < output_shapes_.size(); ++i) { - const auto& t = gpu_buffers_out_[i]; - output_tensors->emplace_back(Tensor::ElementType::kFloat32, - gpu_buffers_out_[i]->shape()); - auto read_view = t->GetOpenGlBufferReadView(); - glBindBuffer(GL_COPY_READ_BUFFER, read_view.name()); - auto write_view = output_tensors->back().GetOpenGlBufferWriteView(); - glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name()); - glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, - t->bytes()); - } - return absl::OkStatus(); - })); - } - // Output tensors are already bound if use_advanced_gpu_api_ is true. + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &output_tensors]() -> ::mediapipe::Status { + output_tensors->reserve(output_shapes_.size()); + for (int i = 0; i < output_shapes_.size(); ++i) { + const auto& t = gpu_buffers_out_[i]; + output_tensors->emplace_back(Tensor::ElementType::kFloat32, + gpu_buffers_out_[i]->shape()); + auto read_view = t->GetOpenGlBufferReadView(); + glBindBuffer(GL_COPY_READ_BUFFER, read_view.name()); + auto write_view = output_tensors->back().GetOpenGlBufferWriteView(); + glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name()); + glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, + t->bytes()); + } + return absl::OkStatus(); + })); kOutTensors(cc).Send(std::move(output_tensors)); return absl::OkStatus(); } -absl::Status InferenceCalculatorGlImpl::SaveGpuCaches() { -#ifdef MEDIAPIPE_ANDROID - if (use_kernel_caching_) { - // Save kernel file. - auto kernel_cache = absl::make_unique>( - tflite_gpu_runner_->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); - MP_RETURN_IF_ERROR( - mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); - } - if (use_serialized_model_) { - // Save serialized model file. - ASSIGN_OR_RETURN(std::vector serialized_model_vec, - tflite_gpu_runner_->GetSerializedModel()); - absl::string_view serialized_model( - reinterpret_cast(serialized_model_vec.data()), - serialized_model_vec.size()); - MP_RETURN_IF_ERROR( - mediapipe::file::SetContents(serialized_model_path_, serialized_model)); - } -#endif // MEDIAPIPE_ANDROID - return absl::OkStatus(); -} - absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) { - MP_RETURN_IF_ERROR(SaveGpuCaches()); - if (use_gpu_delegate_) { - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { - gpu_buffers_in_.clear(); - gpu_buffers_out_.clear(); - // Delegate must outlive the interpreter, hence the order is important. - interpreter_ = nullptr; - delegate_ = nullptr; - return absl::OkStatus(); - })); - } else { + return gpu_helper_.RunInGlContext([this]() -> absl::Status { + gpu_buffers_in_.clear(); + gpu_buffers_out_.clear(); // Delegate must outlive the interpreter, hence the order is important. interpreter_ = nullptr; delegate_ = nullptr; - } - - return absl::OkStatus(); -} - -absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() { -#ifdef MEDIAPIPE_ANDROID - if (use_kernel_caching_ && File::Exists(cached_kernel_filename_)) { - // Load pre-compiled kernel file. - std::string cache_str; - MP_RETURN_IF_ERROR( - mediapipe::file::GetContents(cached_kernel_filename_, &cache_str)); - std::vector cache_vec(cache_str.begin(), cache_str.end()); - tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec)); - } - if (use_serialized_model_ && File::Exists(serialized_model_path_)) { - // Load serialized model file. - std::string serialized_model_str; - MP_RETURN_IF_ERROR( - file::GetContents(serialized_model_path_, &serialized_model_str)); - std::vector serialized_model_vec(serialized_model_str.begin(), - serialized_model_str.end()); - tflite_gpu_runner_->SetSerializedModel(std::move(serialized_model_vec)); - } -#endif // MEDIAPIPE_ANDROID - return absl::OkStatus(); -} - -absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( - CalculatorContext* cc) { - ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); - const auto& model = *model_packet_.Get(); - - // Create runner - tflite::gpu::InferenceOptions options; - options.priority1 = allow_precision_loss_ - ? tflite::gpu::InferencePriority::MIN_LATENCY - : tflite::gpu::InferencePriority::MAX_PRECISION; - options.priority2 = tflite::gpu::InferencePriority::AUTO; - options.priority3 = tflite::gpu::InferencePriority::AUTO; - switch (tflite_gpu_runner_usage_) { - case mediapipe::InferenceCalculatorOptions::Delegate::Gpu:: - FAST_SINGLE_ANSWER: { - options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER; - break; - } - case mediapipe::InferenceCalculatorOptions::Delegate::Gpu:: - SUSTAINED_SPEED: { - options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; - break; - } - case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::UNSPECIFIED: { - return absl::InternalError("inference usage need to be specified."); - } - } - tflite_gpu_runner_ = std::make_unique(options); - switch (tflite_gpu_runner_api_) { - case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: { - // Do not need to force any specific API. - break; - } - case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: { - tflite_gpu_runner_->ForceOpenGL(); - break; - } - case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENCL: { - tflite_gpu_runner_->ForceOpenCL(); - break; - } - } - if (kSideInOpResolver(cc).IsConnected()) { - const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get(); - MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( - model, op_resolver, /*allow_quant_ops=*/true)); - } else { - tflite::ops::builtin::BuiltinOpResolver op_resolver = - kSideInCustomOpResolver(cc).GetOr( - tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); - MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( - model, op_resolver, /*allow_quant_ops=*/true)); - } - - // Create and bind OpenGL buffers for outputs. - // The buffers are created once and their ids are passed to calculator outputs - output_shapes_.resize(tflite_gpu_runner_->outputs_size()); - for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) { - output_shapes_[i] = {tflite_gpu_runner_->GetOutputShapes()[i].b, - tflite_gpu_runner_->GetOutputShapes()[i].h, - tflite_gpu_runner_->GetOutputShapes()[i].w, - tflite_gpu_runner_->GetOutputShapes()[i].c}; - } - - MP_RETURN_IF_ERROR(ReadGpuCaches()); - - MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); - - return absl::OkStatus(); + return absl::OkStatus(); + }); } absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { @@ -375,12 +154,8 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { } RET_CHECK(interpreter_); -#if defined(__EMSCRIPTEN__) - interpreter_->SetNumThreads(1); -#else interpreter_->SetNumThreads( cc->Options().cpu_num_thread()); -#endif // __EMSCRIPTEN__ return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc new file mode 100644 index 0000000000..cdadc4e612 --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -0,0 +1,285 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "mediapipe/calculators/tensor/inference_calculator.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/util/tflite/tflite_gpu_runner.h" + +#if defined(MEDIAPIPE_ANDROID) +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/util/android/file/base/file.h" +#include "mediapipe/util/android/file/base/filesystem.h" +#include "mediapipe/util/android/file/base/helpers.h" +#endif // ANDROID + +namespace mediapipe { +namespace api2 { + +// Runs TFLite GPU delegate API2 directly, bypassing interpreter usage, and +// allows choosing specific API. +// +// To trigger this code path: +// [mediapipe.InferenceCalculatorOptions.ext] { +// delegate { +// gpu { +// use_advanced_gpu_api: true +// api: OPENCL # or OPENGL or ANY +// } +// } +// } +class InferenceCalculatorGlAdvancedImpl + : public NodeImpl { + public: + static absl::Status UpdateContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status ReadGpuCaches(); + absl::Status SaveGpuCaches(); + absl::Status InitTFLiteGPURunner(CalculatorContext* cc); + + // TfLite requires us to keep the model alive as long as the interpreter is. + Packet model_packet_; + + mediapipe::GlCalculatorHelper gpu_helper_; + std::unique_ptr tflite_gpu_runner_; + bool allow_precision_loss_ = false; + mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api + tflite_gpu_runner_api_; + mediapipe::InferenceCalculatorOptions::Delegate::Gpu::InferenceUsage + tflite_gpu_runner_usage_; + + std::vector output_shapes_; + + bool use_kernel_caching_ = false; + std::string cached_kernel_filename_; + bool use_serialized_model_ = false; + std::string serialized_model_path_; +}; + +absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract( + CalculatorContract* cc) { + const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); + RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) + << "Either model as side packet or model path in options is required."; + + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlAdvancedImpl::Open(CalculatorContext* cc) { + const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); + mediapipe::InferenceCalculatorOptions::Delegate delegate = options.delegate(); + if (!kDelegate(cc).IsEmpty()) { + mediapipe::InferenceCalculatorOptions::Delegate input_side_packet_delegate = + kDelegate(cc).Get(); + CHECK(input_side_packet_delegate.has_gpu() || + input_side_packet_delegate.delegate_case() == + mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET) + << "inference_calculator_gl_advanced only supports delegate input side " + "packet for Gpu"; + delegate.MergeFrom(input_side_packet_delegate); + } + allow_precision_loss_ = delegate.gpu().allow_precision_loss(); + tflite_gpu_runner_api_ = delegate.gpu().api(); + tflite_gpu_runner_usage_ = delegate.gpu().usage(); + use_kernel_caching_ = delegate.gpu().has_cached_kernel_path(); + use_serialized_model_ = delegate.gpu().has_serialized_model_dir() && + delegate.gpu().has_model_token(); + + if (use_kernel_caching_) { +#ifdef MEDIAPIPE_ANDROID + cached_kernel_filename_ = delegate.gpu().cached_kernel_path() + + mediapipe::File::Basename(options.model_path()) + + ".ker"; +#endif // MEDIAPIPE_ANDROID + } + if (use_serialized_model_) { +#ifdef MEDIAPIPE_ANDROID + serialized_model_path_ = mediapipe::file::JoinPath( + delegate.gpu().serialized_model_dir(), delegate.gpu().model_token()); +#endif // MEDIAPIPE_ANDROID + } + + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); + return gpu_helper_.RunInGlContext( + [this, &cc]() -> absl::Status { return InitTFLiteGPURunner(cc); }); +} + +absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) { + if (kInTensors(cc).IsEmpty()) { + return absl::OkStatus(); + } + const auto& input_tensors = *kInTensors(cc); + RET_CHECK(!input_tensors.empty()); + auto output_tensors = absl::make_unique>(); + + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &input_tensors, &output_tensors]() -> absl::Status { + for (int i = 0; i < input_tensors.size(); ++i) { + MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( + input_tensors[i].GetOpenGlBufferReadView().name(), i)); + } + output_tensors->reserve(output_shapes_.size()); + for (int i = 0; i < output_shapes_.size(); ++i) { + output_tensors->emplace_back(Tensor::ElementType::kFloat32, + output_shapes_[i]); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor( + output_tensors->back().GetOpenGlBufferWriteView().name(), i)); + } + return absl::OkStatus(); + })); + + // Run inference. + MP_RETURN_IF_ERROR(tflite_gpu_runner_->Invoke()); + kOutTensors(cc).Send(std::move(output_tensors)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlAdvancedImpl::SaveGpuCaches() { +#ifdef MEDIAPIPE_ANDROID + if (use_kernel_caching_) { + // Save kernel file. + auto kernel_cache = absl::make_unique>( + tflite_gpu_runner_->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + MP_RETURN_IF_ERROR( + mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); + } + if (use_serialized_model_) { + // Save serialized model file. + ASSIGN_OR_RETURN(std::vector serialized_model_vec, + tflite_gpu_runner_->GetSerializedModel()); + absl::string_view serialized_model( + reinterpret_cast(serialized_model_vec.data()), + serialized_model_vec.size()); + MP_RETURN_IF_ERROR( + mediapipe::file::SetContents(serialized_model_path_, serialized_model)); + } +#endif // MEDIAPIPE_ANDROID + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlAdvancedImpl::Close(CalculatorContext* cc) { + MP_RETURN_IF_ERROR(SaveGpuCaches()); + return gpu_helper_.RunInGlContext([this]() -> absl::Status { + tflite_gpu_runner_.reset(); + return absl::OkStatus(); + }); +} + +absl::Status InferenceCalculatorGlAdvancedImpl::ReadGpuCaches() { +#ifdef MEDIAPIPE_ANDROID + if (use_kernel_caching_ && File::Exists(cached_kernel_filename_)) { + // Load pre-compiled kernel file. + std::string cache_str; + MP_RETURN_IF_ERROR( + mediapipe::file::GetContents(cached_kernel_filename_, &cache_str)); + std::vector cache_vec(cache_str.begin(), cache_str.end()); + tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec)); + } + if (use_serialized_model_ && File::Exists(serialized_model_path_)) { + // Load serialized model file. + std::string serialized_model_str; + MP_RETURN_IF_ERROR( + file::GetContents(serialized_model_path_, &serialized_model_str)); + std::vector serialized_model_vec(serialized_model_str.begin(), + serialized_model_str.end()); + tflite_gpu_runner_->SetSerializedModel(std::move(serialized_model_vec)); + } +#endif // MEDIAPIPE_ANDROID + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlAdvancedImpl::InitTFLiteGPURunner( + CalculatorContext* cc) { + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); + const auto& model = *model_packet_.Get(); + + // Create runner + tflite::gpu::InferenceOptions options; + options.priority1 = allow_precision_loss_ + ? tflite::gpu::InferencePriority::MIN_LATENCY + : tflite::gpu::InferencePriority::MAX_PRECISION; + options.priority2 = tflite::gpu::InferencePriority::AUTO; + options.priority3 = tflite::gpu::InferencePriority::AUTO; + switch (tflite_gpu_runner_usage_) { + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu:: + FAST_SINGLE_ANSWER: { + options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER; + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu:: + SUSTAINED_SPEED: { + options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::UNSPECIFIED: { + return absl::InternalError("inference usage need to be specified."); + } + } + tflite_gpu_runner_ = std::make_unique(options); + switch (tflite_gpu_runner_api_) { + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: { + // Do not need to force any specific API. + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: { + tflite_gpu_runner_->ForceOpenGL(); + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENCL: { + tflite_gpu_runner_->ForceOpenCL(); + break; + } + } + if (kSideInOpResolver(cc).IsConnected()) { + const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get(); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( + model, op_resolver, /*allow_quant_ops=*/true)); + } else { + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( + model, op_resolver, /*allow_quant_ops=*/true)); + } + + // Create and bind OpenGL buffers for outputs. + // The buffers are created once and their ids are passed to calculator outputs + output_shapes_.resize(tflite_gpu_runner_->outputs_size()); + for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) { + output_shapes_[i] = {tflite_gpu_runner_->GetOutputShapes()[i].b, + tflite_gpu_runner_->GetOutputShapes()[i].h, + tflite_gpu_runner_->GetOutputShapes()[i].w, + tflite_gpu_runner_->GetOutputShapes()[i].c}; + } + + MP_RETURN_IF_ERROR(ReadGpuCaches()); + return tflite_gpu_runner_->Build(); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_test.cc b/mediapipe/calculators/tensor/inference_calculator_test.cc index 1cff995a3e..fe96c06628 100644 --- a/mediapipe/calculators/tensor/inference_calculator_test.cc +++ b/mediapipe/calculators/tensor/inference_calculator_test.cc @@ -38,61 +38,13 @@ #endif // defined(__APPLE__) namespace mediapipe { +namespace { -void DoSmokeTest(const std::string& graph_proto) { - const int width = 8; - const int height = 8; - const int channels = 3; - // Prepare input tensor. - auto input_vec = absl::make_unique>(); - input_vec->emplace_back(Tensor::ElementType::kFloat32, - Tensor::Shape{1, height, width, channels}); - { - auto view1 = input_vec->back().GetCpuWriteView(); - auto tensor_buffer = view1.buffer(); - ASSERT_NE(tensor_buffer, nullptr); - for (int i = 0; i < width * height * channels - 1; i++) { - tensor_buffer[i] = 1; - } - } - - // Prepare single calculator graph to and wait for packets. - CalculatorGraphConfig graph_config = - ParseTextProtoOrDie(graph_proto); - std::vector output_packets; - tool::AddVectorSink("tensor_out", &graph_config, &output_packets); - CalculatorGraph graph(graph_config); - MP_ASSERT_OK(graph.StartRun({})); - - // Push the tensor into the graph. - MP_ASSERT_OK(graph.AddPacketToInputStream( - "tensor_in", Adopt(input_vec.release()).At(Timestamp(0)))); - // Wait until the calculator done processing. - MP_ASSERT_OK(graph.WaitUntilIdle()); - ASSERT_EQ(1, output_packets.size()); - - // Get and process results. - const std::vector& result_vec = - output_packets[0].Get>(); - ASSERT_EQ(1, result_vec.size()); - - const Tensor& result = result_vec[0]; - auto view = result.GetCpuReadView(); - auto result_buffer = view.buffer(); - ASSERT_NE(result_buffer, nullptr); - for (int i = 0; i < width * height * channels - 1; i++) { - ASSERT_EQ(3, result_buffer[i]); - } +constexpr int kTensorWidth = 8; +constexpr int kTensorHeight = 8; +constexpr int kTensorChannels = 3; - // Fully close graph at end, otherwise calculator+tensors are destroyed - // after calling WaitUntilDone(). - MP_ASSERT_OK(graph.CloseInputStream("tensor_in")); - MP_ASSERT_OK(graph.WaitUntilDone()); -} - -// Tests a simple add model that adds an input tensor to itself. -TEST(InferenceCalculatorTest, SmokeTest) { - std::string graph_proto = R"( +constexpr char kGraphWithModelPathInOption[] = R"( input_stream: "tensor_in" node { calculator: "InferenceCalculator" @@ -106,18 +58,7 @@ TEST(InferenceCalculatorTest, SmokeTest) { } } )"; - // Test CPU inference only. - DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( - graph_proto, {{"$delegate", "delegate { tflite {} }"}})); - DoSmokeTest(absl::StrReplaceAll(graph_proto, - {{"$delegate", "delegate { xnnpack {} }"}})); - DoSmokeTest(absl::StrReplaceAll( - graph_proto, - {{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}})); -} - -TEST(InferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) { - std::string graph_proto = R"( +constexpr char kGraphWithModelAsInputSidePacket[] = R"( input_stream: "tensor_in" node { @@ -154,7 +95,84 @@ TEST(InferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) { } } )"; - DoSmokeTest(graph_proto); + +std::vector CreateInputs() { + std::vector input_vec; + // Prepare input tensor. + input_vec.emplace_back( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, kTensorHeight, kTensorWidth, kTensorChannels}); + { + auto view = input_vec.back().GetCpuWriteView(); + auto num_elements = input_vec.back().shape().num_elements(); + auto tensor_buffer = view.buffer(); + for (int i = 0; i < num_elements; i++) { + tensor_buffer[i] = 1; + } + } + + return input_vec; +} + +void RunGraphThenClose(CalculatorGraph& graph, std::vector input_vec) { + MP_ASSERT_OK(graph.StartRun({})); + + // Push the tensor into the graph. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "tensor_in", + MakePacket>(std::move(input_vec)).At(Timestamp(0)))); + // Wait until the calculator done processing. + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("tensor_in")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +void DoSmokeTest(const std::string& graph_proto) { + auto input_vec = CreateInputs(); + + // Prepare single calculator graph to and wait for packets. + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(graph_proto); + std::vector output_packets; + tool::AddVectorSink("tensor_out", &graph_config, &output_packets); + CalculatorGraph graph(graph_config); + + RunGraphThenClose(graph, std::move(input_vec)); + + ASSERT_EQ(1, output_packets.size()); + + // Get and process results. + const std::vector& result_vec = + output_packets[0].Get>(); + ASSERT_EQ(1, result_vec.size()); + + const Tensor& result = result_vec[0]; + auto view = result.GetCpuReadView(); + auto result_buffer = view.buffer(); + ASSERT_NE(result_buffer, nullptr); + for (int i = 0; i < result.shape().num_elements(); i++) { + ASSERT_EQ(3, result_buffer[i]); + } +} + +// Tests a simple add model that adds an input tensor to itself. +TEST(InferenceCalculatorTest, SmokeTest) { + // Test CPU inference only. + DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( + kGraphWithModelPathInOption, {{"$delegate", "delegate { tflite {} }"}})); + DoSmokeTest(absl::StrReplaceAll(kGraphWithModelPathInOption, + {{"$delegate", "delegate { xnnpack {} }"}})); + DoSmokeTest(absl::StrReplaceAll( + kGraphWithModelPathInOption, + {{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}})); +} + +TEST(InferenceCalculatorTest, ModelAsInputSidePacketSmokeTest) { + DoSmokeTest(kGraphWithModelAsInputSidePacket); } +} // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc index 87216f4d27..2f73549586 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc @@ -16,7 +16,6 @@ #include #include -#include "absl/container/node_hash_map.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h" @@ -25,6 +24,7 @@ #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/resource_util.h" #if defined(MEDIAPIPE_MOBILE) #include "mediapipe/util/android/file/base/file.h" @@ -35,6 +35,17 @@ namespace mediapipe { namespace api2 { +namespace { + +void SetClassificationLabel(const LabelMapItem label_map_item, + Classification* classification) { + classification->set_label(label_map_item.name()); + if (label_map_item.has_display_name()) { + classification->set_display_name(label_map_item.display_name()); + } +} + +} // namespace // Convert result tensors from classification models into MediaPipe // classifications. @@ -54,7 +65,6 @@ namespace api2 { // output_stream: "CLASSIFICATIONS:classifications" // options: { // [mediapipe.TensorsToClassificationCalculatorOptions.ext] { -// num_classes: 1024 // min_score_threshold: 0.1 // label_map_path: "labelmap.txt" // } @@ -72,22 +82,35 @@ class TensorsToClassificationCalculator : public Node { absl::Status Close(CalculatorContext* cc) override; private: - ::mediapipe::TensorsToClassificationCalculatorOptions options_; int top_k_ = 0; - absl::node_hash_map label_map_; + bool sort_by_descending_score_ = false; + proto_ns::Map local_label_map_; bool label_map_loaded_ = false; + bool is_binary_classification_ = false; + float min_score_threshold_ = std::numeric_limits::lowest(); + + // Set of allowed or ignored class indices. + struct ClassIndexSet { + absl::flat_hash_set values; + bool is_allowlist; + }; + // Allowed or ignored class indices based on provided options. + // These are used to filter out the output classification results. + ClassIndexSet class_index_set_; + bool IsClassIndexAllowed(int class_index); + const proto_ns::Map& GetLabelMap(CalculatorContext* cc); }; MEDIAPIPE_REGISTER_NODE(TensorsToClassificationCalculator); absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) { - options_ = - cc->Options<::mediapipe::TensorsToClassificationCalculatorOptions>(); + const auto& options = cc->Options(); - top_k_ = options_.top_k(); - if (options_.has_label_map_path()) { + top_k_ = options.top_k(); + sort_by_descending_score_ = options.sort_by_descending_score(); + if (options.has_label_map_path()) { std::string string_path; ASSIGN_OR_RETURN(string_path, - PathToResourceAsFile(options_.label_map_path())); + PathToResourceAsFile(options.label_map_path())); std::string label_map_string; MP_RETURN_IF_ERROR( mediapipe::GetResourceContents(string_path, &label_map_string)); @@ -96,18 +119,45 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) { std::string line; int i = 0; while (std::getline(stream, line)) { - label_map_[i++] = line; + LabelMapItem item; + item.set_name(line); + local_label_map_[i++] = item; } label_map_loaded_ = true; - } else if (options_.has_label_map()) { - for (int i = 0; i < options_.label_map().entries_size(); ++i) { - const auto& entry = options_.label_map().entries(i); - RET_CHECK(!label_map_.contains(entry.id())) + } else if (!options.label_items().empty()) { + label_map_loaded_ = true; + } else if (options.has_label_map()) { + for (int i = 0; i < options.label_map().entries_size(); ++i) { + const auto& entry = options.label_map().entries(i); + RET_CHECK(!local_label_map_.contains(entry.id())) << "Duplicate id found: " << entry.id(); - label_map_[entry.id()] = entry.label(); + LabelMapItem item; + item.set_name(entry.label()); + local_label_map_[entry.id()] = item; } label_map_loaded_ = true; } + if (options.has_min_score_threshold()) { + min_score_threshold_ = options.min_score_threshold(); + } + is_binary_classification_ = options.binary_classification(); + + if (is_binary_classification_) { + RET_CHECK(options.allow_classes().empty() && + options.ignore_classes().empty()); + } + if (!options.allow_classes().empty()) { + RET_CHECK(options.ignore_classes().empty()); + class_index_set_.is_allowlist = true; + for (int i = 0; i < options.allow_classes_size(); ++i) { + class_index_set_.values.insert(options.allow_classes(i)); + } + } else { + class_index_set_.is_allowlist = false; + for (int i = 0; i < options.ignore_classes_size(); ++i) { + class_index_set_.values.insert(options.ignore_classes(i)); + } + } return absl::OkStatus(); } @@ -118,19 +168,19 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { int num_classes = input_tensors[0].shape().num_elements(); - if (options_.binary_classification()) { + if (is_binary_classification_) { RET_CHECK_EQ(num_classes, 1); // Number of classes for binary classification. num_classes = 2; } if (label_map_loaded_) { - RET_CHECK_EQ(num_classes, label_map_.size()); + RET_CHECK_EQ(num_classes, GetLabelMap(cc).size()); } auto view = input_tensors[0].GetCpuReadView(); auto raw_scores = view.buffer(); auto classification_list = absl::make_unique(); - if (options_.binary_classification()) { + if (is_binary_classification_) { Classification* class_first = classification_list->add_classification(); Classification* class_second = classification_list->add_classification(); class_first->set_index(0); @@ -139,41 +189,48 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { class_second->set_score(1. - raw_scores[0]); if (label_map_loaded_) { - class_first->set_label(label_map_[0]); - class_second->set_label(label_map_[1]); + SetClassificationLabel(GetLabelMap(cc).at(0), class_first); + SetClassificationLabel(GetLabelMap(cc).at(1), class_second); } } else { for (int i = 0; i < num_classes; ++i) { - if (options_.has_min_score_threshold() && - raw_scores[i] < options_.min_score_threshold()) { + if (!IsClassIndexAllowed(i)) { + continue; + } + if (raw_scores[i] < min_score_threshold_) { continue; } Classification* classification = classification_list->add_classification(); classification->set_index(i); classification->set_score(raw_scores[i]); - if (label_map_loaded_) { - classification->set_label(label_map_[i]); + SetClassificationLabel(GetLabelMap(cc).at(i), classification); } } } - // Note that partial_sort will raise error when top_k_ > - // classification_list->classification_size(). - CHECK_GE(classification_list->classification_size(), top_k_); auto raw_classification_list = classification_list->mutable_classification(); - if (top_k_ > 0 && classification_list->classification_size() >= top_k_) { + if (top_k_ > 0) { + int desired_size = + std::min(classification_list->classification_size(), top_k_); std::partial_sort(raw_classification_list->begin(), - raw_classification_list->begin() + top_k_, + raw_classification_list->begin() + desired_size, raw_classification_list->end(), [](const Classification a, const Classification b) { return a.score() > b.score(); }); - // Resizes the underlying list to have only top_k_ classifications. - raw_classification_list->DeleteSubrange( - top_k_, raw_classification_list->size() - top_k_); + if (desired_size >= top_k_) { + // Resizes the underlying list to have only top_k_ classifications. + raw_classification_list->DeleteSubrange( + top_k_, raw_classification_list->size() - top_k_); + } + } else if (sort_by_descending_score_) { + std::sort(raw_classification_list->begin(), raw_classification_list->end(), + [](const Classification a, const Classification b) { + return a.score() > b.score(); + }); } kOutClassificationList(cc).Send(std::move(classification_list)); return absl::OkStatus(); @@ -183,5 +240,24 @@ absl::Status TensorsToClassificationCalculator::Close(CalculatorContext* cc) { return absl::OkStatus(); } +bool TensorsToClassificationCalculator::IsClassIndexAllowed(int class_index) { + if (class_index_set_.values.empty()) { + return true; + } + if (class_index_set_.is_allowlist) { + return class_index_set_.values.contains(class_index); + } else { + return !class_index_set_.values.contains(class_index); + } +} + +const proto_ns::Map& +TensorsToClassificationCalculator::GetLabelMap(CalculatorContext* cc) { + return !local_label_map_.empty() + ? local_label_map_ + : cc->Options() + .label_items(); +} + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto index 3934a61012..32bc4b63ae 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto @@ -19,6 +19,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; +import "mediapipe/util/label_map.proto"; message TensorsToClassificationCalculatorOptions { extend .mediapipe.CalculatorOptions { @@ -38,16 +39,37 @@ message TensorsToClassificationCalculatorOptions { // Number of highest scoring labels to output. If top_k is not positive then // all labels are used. optional int32 top_k = 2; + // Whether results should be sorted by descending score. By default, results + // may or may not be sorted: setting this to true guarantees that the returned + // results will be sorted by descending score. + optional bool sort_by_descending_score = 9; // Path to a label map file for getting the actual name of class ids. optional string label_map_path = 3; // Label map. (Can be used instead of label_map_path.) - // NOTE: "label_map_path", if specified, takes precedence over "label_map". + // NOTE: either "label_map_path" or "label_items", if specified, takes + // precedence over "label_map". + // Deprecated: please use `label_items` instead. optional LabelMap label_map = 5; + // Label items. (Can be used instead of label_map_path.) + // NOTE: "label_map_path", if specified, takes precedence over "label_items". + map label_items = 6; + // Whether the input is a single float for binary classification. // When true, only a single float is expected in the input tensor and the // label map, if provided, is expected to have exactly two labels. // The single score(float) represent the probability of first label, and // 1 - score is the probabilility of the second label. optional bool binary_classification = 4; + + // The ids of classes that should be ignored during decoding the score for + // each classification. If `ignore_classes` is specified, all the other + // classes that are not in the `ignore_class` field will be considered during + // decoding. `ignore_classes` and `allow_classes` are mutually exclusive. + repeated int32 ignore_classes = 7 [packed = true]; + // The ids of classes that will be allowed during decoding the score for + // each classification. If `allow_classes` is specified, all the other classes + // that are not in the `allow_classes` field will be completely ignored. + // `ignore_classes` and `allow_classes` are mutually exclusive. + repeated int32 allow_classes = 8 [packed = true]; } diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc index 92b20629d0..9634635f06 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include "absl/memory/memory.h" @@ -206,4 +208,119 @@ TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithTopK) { } } +TEST_F(TensorsToClassificationCalculatorTest, + CorrectOutputWithSortByDescendingScore) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToClassificationCalculator" + input_stream: "TENSORS:tensors" + output_stream: "CLASSIFICATIONS:classifications" + options { + [mediapipe.TensorsToClassificationCalculatorOptions.ext] { + sort_by_descending_score: true + } + } + )pb")); + + BuildGraph(&runner, {0, 0.5, 1}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets; + + EXPECT_EQ(1, output_packets_.size()); + + const auto& classification_list = + output_packets_[0].Get(); + + // Verify results are sorted by descending score. + EXPECT_EQ(3, classification_list.classification_size()); + float score = std::numeric_limits::max(); + for (int i = 0; i < classification_list.classification_size(); ++i) { + EXPECT_LE(classification_list.classification(i).score(), score); + score = classification_list.classification(i).score(); + } +} + +TEST_F(TensorsToClassificationCalculatorTest, + ClassNameAllowlistWithLabelItems) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToClassificationCalculator" + input_stream: "TENSORS:tensors" + output_stream: "CLASSIFICATIONS:classifications" + options { + [mediapipe.TensorsToClassificationCalculatorOptions.ext] { + label_items { + key: 0 + value { name: "ClassA" } + } + label_items { + key: 1 + value { name: "ClassB" } + } + label_items { + key: 2 + value { name: "ClassC" } + } + allow_classes: 1 + } + } + )pb")); + + BuildGraph(&runner, {0, 0.5, 1}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets; + + EXPECT_EQ(1, output_packets_.size()); + + const auto& classification_list = + output_packets_[0].Get(); + EXPECT_EQ(1, classification_list.classification_size()); + EXPECT_EQ(1, classification_list.classification(0).index()); + EXPECT_EQ(0.5, classification_list.classification(0).score()); + ASSERT_TRUE(classification_list.classification(0).has_label()); +} + +TEST_F(TensorsToClassificationCalculatorTest, + ClassNameIgnorelistWithLabelItems) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToClassificationCalculator" + input_stream: "TENSORS:tensors" + output_stream: "CLASSIFICATIONS:classifications" + options { + [mediapipe.TensorsToClassificationCalculatorOptions.ext] { + label_items { + key: 0 + value { name: "ClassA" } + } + label_items { + key: 1 + value { name: "ClassB" } + } + label_items { + key: 2 + value { name: "ClassC" } + } + ignore_classes: 1 + } + } + )pb")); + + BuildGraph(&runner, {0, 0.5, 1}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets; + + EXPECT_EQ(1, output_packets_.size()); + + const auto& classification_list = + output_packets_[0].Get(); + EXPECT_EQ(2, classification_list.classification_size()); + EXPECT_EQ(0, classification_list.classification(0).index()); + EXPECT_EQ(0, classification_list.classification(0).score()); + ASSERT_TRUE(classification_list.classification(0).has_label()); + EXPECT_EQ(2, classification_list.classification(1).index()); + EXPECT_EQ(1, classification_list.classification(1).score()); + ASSERT_TRUE(classification_list.classification(1).has_label()); +} + } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index a03a60189a..21f9838942 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -20,10 +20,8 @@ #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/image_opencv.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port.h" -#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/statusor.h" #include "mediapipe/gpu/gpu_origin.pb.h" @@ -37,6 +35,11 @@ #include "mediapipe/gpu/shader_util.h" #endif // !MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_OPENCV +#include "mediapipe/framework/formats/image_opencv.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#endif // !MEDIAPIPE_DISABLE_OPENCV + #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #include "tensorflow/lite/delegates/gpu/gl/converters/util.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h" @@ -159,9 +162,10 @@ class TensorsToSegmentationCalculator : public CalculatorBase { return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; } +#if !MEDIAPIPE_DISABLE_OPENCV template absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat); - +#endif // !MEDIAPIPE_DISABLE_OPENCV ::mediapipe::TensorsToSegmentationCalculatorOptions options_; #if !MEDIAPIPE_DISABLE_GPU @@ -283,7 +287,11 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) { RET_CHECK_FAIL() << "GPU processing disabled."; #endif // !MEDIAPIPE_DISABLE_GPU } else { +#if !MEDIAPIPE_DISABLE_OPENCV MP_RETURN_IF_ERROR(ProcessCpu(cc)); +#else + RET_CHECK_FAIL() << "OpenCV processing disabled."; +#endif // !MEDIAPIPE_DISABLE_OPENCV } return absl::OkStatus(); @@ -311,6 +319,7 @@ absl::Status TensorsToSegmentationCalculator::Close(CalculatorContext* cc) { absl::Status TensorsToSegmentationCalculator::ProcessCpu( CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_OPENCV // Get input streams, and dimensions. const auto& input_tensors = cc->Inputs().Tag(kTensorsTag).Get>(); @@ -360,10 +369,12 @@ absl::Status TensorsToSegmentationCalculator::ProcessCpu( cv::resize(small_mask_mat, *output_mat, cv::Size(output_width, output_height)); cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); +#endif // !MEDIAPIPE_DISABLE_OPENCV return absl::OkStatus(); } +#if !MEDIAPIPE_DISABLE_OPENCV template absl::Status TensorsToSegmentationCalculator::ApplyActivation( cv::Mat& tensor_mat, cv::Mat* small_mask_mat) { @@ -411,6 +422,7 @@ absl::Status TensorsToSegmentationCalculator::ApplyActivation( return absl::OkStatus(); } +#endif // !MEDIAPIPE_DISABLE_OPENCV // Steps: // 1. receive tensor diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc index 1db886a367..5eddd3c2e1 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc @@ -300,17 +300,26 @@ class TensorFlowInferenceCalculator : public CalculatorBase { RET_CHECK(options_.batch_size() == 1 || options_.recurrent_tag_pair().empty()) << "To use recurrent_tag_pairs, batch_size must be 1."; + + // Helper for StrJoin. Prints key (tag) of map. + auto TagFormatter = + absl::PairFormatter(absl::StreamFormatter(), "", + [](std::string* out, const std::string& second) {}); + for (const auto& tag_pair : options_.recurrent_tag_pair()) { const std::vector tags = absl::StrSplit(tag_pair, ':'); RET_CHECK_EQ(tags.size(), 2) << "recurrent_tag_pair must be a colon " "separated string with two components: " << tag_pair; + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0])) << "Can't find tag '" << tags[0] << "' in signature " - << options_.signature_name(); + << options_.signature_name() << "; instead found tags " + << absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter); RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1])) << "Can't find tag '" << tags[1] << "' in signature " - << options_.signature_name(); + << options_.signature_name() << " ; instead found tags " + << absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter); recurrent_feed_tags_.insert(tags[0]); recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0]; } @@ -319,12 +328,14 @@ class TensorFlowInferenceCalculator : public CalculatorBase { for (const std::string& tag : cc->Inputs().GetTags()) { RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) << "Can't find tag '" << tag << "' in signature " - << options_.signature_name(); + << options_.signature_name() << "; instead found tags " + << absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter); } for (const std::string& tag : cc->Outputs().GetTags()) { RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) << "Can't find tag '" << tag << "' in signature " - << options_.signature_name(); + << options_.signature_name() << "; instead found tags " + << absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter); } { diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc index cc1d15043a..70487b26eb 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc @@ -38,6 +38,9 @@ namespace mediapipe { +using ::testing::AllOf; +using ::testing::HasSubstr; + namespace tf = ::tensorflow; namespace { @@ -199,8 +202,8 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed) { auto run_status = runner_->Run(); ASSERT_FALSE(run_status.ok()); EXPECT_THAT(run_status.ToString(), - testing::HasSubstr("TensorFlowInferenceCalculator")); - EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Tag B")); + HasSubstr("TensorFlowInferenceCalculator")); + EXPECT_THAT(run_status.ToString(), HasSubstr("Tag B")); } TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) { @@ -238,8 +241,8 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) { auto run_status = runner_->Run(); ASSERT_FALSE(run_status.ok()); EXPECT_THAT(run_status.ToString(), - testing::HasSubstr("TensorFlowInferenceCalculator")); - EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Tag B")); + HasSubstr("TensorFlowInferenceCalculator")); + EXPECT_THAT(run_status.ToString(), HasSubstr("Tag B")); } TEST_F(TensorflowInferenceCalculatorTest, BadTag) { @@ -255,7 +258,12 @@ TEST_F(TensorflowInferenceCalculatorTest, BadTag) { runner_ = absl::make_unique(config); AddSessionInputSidePacket(); - ASSERT_FALSE(runner_->Run().ok()); + absl::Status status = runner_->Run(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + AllOf(HasSubstr("Can't find tag 'BAD' in signature"), + HasSubstr("instead found tags A, B, EXPENSIVE, MULTIPLIED"))); } TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) { @@ -740,7 +748,7 @@ TEST_F(TensorflowInferenceCalculatorTest, BatchedInputTooBigBatch) { ASSERT_FALSE(status.ok()); EXPECT_THAT( status.message(), - ::testing::HasSubstr( + HasSubstr( "has more packets than batch capacity. batch_size: 2 packets: 3")); } diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index d00fd09ff2..e4e6cf9128 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -301,6 +301,8 @@ cc_library( ":detection_label_id_to_text_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", "@com_google_absl//absl/container:node_hash_map", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:status", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc index 20d1c1cbd4..0b8dde20d5 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc @@ -16,6 +16,8 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/resource_util.h" @@ -55,9 +57,9 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase { private: // Local label map built from the calculator options' `label_map_path` or // `label` field. - LabelMap local_label_map_; + proto_ns::Map local_label_map_; bool keep_label_id_; - const LabelMap& GetLabelMap(CalculatorContext* cc); + const proto_ns::Map& GetLabelMap(CalculatorContext* cc); }; REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); @@ -72,13 +74,12 @@ absl::Status DetectionLabelIdToTextCalculator::GetContract( absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - const auto& options = - cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>(); + const auto& options = cc->Options(); if (options.has_label_map_path()) { - RET_CHECK(!options.has_label_map() && options.label().empty()) + RET_CHECK(options.label_items().empty() && options.label().empty()) << "Only can set one of the following fields in the CalculatorOptions: " - "label_map_path, label, and label_map."; + "label_map_path, label, and label_items."; std::string string_path; ASSIGN_OR_RETURN(string_path, PathToResourceAsFile(options.label_map_path())); @@ -91,16 +92,16 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { while (std::getline(stream, line)) { LabelMapItem item; item.set_name(line); - (*local_label_map_.mutable_index_to_item())[i++] = item; + local_label_map_[i++] = item; } } else if (!options.label().empty()) { - RET_CHECK(!options.has_label_map()) + RET_CHECK(options.label_items().empty()) << "Only can set one of the following fields in the CalculatorOptions: " - "label_map_path, label, and label_map."; + "label_map_path, label, and label_items."; for (int i = 0; i < options.label_size(); ++i) { LabelMapItem item; item.set_name(options.label(i)); - (*local_label_map_.mutable_index_to_item())[i] = item; + local_label_map_[i] = item; } } keep_label_id_ = options.keep_label_id(); @@ -115,9 +116,8 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) { Detection& output_detection = output_detections.back(); bool has_text_label = false; for (const int32 label_id : output_detection.label_id()) { - if (GetLabelMap(cc).index_to_item().find(label_id) != - GetLabelMap(cc).index_to_item().end()) { - auto item = GetLabelMap(cc).index_to_item().at(label_id); + if (GetLabelMap(cc).contains(label_id)) { + auto item = GetLabelMap(cc).at(label_id); output_detection.add_label(item.name()); if (item.has_display_name()) { output_detection.add_display_name(item.display_name()); @@ -136,13 +136,12 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } -const LabelMap& DetectionLabelIdToTextCalculator::GetLabelMap( - CalculatorContext* cc) { - return !local_label_map_.index_to_item().empty() +const proto_ns::Map& +DetectionLabelIdToTextCalculator::GetLabelMap(CalculatorContext* cc) { + return !local_label_map_.empty() ? local_label_map_ - : cc->Options< - ::mediapipe::DetectionLabelIdToTextCalculatorOptions>() - .label_map(); + : cc->Options() + .label_items(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto b/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto index bb1cf6098b..eedc5e7046 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto @@ -38,6 +38,6 @@ message DetectionLabelIdToTextCalculatorOptions { // output detections. optional bool keep_label_id = 3; - // Label map. - optional LabelMap label_map = 4; + // Identifying information for each classification label. + map label_items = 4; } diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 806b9f1fa5..53d9681516 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -426,6 +426,7 @@ cc_test( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:test_util", "@com_google_absl//absl/flags:flag", ], ) @@ -451,6 +452,7 @@ cc_test( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:test_util", "@com_google_absl//absl/flags:flag", ], ) @@ -534,6 +536,7 @@ cc_test( "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/stream_handler:sync_set_input_stream_handler", + "//mediapipe/framework/tool:test_util", "//mediapipe/util/tracking:box_tracker_cc_proto", "//mediapipe/util/tracking:tracking_cc_proto", "@com_google_absl//absl/flags:flag", diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc index 52905d837f..9e04f33cb5 100644 --- a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc @@ -120,7 +120,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { // back. To get correct image format, we read the first frame from the video // and get the number of channels. cv::Mat frame; - cap_->read(frame); + ReadFrame(frame); if (frame.empty()) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Fail to read any frames from the video file at " @@ -193,13 +193,13 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { Timestamp timestamp(cap_->get(cv::CAP_PROP_POS_MSEC) * 1000); if (format_ == ImageFormat::GRAY8) { cv::Mat frame = formats::MatView(image_frame.get()); - cap_->read(frame); + ReadFrame(frame); if (frame.empty()) { return tool::StatusStop(); } } else { cv::Mat tmp_frame; - cap_->read(tmp_frame); + ReadFrame(tmp_frame); if (tmp_frame.empty()) { return tool::StatusStop(); } @@ -234,6 +234,14 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { return absl::OkStatus(); } + // Sometimes an empty frame is returned even though there are more frames. + void ReadFrame(cv::Mat& frame) { + cap_->read(frame); + if (frame.empty()) { + cap_->read(frame); // Try again. + } + } + private: std::unique_ptr cap_; int width_; diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc index 035e5a8c9d..2a6d66f159 100644 --- a/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc @@ -24,6 +24,7 @@ #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/test_util.h" namespace mediapipe { @@ -32,6 +33,7 @@ namespace { constexpr char kVideoTag[] = "VIDEO"; constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH"; +constexpr char kTestPackageRoot[] = "mediapipe/calculators/video"; TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) { CalculatorGraphConfig::Node node_config = @@ -41,10 +43,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) { output_stream: "VIDEO:video" output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); CalculatorRunner runner(node_config); - runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket( - file::JoinPath("./", - "/mediapipe/calculators/video/" - "testdata/format_MP4_AVC720P_AAC.video")); + runner.MutableSidePackets()->Tag(kInputFilePathTag) = + MakePacket(file::JoinPath(GetTestDataDir(kTestPackageRoot), + "format_MP4_AVC720P_AAC.video")); MP_EXPECT_OK(runner.Run()); EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); @@ -87,10 +88,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) { output_stream: "VIDEO:video" output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); CalculatorRunner runner(node_config); - runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket( - file::JoinPath("./", - "/mediapipe/calculators/video/" - "testdata/format_FLV_H264_AAC.video")); + runner.MutableSidePackets()->Tag(kInputFilePathTag) = + MakePacket(file::JoinPath(GetTestDataDir(kTestPackageRoot), + "format_FLV_H264_AAC.video")); MP_EXPECT_OK(runner.Run()); EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); @@ -131,10 +131,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) { output_stream: "VIDEO:video" output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); CalculatorRunner runner(node_config); - runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket( - file::JoinPath("./", - "/mediapipe/calculators/video/" - "testdata/format_MKV_VP8_VORBIS.video")); + runner.MutableSidePackets()->Tag(kInputFilePathTag) = + MakePacket(file::JoinPath(GetTestDataDir(kTestPackageRoot), + "format_MKV_VP8_VORBIS.video")); MP_EXPECT_OK(runner.Run()); EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc b/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc index 7e946d82d8..b30c785d7d 100644 --- a/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc @@ -28,10 +28,14 @@ #include "mediapipe/framework/port/opencv_video_inc.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/test_util.h" namespace mediapipe { namespace { + +constexpr char kTestPackageRoot[] = "mediapipe/calculators/video"; + // Temporarily disable the test. // TODO: Investigate the “Could not open codec 'libx264'” error with // opencv2. @@ -59,10 +63,9 @@ TEST(OpenCvVideoEncoderCalculatorTest, DISABLED_TestMp4Avc720pVideo) { } )pb"); std::map input_side_packets; - input_side_packets["input_file_path"] = MakePacket( - file::JoinPath("./", - "/mediapipe/calculators/video/" - "testdata/format_MP4_AVC720P_AAC.video")); + input_side_packets["input_file_path"] = + MakePacket(file::JoinPath(GetTestDataDir(kTestPackageRoot), + "format_MP4_AVC720P_AAC.video")); const std::string output_file_path = "/tmp/tmp_video.mp4"; DeletingFile deleting_file(output_file_path, true); input_side_packets["output_file_path"] = @@ -120,10 +123,9 @@ TEST(OpenCvVideoEncoderCalculatorTest, TestFlvH264Video) { } )pb"); std::map input_side_packets; - input_side_packets["input_file_path"] = MakePacket( - file::JoinPath("./", - "/mediapipe/calculators/video/" - "testdata/format_FLV_H264_AAC.video")); + input_side_packets["input_file_path"] = + MakePacket(file::JoinPath(GetTestDataDir(kTestPackageRoot), + "format_FLV_H264_AAC.video")); const std::string output_file_path = "/tmp/tmp_video.avi"; DeletingFile deleting_file(output_file_path, true); input_side_packets["output_file_path"] = @@ -183,10 +185,9 @@ TEST(OpenCvVideoEncoderCalculatorTest, TestMkvVp8Video) { } )pb"); std::map input_side_packets; - input_side_packets["input_file_path"] = MakePacket( - file::JoinPath("./", - "/mediapipe/calculators/video/" - "testdata/format_MKV_VP8_VORBIS.video")); + input_side_packets["input_file_path"] = + MakePacket(file::JoinPath(GetTestDataDir(kTestPackageRoot), + "format_MKV_VP8_VORBIS.video")); const std::string output_file_path = "/tmp/tmp_video.mkv"; DeletingFile deleting_file(output_file_path, true); input_side_packets["output_file_path"] = diff --git a/mediapipe/calculators/video/tracking_graph_test.cc b/mediapipe/calculators/video/tracking_graph_test.cc index 6516bd7da3..d6529da1df 100644 --- a/mediapipe/calculators/video/tracking_graph_test.cc +++ b/mediapipe/calculators/video/tracking_graph_test.cc @@ -33,39 +33,16 @@ #include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/test_util.h" #include "mediapipe/util/tracking/box_tracker.pb.h" #include "mediapipe/util/tracking/tracking.pb.h" -#ifdef __APPLE__ -#include -#endif // defined(__APPLE__) - namespace mediapipe { namespace { using ::testing::FloatNear; using ::testing::Test; -std::string GetTestDir() { -#ifdef __APPLE__ - char path[1024]; - CFURLRef bundle_url = CFBundleCopyBundleURL(CFBundleGetMainBundle()); - CFURLGetFileSystemRepresentation( - bundle_url, true, reinterpret_cast(path), sizeof(path)); - CFRelease(bundle_url); - return mediapipe::file::JoinPath(path, "testdata"); -#elif defined(__ANDROID__) - char path[1024]; - getcwd(path, sizeof(path)); - return mediapipe::file::JoinPath(path, - "mediapipe/calculators/video/testdata"); -#else - return mediapipe::file::JoinPath( - "./", - // This should match the path of the output files - // of the genrule() that generates test model files. - "mediapipe/calculators/video/testdata"); -#endif // defined(__APPLE__) -} +constexpr char kTestPackageRoot[] = "mediapipe/calculators/video"; bool LoadBinaryTestGraph(const std::string& graph_path, CalculatorGraphConfig* config) { @@ -85,7 +62,7 @@ class TrackingGraphTest : public Test { TrackingGraphTest() {} void SetUp() override { - test_dir_ = GetTestDir(); + test_dir_ = mediapipe::GetTestDataDir(kTestPackageRoot); const auto graph_path = file::JoinPath(test_dir_, "tracker.binarypb"); ASSERT_TRUE(LoadBinaryTestGraph(graph_path, &config_)); diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java index df18471786..ef350f71db 100644 --- a/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java +++ b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java @@ -15,10 +15,10 @@ package com.google.mediapipe.examples.facedetection; import android.opengl.GLES20; -import com.google.mediapipe.formats.proto.DetectionProto.Detection; import com.google.mediapipe.solutioncore.ResultGlRenderer; import com.google.mediapipe.solutions.facedetection.FaceDetectionResult; import com.google.mediapipe.solutions.facedetection.FaceKeypoint; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java index 3da3a467a8..f82692799d 100644 --- a/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java +++ b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java @@ -23,9 +23,9 @@ import android.graphics.Matrix; import android.graphics.Paint; import androidx.appcompat.widget.AppCompatImageView; -import com.google.mediapipe.formats.proto.DetectionProto.Detection; import com.google.mediapipe.solutions.facedetection.FaceDetectionResult; import com.google.mediapipe.solutions.facedetection.FaceKeypoint; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; /** An ImageView implementation for displaying {@link FaceDetectionResult}. */ public class FaceDetectionResultImageView extends AppCompatImageView { diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc index 0cc9232e70..585bddbcd1 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc @@ -279,35 +279,45 @@ mediapipe::autoflip::RectF ShiftDetection( } absl::Status UpdateRanges(const SalientRegion& region, const float shift_vertical, - const float shift_horizontal, float* xmin, - float* xmax, float* ymin, float* ymax) { + const float shift_horizontal, + const float pad_vertical, const float pad_horizontal, + float* xmin, float* xmax, float* ymin, float* ymax) { if (!region.has_location_normalized()) { return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "SalientRegion did not have location normalized set."; } auto location = ShiftDetection(region.location_normalized(), shift_vertical, shift_horizontal); - *xmin = fmin(*xmin, location.x()); - *xmax = fmax(*xmax, location.x() + location.width()); - *ymin = fmin(*ymin, location.y()); - *ymax = fmax(*ymax, location.y() + location.height()); + + const float x_padding = pad_horizontal * location.width(); + const float y_padding = pad_vertical * location.height(); + + *xmin = fmin(*xmin, location.x() - x_padding); + *xmax = fmax(*xmax, location.x() + location.width() + x_padding); + *ymin = fmin(*ymin, location.y() - y_padding); + *ymax = fmax(*ymax, location.y() + location.height() + y_padding); return absl::OkStatus(); } absl::Status UpdateRanges(const mediapipe::Detection& detection, const float shift_vertical, - const float shift_horizontal, float* xmin, - float* xmax, float* ymin, float* ymax) { + const float shift_horizontal, + const float pad_vertical, const float pad_horizontal, + float* xmin, float* xmax, float* ymin, float* ymax) { RET_CHECK(detection.location_data().format() == mediapipe::LocationData::RELATIVE_BOUNDING_BOX) << "Face detection input is lacking required relative_bounding_box()"; const auto& location = ShiftDetection(detection.location_data().relative_bounding_box(), shift_vertical, shift_horizontal); - *xmin = fmin(*xmin, location.xmin()); - *xmax = fmax(*xmax, location.xmin() + location.width()); - *ymin = fmin(*ymin, location.ymin()); - *ymax = fmax(*ymax, location.ymin() + location.height()); + + const float x_padding = pad_horizontal * location.width(); + const float y_padding = pad_vertical * location.height(); + + *xmin = fmin(*xmin, location.xmin() - x_padding); + *xmax = fmax(*xmax, location.xmin() + location.width() + x_padding); + *ymin = fmin(*ymin, location.ymin() - y_padding); + *ymax = fmax(*ymax, location.ymin() + location.height() + y_padding); return absl::OkStatus(); } @@ -818,7 +828,9 @@ absl::Status ContentZoomingCalculator::GetDetectionsBox( *only_required_found = true; MP_RETURN_IF_ERROR(UpdateRanges( region, options_.detection_shift_vertical(), - options_.detection_shift_horizontal(), xmin, xmax, ymin, ymax)); + options_.detection_shift_horizontal(), + options_.extra_vertical_padding(), + options_.extra_horizontal_padding(), xmin, xmax, ymin, ymax)); } } @@ -864,7 +876,9 @@ absl::Status ContentZoomingCalculator::GetDetectionsBox( *only_required_found = true; MP_RETURN_IF_ERROR(UpdateRanges( detection, options_.detection_shift_vertical(), - options_.detection_shift_horizontal(), xmin, xmax, ymin, ymax)); + options_.detection_shift_horizontal(), + options_.extra_vertical_padding(), + options_.extra_horizontal_padding(), xmin, xmax, ymin, ymax)); } } } diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto index 124551304f..1d08fe8128 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto @@ -19,7 +19,7 @@ package mediapipe.autoflip; import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto"; import "mediapipe/framework/calculator.proto"; -// NextTag: 19 +// NextTag: 21 message ContentZoomingCalculatorOptions { extend mediapipe.CalculatorOptions { optional ContentZoomingCalculatorOptions ext = 313091992; @@ -45,12 +45,17 @@ message ContentZoomingCalculatorOptions { optional int64 height = 2; } optional Size target_size = 8; - // Amount to shift an input detection as a ratio of the size (positive: + + // Amount to shift an input detection, as a ratio of its size (positive: // down/right, negative: up/left). Use a negative value to increase padding // above/left of an object, positive to increase padding below/right of an - // object. + // object. (Applies to one side only) optional float detection_shift_vertical = 11 [default = 0.0]; optional float detection_shift_horizontal = 12 [default = 0.0]; + // Amount to pad around an input detection, as a ratio of its size. + // (Applies to both sides) + optional float extra_vertical_padding = 19 [default = 0.0]; + optional float extra_horizontal_padding = 20 [default = 0.0]; // Defines the smallest value in degrees the camera is permitted to zoom. optional float max_zoom_value_deg = 13 [default = 35]; diff --git a/mediapipe/examples/ios/common/BUILD b/mediapipe/examples/ios/common/BUILD index 0f3d34cd1c..8db4699a5d 100644 --- a/mediapipe/examples/ios/common/BUILD +++ b/mediapipe/examples/ios/common/BUILD @@ -35,7 +35,9 @@ objc_library( "CoreMedia", "UIKit", ], - visibility = ["//mediapipe:__subpackages__"], + visibility = [ + "//mediapipe:__subpackages__", + ], deps = [ "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 1166c2a33d..befc6809d0 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -115,7 +115,10 @@ mediapipe_proto_library( name = "packet_test_proto", testonly = 1, srcs = ["packet_test.proto"], - visibility = ["//mediapipe/framework:__subpackages__"], + visibility = [ + ":mediapipe_internal", + "//mediapipe/framework:__subpackages__", + ], ) mediapipe_proto_library( @@ -973,6 +976,7 @@ cc_library( ], }), visibility = [ + "//fitbit/research/sensing/mobisense:__subpackages__", "//mediapipe/calculators:__subpackages__", "//mediapipe/framework:__subpackages__", "//mediapipe/framework/port:__pkg__", @@ -1427,6 +1431,7 @@ cc_test( "//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler", "//mediapipe/framework/tool:sink", "//mediapipe/framework/tool:status_util", + "//mediapipe/gpu:graph_support", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/api2/BUILD b/mediapipe/framework/api2/BUILD index 73a3e5e5d7..768fc86c86 100644 --- a/mediapipe/framework/api2/BUILD +++ b/mediapipe/framework/api2/BUILD @@ -149,6 +149,7 @@ cc_library( "//mediapipe/framework:calculator_contract", "//mediapipe/framework:output_side_packet", "//mediapipe/framework/port:logging", + "//mediapipe/framework/tool:type_util", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index fc74ba6097..8d1e79c6fc 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -27,41 +27,34 @@ #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/output_side_packet.h" #include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/tool/type_util.h" namespace mediapipe { namespace api2 { -// typeid is not constexpr, but a pointer to this is. -template -size_t get_type_hash() { - return typeid(T).hash_code(); -} - -using type_id_fptr = size_t (*)(); - // This is a base class for various types of port. It is not meant to be used // directly by node code. class PortBase { public: - constexpr PortBase(std::size_t tag_size, const char* tag, - type_id_fptr get_type_id, bool optional, bool multiple) + constexpr PortBase(std::size_t tag_size, const char* tag, TypeId type_id, + bool optional, bool multiple) : tag_(tag_size, tag), optional_(optional), multiple_(multiple), - type_id_getter_(get_type_id) {} + type_id_(type_id) {} bool IsOptional() const { return optional_; } bool IsMultiple() const { return multiple_; } const char* Tag() const { return tag_.data(); } - size_t type_id() const { return type_id_getter_(); } + TypeId type_id() const { return type_id_; } const const_str tag_; const bool optional_; const bool multiple_; protected: - type_id_fptr type_id_getter_; + TypeId type_id_; }; // These four base classes are used to distinguish between ports of different @@ -340,7 +333,7 @@ class PortCommon : public Base { template explicit constexpr PortCommon(const char (&tag)[N]) - : Base(N, tag, &get_type_hash, IsOptionalV, IsMultipleV) {} + : Base(N, tag, kTypeId, IsOptionalV, IsMultipleV) {} using PayloadT = ActualPayloadT; @@ -428,7 +421,7 @@ class SideFallbackT : public Base { template explicit constexpr SideFallbackT(const char (&tag)[N]) - : Base(N, tag, &get_type_hash, IsOptionalV, IsMultipleV), + : Base(N, tag, kTypeId, IsOptionalV, IsMultipleV), stream_port(tag), side_port(tag) {} diff --git a/mediapipe/framework/api2/port_test.cc b/mediapipe/framework/api2/port_test.cc index c09e38452e..6676e44f0a 100644 --- a/mediapipe/framework/api2/port_test.cc +++ b/mediapipe/framework/api2/port_test.cc @@ -8,7 +8,7 @@ namespace { TEST(PortTest, IntInput) { static constexpr auto port = Input("FOO"); - EXPECT_EQ(port.type_id(), typeid(int).hash_code()); + EXPECT_EQ(port.type_id(), kTypeId); } TEST(PortTest, OptionalInput) { diff --git a/mediapipe/framework/calculator_contract.h b/mediapipe/framework/calculator_contract.h index fd0507becb..2162f84e78 100644 --- a/mediapipe/framework/calculator_contract.h +++ b/mediapipe/framework/calculator_contract.h @@ -59,7 +59,7 @@ class CalculatorContract { const CalculatorOptions& Options() const { return node_config_->options(); } // Returns the name given to this node. - const std::string& GetNodeName() { return node_name_; } + const std::string& GetNodeName() const { return node_name_; } // Returns the options given to this calculator. Template argument T must // be the type of the protobuf extension message or the protobuf::Any diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index 4f6755364e..c17a2e1e2c 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -120,10 +120,10 @@ CalculatorGraph::CalculatorGraph() counter_factory_ = absl::make_unique(); } -CalculatorGraph::CalculatorGraph(const CalculatorGraphConfig& config) +CalculatorGraph::CalculatorGraph(CalculatorGraphConfig config) : CalculatorGraph() { counter_factory_ = absl::make_unique(); - MEDIAPIPE_CHECK_OK(Initialize(config)); + MEDIAPIPE_CHECK_OK(Initialize(std::move(config))); } // Defining the destructor here lets us use incomplete types in the header; @@ -429,18 +429,17 @@ absl::Status CalculatorGraph::Initialize( return absl::OkStatus(); } -absl::Status CalculatorGraph::Initialize( - const CalculatorGraphConfig& input_config) { - return Initialize(input_config, {}); +absl::Status CalculatorGraph::Initialize(CalculatorGraphConfig input_config) { + return Initialize(std::move(input_config), {}); } absl::Status CalculatorGraph::Initialize( - const CalculatorGraphConfig& input_config, + CalculatorGraphConfig input_config, const std::map& side_packets) { auto validated_graph = absl::make_unique(); MP_RETURN_IF_ERROR(validated_graph->Initialize( - input_config, /*graph_registry=*/nullptr, /*graph_options=*/nullptr, - &service_manager_)); + std::move(input_config), /*graph_registry=*/nullptr, + /*graph_options=*/nullptr, &service_manager_)); return Initialize(std::move(validated_graph), side_packets); } @@ -675,6 +674,7 @@ absl::Status CalculatorGraph::PrepareForRun( #endif // !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(PrepareServices()); #if !MEDIAPIPE_DISABLE_GPU + // TODO: should we do this on each run, or only once? MP_RETURN_IF_ERROR(PrepareGpu()); additional_side_packets = MaybeCreateLegacyGpuSidePacket(legacy_sp); #endif // !MEDIAPIPE_DISABLE_GPU @@ -1251,7 +1251,9 @@ void CalculatorGraph::Resume() { scheduler_.Resume(); } absl::Status CalculatorGraph::SetExecutorInternal( const std::string& name, std::shared_ptr executor) { - if (!executors_.emplace(name, executor).second) { + auto [it, inserted] = executors_.emplace(name, executor); + if (!inserted) { + if (it->second == executor) return absl::OkStatus(); return mediapipe::AlreadyExistsErrorBuilder(MEDIAPIPE_LOC) << "SetExecutor must be called only once for the executor \"" << name << "\""; diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index 406317fb98..c514761022 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -119,17 +119,17 @@ class CalculatorGraph { // Initializes the graph from its proto description (using Initialize()) // and crashes if something goes wrong. - explicit CalculatorGraph(const CalculatorGraphConfig& config); + explicit CalculatorGraph(CalculatorGraphConfig config); virtual ~CalculatorGraph(); // Initializes the graph from a its proto description. // side_packets that are provided at this stage are common across all Run() // invocations and could be used to execute PacketGenerators immediately. - absl::Status Initialize(const CalculatorGraphConfig& config, + absl::Status Initialize(CalculatorGraphConfig config, const std::map& side_packets); // Convenience version which does not take side packets. - absl::Status Initialize(const CalculatorGraphConfig& config); + absl::Status Initialize(CalculatorGraphConfig config); // Initializes the CalculatorGraph from the specified graph and subgraph // configs. Template graph and subgraph configs can be specified through @@ -272,7 +272,6 @@ class CalculatorGraph { absl::Status CloseInputStream(const std::string& stream_name); // Closes all the graph input streams. - // TODO: deprecate this function in favor of CloseAllPacketSources. absl::Status CloseAllInputStreams(); // Closes all the graph input streams and source calculator nodes. diff --git a/mediapipe/framework/calculator_graph_test.cc b/mediapipe/framework/calculator_graph_test.cc index af3655c221..f982400cf6 100644 --- a/mediapipe/framework/calculator_graph_test.cc +++ b/mediapipe/framework/calculator_graph_test.cc @@ -60,6 +60,7 @@ #include "mediapipe/framework/tool/sink.h" #include "mediapipe/framework/tool/status_util.h" #include "mediapipe/framework/type_map.h" +#include "mediapipe/gpu/graph_support.h" namespace mediapipe { @@ -2059,6 +2060,26 @@ TEST(CalculatorGraph, HandlersRun) { input_side_packets.at("unavailable_input_counter2"))); } +TEST(CalculatorGraph, CalculatorGraphConfigCopyElision) { + CalculatorGraph graph; + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: 'in' + node { + calculator: 'PassThroughCalculator' + input_stream: 'in' + output_stream: 'out' + } + )pb"); + // config is consumed and never copied, which avoid copying data. + MP_ASSERT_OK(graph.Initialize(std::move(config))); + MP_EXPECT_OK(graph.StartRun({})); + MP_EXPECT_OK( + graph.AddPacketToInputStream("in", MakePacket(1).At(Timestamp(1)))); + MP_EXPECT_OK(graph.CloseInputStream("in")); + MP_EXPECT_OK(graph.WaitUntilDone()); +} + // Test that calling SetOffset() in Calculator::Process() results in the // absl::StatusCode::kFailedPrecondition error. TEST(CalculatorGraph, SetOffsetInProcess) { diff --git a/mediapipe/framework/calculator_profile.proto b/mediapipe/framework/calculator_profile.proto index 723178ae4b..066d433d6b 100644 --- a/mediapipe/framework/calculator_profile.proto +++ b/mediapipe/framework/calculator_profile.proto @@ -11,10 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// -// Forked from mediapipe/framework/calculator_profile.proto. -// The forked proto must remain identical to the original proto and should be -// ONLY used by mediapipe open source project. syntax = "proto2"; @@ -24,6 +20,7 @@ import "mediapipe/framework/calculator.proto"; option java_package = "com.google.mediapipe.proto"; option java_outer_classname = "CalculatorProfileProto"; +option objc_class_prefix = "MediaPipe"; // Stores the profiling information. // diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 27052fcf49..d1e544e681 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -88,7 +88,10 @@ cc_library( testonly = True, hdrs = ["message_matchers.h"], # Use this library through "mediapipe/framework/port:gtest_main". - visibility = ["//mediapipe/framework/port:__pkg__"], + visibility = [ + "//mediapipe/framework/port:__pkg__", + "//third_party/visionai/algorithms/tracking:__pkg__", + ], deps = [ "//mediapipe/framework/port:core_proto", "@com_google_googletest//:gtest", @@ -137,7 +140,6 @@ cc_library( hdrs = ["image_resizer.h"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:opencv_imgproc", ], ) diff --git a/mediapipe/framework/deps/image_resizer.h b/mediapipe/framework/deps/image_resizer.h index 6e1215a698..e8c541a3cf 100644 --- a/mediapipe/framework/deps/image_resizer.h +++ b/mediapipe/framework/deps/image_resizer.h @@ -15,7 +15,6 @@ #ifndef MEDIAPIPE_DEPS_IMAGE_RESIZER_H_ #define MEDIAPIPE_DEPS_IMAGE_RESIZER_H_ -#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" namespace mediapipe { diff --git a/mediapipe/framework/encode_binary_proto.bzl b/mediapipe/framework/encode_binary_proto.bzl index 3af435f759..6807a94f9f 100644 --- a/mediapipe/framework/encode_binary_proto.bzl +++ b/mediapipe/framework/encode_binary_proto.bzl @@ -140,9 +140,22 @@ _encode_binary_proto = rule( ) def encode_binary_proto(name, input, message_type, deps, **kwargs): + if type(input) == type("string"): + input_label = input + textproto_srcs = [input] + elif type(input) == type(dict()): + # We cannot accept a select, as macros are unable to manipulate selects. + input_label = select(input) + srcs_dict = dict() + for k, v in input.items(): + srcs_dict[k] = [v] + textproto_srcs = select(srcs_dict) + else: + fail("input should be a string or a dict, got %s" % input) + _encode_binary_proto( name = name, - input = input, + input = input_label, message_type = message_type, deps = deps, **kwargs diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 2362692603..c47c25a949 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -448,6 +448,7 @@ cc_library( srcs = [ "tensor.cc", + "tensor_ahwb.cc", ], hdrs = ["tensor.h"], copts = select({ @@ -463,6 +464,9 @@ cc_library( "-framework MetalKit", ], "//conditions:default": [], + "//mediapipe:android": [ + "-landroid", + ], }), visibility = ["//visibility:public"], deps = [ diff --git a/mediapipe/framework/formats/body_rig.proto b/mediapipe/framework/formats/body_rig.proto index 094b6cb81b..5420ccc10c 100644 --- a/mediapipe/framework/formats/body_rig.proto +++ b/mediapipe/framework/formats/body_rig.proto @@ -19,7 +19,9 @@ package mediapipe; // Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of // the joint and its visibility. message Joint { - // Joint rotation in 6D contineous representation. + // Joint rotation in 6D contineous representation ordered as + // [a1, b1, a2, b2, a3, b3]. + // // Such representation is more sutable for NN model training and can be // converted to quaternions and Euler angles if needed. Details can be found // in https://arxiv.org/abs/1812.07035. diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index 9212060aa1..b028ee8c04 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -20,6 +20,9 @@ #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/logging.h" +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 +#include "mediapipe/gpu/gl_base.h" +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_METAL_ENABLED #include @@ -319,28 +322,41 @@ void Tensor::AllocateOpenGlTexture2d() const { Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { LOG_IF(FATAL, valid_ == kValidNone) << "Tensor must be written prior to read from."; - LOG_IF(FATAL, !(valid_ & (kValidCpu | kValidOpenGlBuffer))) - << "Tensor conversion between different GPU resources is not supported " - "yet."; + LOG_IF(FATAL, !(valid_ & (kValidCpu | +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + kValidAHardwareBuffer | +#endif // MEDIAPIPE_TENSOR_USE_AHWB + kValidOpenGlBuffer))) + << "Tensor conversion between different GPU resources is not supported."; auto lock(absl::make_unique(&view_mutex_)); AllocateOpenGlBuffer(); if (!(valid_ & kValidOpenGlBuffer)) { - glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); - void* ptr = - glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), - GL_MAP_INVALIDATE_BUFFER_BIT | GL_MAP_WRITE_BIT); - std::memcpy(ptr, cpu_buffer_, bytes()); - glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); + // If the call succeds then AHWB -> SSBO are synchronized so any usage of + // the SSBO is correct after this call. + if (!InsertAhwbToSsboFence()) { + glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); + void* ptr = + glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), + GL_MAP_INVALIDATE_BUFFER_BIT | GL_MAP_WRITE_BIT); + std::memcpy(ptr, cpu_buffer_, bytes()); + glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); + } valid_ |= kValidOpenGlBuffer; } - return {opengl_buffer_, std::move(lock)}; + return {opengl_buffer_, std::move(lock), +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + &ssbo_read_ +#else + nullptr +#endif // MEDIAPIPE_TENSOR_USE_AHWB + }; } Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView() const { auto lock(absl::make_unique(&view_mutex_)); AllocateOpenGlBuffer(); valid_ = kValidOpenGlBuffer; - return {opengl_buffer_, std::move(lock)}; + return {opengl_buffer_, std::move(lock), nullptr}; } void Tensor::AllocateOpenGlBuffer() const { @@ -349,7 +365,10 @@ void Tensor::AllocateOpenGlBuffer() const { LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread."; glGenBuffers(1, &opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); - glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY); + if (!AllocateAhwbMapToSsbo()) { + glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY); + } + glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); } } #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 @@ -377,6 +396,8 @@ void Tensor::Move(Tensor* src) { src->metal_buffer_ = nil; #endif // MEDIAPIPE_METAL_ENABLED + MoveAhwbStuff(src); + #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 gl_context_ = std::move(src->gl_context_); frame_buffer_ = src->frame_buffer_; @@ -395,27 +416,31 @@ void Tensor::Move(Tensor* src) { Tensor::Tensor(ElementType element_type, const Shape& shape) : element_type_(element_type), shape_(shape) {} +#if MEDIAPIPE_METAL_ENABLED +void Tensor::Invalidate() { + absl::MutexLock lock(&view_mutex_); + // If memory is allocated and not owned by the metal buffer. + // TODO: Re-design cpu buffer memory management. + if (cpu_buffer_ && !metal_buffer_) { + DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes())); + } + metal_buffer_ = nil; + cpu_buffer_ = nullptr; +} + +#else + void Tensor::Invalidate() { #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 GLuint cleanup_gl_tex = GL_INVALID_INDEX; GLuint cleanup_gl_fb = GL_INVALID_INDEX; +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 GLuint cleanup_gl_buf = GL_INVALID_INDEX; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 { absl::MutexLock lock(&view_mutex_); -#if MEDIAPIPE_METAL_ENABLED - // If memory is allocated and not owned by the metal buffer. - // TODO: Re-design cpu buffer memory management. - if (cpu_buffer_ && !metal_buffer_) { - DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes())); - } - metal_buffer_ = nil; -#else - if (cpu_buffer_) { - free(cpu_buffer_); - } -#endif // MEDIAPIPE_METAL_ENABLED - cpu_buffer_ = nullptr; + ReleaseAhwbStuff(); // Don't need to wait for the resource to be deleted bacause if will be // released on last reference deletion inside the OpenGL driver. @@ -429,28 +454,44 @@ void Tensor::Invalidate() { } // Do not hold the view mutex while invoking GlContext::RunWithoutWaiting, // since that method may acquire the context's own lock. -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 - if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX || - cleanup_gl_buf != GL_INVALID_INDEX) - gl_context_->RunWithoutWaiting([cleanup_gl_tex, cleanup_gl_fb #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - , - cleanup_gl_buf -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - ]() { + if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX || + cleanup_gl_buf != GL_INVALID_INDEX) { + gl_context_->RunWithoutWaiting( + [cleanup_gl_tex, cleanup_gl_fb, cleanup_gl_buf]() { + glDeleteTextures(1, &cleanup_gl_tex); + glDeleteFramebuffers(1, &cleanup_gl_fb); + glDeleteBuffers(1, &cleanup_gl_buf); + }); + } +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX) { + gl_context_->RunWithoutWaiting([cleanup_gl_tex, cleanup_gl_fb]() { glDeleteTextures(1, &cleanup_gl_tex); glDeleteFramebuffers(1, &cleanup_gl_fb); -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - glDeleteBuffers(1, &cleanup_gl_buf); -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 }); -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + } +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + + if (cpu_buffer_) { + free(cpu_buffer_); + } + cpu_buffer_ = nullptr; } +#endif // MEDIAPIPE_METAL_ENABLED Tensor::CpuReadView Tensor::GetCpuReadView() const { auto lock = absl::make_unique(&view_mutex_); LOG_IF(FATAL, valid_ == kValidNone) << "Tensor must be written prior to read from."; +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + void* ptr = MapAhwbToCpuRead(); + if (ptr) { + valid_ |= kValidCpu; + return {ptr, ahwb_, nullptr, std::move(lock)}; + } +#endif // MEDIAPIPE_TENSOR_USE_AHWB + AllocateCpuBuffer(); if (!(valid_ & kValidCpu)) { // GPU-to-CPU synchronization and read-back. @@ -512,18 +553,33 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 valid_ |= kValidCpu; } +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + return {cpu_buffer_, nullptr, nullptr, std::move(lock)}; +#else return {cpu_buffer_, std::move(lock)}; +#endif // MEDIAPIPE_TENSOR_USE_AHWB } Tensor::CpuWriteView Tensor::GetCpuWriteView() const { auto lock = absl::make_unique(&view_mutex_); AllocateCpuBuffer(); valid_ = kValidCpu; +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + void* ptr = MapAhwbToCpuWrite(); + if (ptr) { + return {ptr, ahwb_, &fence_fd_, std::move(lock)}; + } + return {cpu_buffer_, nullptr, nullptr, std::move(lock)}; +#else return {cpu_buffer_, std::move(lock)}; +#endif // MEDIAPIPE_TENSOR_USE_AHWB } void Tensor::AllocateCpuBuffer() const { if (!cpu_buffer_) { +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + if (AllocateAHardwareBuffer()) return; +#endif // MEDIAPIPE_TENSOR_USE_AHWB #if MEDIAPIPE_METAL_ENABLED cpu_buffer_ = AllocateVirtualMemory(bytes()); #else @@ -532,4 +588,10 @@ void Tensor::AllocateCpuBuffer() const { } } +void Tensor::SetPreferredStorageType(StorageType type) { +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + use_ahwb_ = type == StorageType::kAhwb; +#endif // MEDIAPIPE_TENSOR_USE_AHWB +} + } // namespace mediapipe diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index adb7dca6ea..d60052affc 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -30,6 +30,16 @@ #import #endif // MEDIAPIPE_METAL_ENABLED +#ifdef MEDIAPIPE_TENSOR_USE_AHWB +#if __ANDROID_API__ < 26 +#error MEDIAPIPE_TENSOR_USE_AHWB requires NDK version 26 or higher to be specified. +#endif // __ANDROID_API__ < 26 +#include + +#include "third_party/GL/gl/include/EGL/egl.h" +#include "third_party/GL/gl/include/EGL/eglext.h" +#endif // MEDIAPIPE_TENSOR_USE_AHWB + #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_context.h" @@ -108,14 +118,37 @@ class Tensor { return static_cast::value, std::tuple >::type>(buffer_); } - CpuView(CpuView&& src) : View(std::move(src)), buffer_(src.buffer_) { - src.buffer_ = nullptr; + CpuView(CpuView&& src) : View(std::move(src)) { + buffer_ = std::exchange(src.buffer_, nullptr); +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + ahwb_ = std::exchange(src.ahwb_, nullptr); + fence_fd_ = std::exchange(src.fence_fd_, nullptr); +#endif // MEDIAPIPE_TENSOR_USE_AHWB + } +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + ~CpuView() { + if (ahwb_) { + auto error = AHardwareBuffer_unlock(ahwb_, fence_fd_); + CHECK(error == 0) << "AHardwareBuffer_unlock " << error; + } } +#endif // MEDIAPIPE_TENSOR_USE_AHWB protected: friend class Tensor; +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + CpuView(T* buffer, AHardwareBuffer* ahwb, int* fence_fd, + std::unique_ptr&& lock) + : View(std::move(lock)), + buffer_(buffer), + fence_fd_(fence_fd), + ahwb_(ahwb) {} + AHardwareBuffer* ahwb_; + int* fence_fd_; +#else CpuView(T* buffer, std::unique_ptr&& lock) : View(std::move(lock)), buffer_(buffer) {} +#endif // MEDIAPIPE_TENSOR_USE_AHWB T* buffer_; }; using CpuReadView = CpuView; @@ -150,6 +183,60 @@ class Tensor { MtlBufferView GetMtlBufferWriteView(id device) const; #endif // MEDIAPIPE_METAL_ENABLED +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + class AHardwareBufferView : public View { + public: + AHardwareBuffer* handle() const { return handle_; } + AHardwareBufferView(AHardwareBufferView&& src) : View(std::move(src)) { + handle_ = std::exchange(src.handle_, nullptr); + file_descriptor_ = src.file_descriptor_; + fence_fd_ = std::exchange(src.fence_fd_, nullptr); + ahwb_written_ = std::exchange(src.ahwb_written_, nullptr); + release_callback_ = std::exchange(src.release_callback_, nullptr); + } + int file_descriptor() const { return file_descriptor_; } + void SetReadingFinishedFunc(std::function&& func) { + CHECK(ahwb_written_) + << "AHWB write view can't accept 'reading finished callback'"; + *ahwb_written_ = std::move(func); + } + void SetWritingFinishedFD(int fd) { + CHECK(fence_fd_) + << "AHWB read view can't accept 'writing finished file descriptor'"; + *fence_fd_ = fd; + } + // The function is called when the tensor is released. + void SetReleaseCallback(std::function callback) { + *release_callback_ = std::move(callback); + } + + protected: + friend class Tensor; + AHardwareBufferView(AHardwareBuffer* handle, int file_descriptor, + int* fence_fd, std::function* ahwb_written, + std::function* release_callback, + std::unique_ptr&& lock) + : View(std::move(lock)), + handle_(handle), + file_descriptor_(file_descriptor), + fence_fd_(fence_fd), + ahwb_written_(ahwb_written), + release_callback_(release_callback) {} + AHardwareBuffer* handle_; + int file_descriptor_; + // The view sets some Tensor's fields. The view is released prior to tensor. + int* fence_fd_; + std::function* ahwb_written_; + std::function* release_callback_; + }; + AHardwareBufferView GetAHardwareBufferReadView() const; + // size_alignment is an optional argument to tell the API to allocate + // a buffer that is padded to multiples of size_alignment bytes. + // size_alignment must be power of 2, i.e. 2, 4, 8, 16, 64, etc. + // If size_alignment is 0, then the buffer will not be padded. + AHardwareBufferView GetAHardwareBufferWriteView(int size_alignment = 0) const; +#endif // MEDIAPIPE_TENSOR_USE_AHWB + #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 // TODO: Use GlTextureView instead. // Only float32 textures are supported with 1/2/3/4 depths. @@ -188,16 +275,23 @@ class Tensor { class OpenGlBufferView : public View { public: GLuint name() const { return name_; } - OpenGlBufferView(OpenGlBufferView&& src) - : View(std::move(src)), name_(src.name_) { - src.name_ = GL_INVALID_INDEX; + OpenGlBufferView(OpenGlBufferView&& src) : View(std::move(src)) { + name_ = std::exchange(src.name_, GL_INVALID_INDEX); + ssbo_read_ = std::exchange(src.ssbo_read_, nullptr); + } + ~OpenGlBufferView() { + if (ssbo_read_) { + *ssbo_read_ = glFenceSync(GL_SYNC_GPU_COMMANDS_COMPLETE, 0); + } } protected: friend class Tensor; - OpenGlBufferView(GLuint name, std::unique_ptr&& lock) - : View(std::move(lock)), name_(name) {} + OpenGlBufferView(GLuint name, std::unique_ptr&& lock, + GLsync* ssbo_read) + : View(std::move(lock)), name_(name), ssbo_read_(ssbo_read) {} GLuint name_; + GLsync* ssbo_read_; }; // A valid OpenGL context must be bound to the calling thread due to possible // GPU resource allocation. @@ -223,16 +317,26 @@ class Tensor { } int bytes() const { return shape_.num_elements() * element_size(); } - bool ready_on_cpu() const { return valid_ & kValidCpu; } + bool ready_on_cpu() const { + return valid_ & (kValidAHardwareBuffer | kValidCpu); + } bool ready_on_gpu() const { - return valid_ & - (kValidMetalBuffer | kValidOpenGlBuffer | kValidOpenGlTexture2d); + return valid_ & (kValidMetalBuffer | kValidOpenGlBuffer | + kValidAHardwareBuffer | kValidOpenGlTexture2d); } bool ready_as_metal_buffer() const { return valid_ & kValidMetalBuffer; } - bool ready_as_opengl_buffer() const { return valid_ & kValidOpenGlBuffer; } + bool ready_as_opengl_buffer() const { + return valid_ & (kValidAHardwareBuffer | kValidOpenGlBuffer); + } bool ready_as_opengl_texture_2d() const { return valid_ & kValidOpenGlTexture2d; } + // Sets the type of underlying resource that is going to be allocated. + enum class StorageType { + kDefault, + kAhwb, + }; + static void SetPreferredStorageType(StorageType type); private: void Move(Tensor*); @@ -248,6 +352,7 @@ class Tensor { kValidMetalBuffer = 1 << 1, kValidOpenGlBuffer = 1 << 2, kValidOpenGlTexture2d = 1 << 3, + kValidAHardwareBuffer = 1 << 5, }; // A list of resource which are currently allocated and synchronized between // each-other: valid_ = kValidCpu | kValidMetalBuffer; @@ -264,6 +369,34 @@ class Tensor { void AllocateMtlBuffer(id device) const; #endif // MEDIAPIPE_METAL_ENABLED +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + mutable AHardwareBuffer* ahwb_ = nullptr; + // Signals when GPU finished writing into SSBO so AHWB can be used then. Or + // signals when writing into AHWB has been finished so GPU can read from SSBO. + // Sync and FD are bound together. + mutable EGLSyncKHR fence_sync_ = EGL_NO_SYNC_KHR; + // This FD signals when the writing into the SSBO has been finished. + mutable int ssbo_written_ = -1; + // An externally set FD that is wrapped with the EGL sync then to synchronize + // AHWB -> OpenGL SSBO. + mutable int fence_fd_ = -1; + // Reading from SSBO has been finished so SSBO can be released. + mutable GLsync ssbo_read_ = 0; + // An externally set function that signals when it is safe to release AHWB. + mutable std::function ahwb_written_; + mutable std::function release_callback_; + bool AllocateAHardwareBuffer(int size_alignment = 0) const; + void CreateEglSyncAndFd() const; + // Use Ahwb for other views: OpenGL / CPU buffer. + static inline bool use_ahwb_ = false; +#endif // MEDIAPIPE_TENSOR_USE_AHWB + bool AllocateAhwbMapToSsbo() const; + bool InsertAhwbToSsboFence() const; + void MoveAhwbStuff(Tensor* src); + void ReleaseAhwbStuff(); + void* MapAhwbToCpuRead() const; + void* MapAhwbToCpuWrite() const; + #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 mutable std::shared_ptr gl_context_; mutable GLuint opengl_texture2d_ = GL_INVALID_INDEX; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc new file mode 100644 index 0000000000..53722ff9b3 --- /dev/null +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -0,0 +1,382 @@ +#include +#include + +#include "mediapipe/framework/formats/tensor.h" + +#ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/gpu/gl_base.h" +#include "third_party/GL/gl/include/EGL/egl.h" +#include "third_party/GL/gl/include/EGL/eglext.h" +#endif // MEDIAPIPE_TENSOR_USE_AHWB + +namespace mediapipe { +#ifdef MEDIAPIPE_TENSOR_USE_AHWB + +namespace { +PFNGLBUFFERSTORAGEEXTERNALEXTPROC glBufferStorageExternalEXT; +PFNEGLGETNATIVECLIENTBUFFERANDROIDPROC eglGetNativeClientBufferANDROID; +PFNEGLDUPNATIVEFENCEFDANDROIDPROC eglDupNativeFenceFDANDROID; +PFNEGLCREATESYNCKHRPROC eglCreateSyncKHR; +PFNEGLWAITSYNCKHRPROC eglWaitSyncKHR; +PFNEGLCLIENTWAITSYNCKHRPROC eglClientWaitSyncKHR; +PFNEGLDESTROYSYNCKHRPROC eglDestroySyncKHR; + +bool IsGlSupported() { + static const bool extensions_allowed = [] { + eglGetNativeClientBufferANDROID = + reinterpret_cast( + eglGetProcAddress("eglGetNativeClientBufferANDROID")); + glBufferStorageExternalEXT = + reinterpret_cast( + eglGetProcAddress("glBufferStorageExternalEXT")); + eglDupNativeFenceFDANDROID = + reinterpret_cast( + eglGetProcAddress("eglDupNativeFenceFDANDROID")); + eglCreateSyncKHR = reinterpret_cast( + eglGetProcAddress("eglCreateSyncKHR")); + eglWaitSyncKHR = reinterpret_cast( + eglGetProcAddress("eglWaitSyncKHR")); + eglClientWaitSyncKHR = reinterpret_cast( + eglGetProcAddress("eglClientWaitSyncKHR")); + eglDestroySyncKHR = reinterpret_cast( + eglGetProcAddress("eglDestroySyncKHR")); + return eglClientWaitSyncKHR && eglWaitSyncKHR && + eglGetNativeClientBufferANDROID && glBufferStorageExternalEXT && + eglCreateSyncKHR && eglDupNativeFenceFDANDROID && eglDestroySyncKHR; + }(); + return extensions_allowed; +} + +absl::Status MapAHardwareBufferToGlBuffer(AHardwareBuffer* handle, size_t size, + GLuint name) { + if (!IsGlSupported()) { + return absl::UnknownError( + "No GL extension functions found to bind AHardwareBuffer and " + "OpenGL buffer"); + } + EGLClientBuffer native_buffer = eglGetNativeClientBufferANDROID(handle); + if (!native_buffer) { + return absl::UnknownError("Can't get native buffer"); + } + glBufferStorageExternalEXT(GL_SHADER_STORAGE_BUFFER, 0, size, native_buffer, + GL_MAP_READ_BIT | GL_MAP_WRITE_BIT | + GL_MAP_COHERENT_BIT_EXT | + GL_MAP_PERSISTENT_BIT_EXT); + if (glGetError() == GL_NO_ERROR) { + return absl::OkStatus(); + } else { + return absl::InternalError("Error in glBufferStorageExternalEXT"); + } +} + +static inline int AlignedToPowerOf2(int value, int alignment) { + // alignment must be a power of 2 + return ((value - 1) | (alignment - 1)) + 1; +} + +// This class keeps tensor's resources while the tensor is in use on GPU or TPU +// but is already released on CPU. When a regular OpenGL buffer is bound to the +// GPU queue for execution and released on client side then the buffer is still +// not released because is being used by GPU. OpenGL driver keeps traking of +// that. When OpenGL buffer is build on top of AHWB then the traking is done +// with the DeleyedRelease which, actually, keeps record of all AHWBs allocated +// and releases each of them if already used. EGL/GL fences are used to check +// the status of a buffer. +class DelayedReleaser { + public: + // Non-copyable + DelayedReleaser(const DelayedReleaser&) = delete; + DelayedReleaser& operator=(const DelayedReleaser&) = delete; + // Non-movable + DelayedReleaser(DelayedReleaser&&) = delete; + DelayedReleaser& operator=(DelayedReleaser&&) = delete; + + static void Add(AHardwareBuffer* ahwb, GLuint opengl_buffer, + EGLSyncKHR ssbo_sync, GLsync ssbo_read, + std::function&& ahwb_written, + std::shared_ptr gl_context, + std::function&& callback) { + static absl::Mutex mutex; + absl::MutexLock lock(&mutex); + // Using `new` to access a non-public constructor. + to_release_.emplace_back(absl::WrapUnique(new DelayedReleaser( + ahwb, opengl_buffer, ssbo_sync, ssbo_read, std::move(ahwb_written), + gl_context, std::move(callback)))); + for (auto it = to_release_.begin(); it != to_release_.end();) { + if ((*it)->IsSignaled()) { + it = to_release_.erase(it); + } else { + ++it; + } + } + } + ~DelayedReleaser() { + AHardwareBuffer_release(ahwb_); + if (release_callback_) release_callback_(); + } + + bool IsSignaled() { + CHECK(!(ssbo_read_ && ahwb_written_)) + << "ssbo_read_ and ahwb_written_ cannot both be set"; + if (ahwb_written_) { + if (!ahwb_written_()) return false; + gl_context_->Run([this]() { + if (fence_sync_ != EGL_NO_SYNC_KHR && IsGlSupported()) { + auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY); + if (egl_display != EGL_NO_DISPLAY) { + eglDestroySyncKHR(egl_display, fence_sync_); + } + fence_sync_ = EGL_NO_SYNC_KHR; + } + glDeleteBuffers(1, &opengl_buffer_); + opengl_buffer_ = GL_INVALID_INDEX; + }); + return true; + } + + gl_context_->Run([this]() { + if (ssbo_read_ != 0) { + GLenum status = glClientWaitSync(ssbo_read_, 0, + /* timeout ns = */ 0); + if (status != GL_CONDITION_SATISFIED && status != GL_ALREADY_SIGNALED) { + return; + } + glDeleteSync(ssbo_read_); + ssbo_read_ = 0; + + // Don't wait on ssbo_sync because it is ahead of ssbo_read_sync. + if (fence_sync_ != EGL_NO_SYNC_KHR && IsGlSupported()) { + auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY); + if (egl_display != EGL_NO_DISPLAY) { + eglDestroySyncKHR(egl_display, fence_sync_); + } + } + fence_sync_ = EGL_NO_SYNC_KHR; + + glDeleteBuffers(1, &opengl_buffer_); + opengl_buffer_ = GL_INVALID_INDEX; + } + }); + return opengl_buffer_ == GL_INVALID_INDEX; + } + + protected: + AHardwareBuffer* ahwb_; + GLuint opengl_buffer_; + // TODO: use wrapper instead. + EGLSyncKHR fence_sync_; + // TODO: use wrapper instead. + GLsync ssbo_read_; + std::function ahwb_written_; + std::shared_ptr gl_context_; + std::function release_callback_; + static inline std::deque> to_release_; + + DelayedReleaser(AHardwareBuffer* ahwb, GLuint opengl_buffer, + EGLSyncKHR fence_sync, GLsync ssbo_read, + std::function&& ahwb_written, + std::shared_ptr gl_context, + std::function&& callback) + : ahwb_(ahwb), + opengl_buffer_(opengl_buffer), + fence_sync_(fence_sync), + ssbo_read_(ssbo_read), + ahwb_written_(std::move(ahwb_written)), + gl_context_(gl_context), + release_callback_(std::move(callback)) {} +}; +} // namespace + +Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { + auto lock(absl::make_unique(&view_mutex_)); + CHECK(valid_ != kValidNone) << "Tensor must be written prior to read from."; + CHECK(!(valid_ & kValidOpenGlTexture2d)) + << "Tensor conversion between OpenGL texture and AHardwareBuffer is not " + "supported."; + CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) + << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " + "supported on targe system."; + CHECK(AllocateAHardwareBuffer()) + << "AHardwareBuffer is not supported on the target system."; + valid_ |= kValidAHardwareBuffer; + if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); + return {ahwb_, + ssbo_written_, + &fence_fd_, // The FD is created for SSBO -> AHWB synchronization. + &ahwb_written_, // Filled by SetReadingFinishedFunc. + &release_callback_, + std::move(lock)}; +} + +void Tensor::CreateEglSyncAndFd() const { + gl_context_->Run([this]() { + if (IsGlSupported()) { + auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY); + if (egl_display != EGL_NO_DISPLAY) { + fence_sync_ = eglCreateSyncKHR(egl_display, + EGL_SYNC_NATIVE_FENCE_ANDROID, nullptr); + if (fence_sync_ != EGL_NO_SYNC_KHR) { + ssbo_written_ = eglDupNativeFenceFDANDROID(egl_display, fence_sync_); + if (ssbo_written_ == -1) { + eglDestroySyncKHR(egl_display, fence_sync_); + fence_sync_ = EGL_NO_SYNC_KHR; + } + } + } + } + // Can't use Sync object. + if (fence_sync_ == EGL_NO_SYNC_KHR) glFinish(); + }); +} + +Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( + int size_alignment) const { + auto lock(absl::make_unique(&view_mutex_)); + CHECK(AllocateAHardwareBuffer(size_alignment)) + << "AHardwareBuffer is not supported on the target system."; + valid_ = kValidAHardwareBuffer; + return {ahwb_, + /*ssbo_written=*/-1, + &fence_fd_, // For SetWritingFinishedFD. + /*ahwb_written=*/nullptr, // The lifetime is managed by SSBO. + &release_callback_, + std::move(lock)}; +} + +bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { + if (!use_ahwb_) return false; + if (ahwb_ == nullptr) { + AHardwareBuffer_Desc desc = {}; + if (size_alignment == 0) { + desc.width = bytes(); + } else { + // We expect allocations to be page-aligned, implicitly satisfying any + // requirements from Edge TPU. No need to add a check for this, + // since Edge TPU will check for us. + desc.width = AlignedToPowerOf2(bytes(), size_alignment); + } + desc.height = 1; + desc.layers = 1; + desc.format = AHARDWAREBUFFER_FORMAT_BLOB; + desc.usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN | + AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | + AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER; + return AHardwareBuffer_allocate(&desc, &ahwb_) == 0; + } + return true; +} + +bool Tensor::AllocateAhwbMapToSsbo() const { + if (AllocateAHardwareBuffer()) { + if (MapAHardwareBufferToGlBuffer(ahwb_, bytes(), opengl_buffer_).ok()) { + glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); + return true; + } + // Unable to make OpenGL <-> AHWB binding. Use regular SSBO instead. + AHardwareBuffer_release(ahwb_); + ahwb_ = nullptr; + } + return false; +} + +// SSBO is created on top of AHWB. A fence is inserted into the GPU queue before +// the GPU task that is going to read from the SSBO. When the writing into AHWB +// is finished then the GPU reads from the SSBO. +bool Tensor::InsertAhwbToSsboFence() const { + if (!ahwb_) return false; + if (fence_fd_ != -1) { + // Can't wait for FD to be signaled on GPU. + // TODO: wait on CPU instead. + if (!IsGlSupported()) return true; + + // Server-side fence. + auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY); + if (egl_display == EGL_NO_DISPLAY) return true; + EGLint sync_attribs[] = {EGL_SYNC_NATIVE_FENCE_FD_ANDROID, + (EGLint)fence_fd_, EGL_NONE}; + fence_sync_ = eglCreateSyncKHR(egl_display, EGL_SYNC_NATIVE_FENCE_ANDROID, + sync_attribs); + if (fence_sync_ != EGL_NO_SYNC_KHR) { + eglWaitSyncKHR(egl_display, fence_sync_, 0); + } + } + return true; +} + +void Tensor::MoveAhwbStuff(Tensor* src) { + ahwb_ = std::exchange(src->ahwb_, nullptr); + fence_sync_ = std::exchange(src->fence_sync_, EGL_NO_SYNC_KHR); + ssbo_read_ = std::exchange(src->ssbo_read_, static_cast(0)); + ssbo_written_ = std::exchange(src->ssbo_written_, -1); + fence_fd_ = std::exchange(src->fence_fd_, -1); + ahwb_written_ = std::move(src->ahwb_written_); + release_callback_ = std::move(src->release_callback_); +} + +void Tensor::ReleaseAhwbStuff() { + if (fence_fd_ != -1) { + close(fence_fd_); + fence_fd_ = -1; + } + if (ahwb_) { + if (ssbo_read_ != 0 || fence_sync_ != EGL_NO_SYNC_KHR) { + if (ssbo_written_ != -1) close(ssbo_written_); + DelayedReleaser::Add(ahwb_, opengl_buffer_, fence_sync_, ssbo_read_, + std::move(ahwb_written_), gl_context_, + std::move(release_callback_)); + opengl_buffer_ = GL_INVALID_INDEX; + } else { + AHardwareBuffer_release(ahwb_); + } + } +} + +void* Tensor::MapAhwbToCpuRead() const { + if (ahwb_) { + if (!(valid_ & kValidCpu) && (valid_ & kValidOpenGlBuffer) && + ssbo_written_ == -1) { + // EGLSync is failed. Use another synchronization method. + // TODO: Use tflite::gpu::GlBufferSync and GlActiveSync. + glFinish(); + } + void* ptr; + auto error = + AHardwareBuffer_lock(ahwb_, AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, + ssbo_written_, nullptr, &ptr); + CHECK(error == 0) << "AHardwareBuffer_lock " << error; + close(ssbo_written_); + ssbo_written_ = -1; + return ptr; + } + return nullptr; +} + +void* Tensor::MapAhwbToCpuWrite() const { + if (ahwb_) { + // TODO: If previously acquired view is GPU write view then need to + // be sure that writing is finished. That's a warning: two consequent write + // views should be interleaved with read view. + void* ptr; + auto error = AHardwareBuffer_lock( + ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN, -1, nullptr, &ptr); + CHECK(error == 0) << "AHardwareBuffer_lock " << error; + return ptr; + } + return nullptr; +} + +#else // MEDIAPIPE_TENSOR_USE_AHWB + +bool Tensor::AllocateAhwbMapToSsbo() const { return false; } +bool Tensor::InsertAhwbToSsboFence() const { return false; } +void Tensor::MoveAhwbStuff(Tensor* src) {} +void Tensor::ReleaseAhwbStuff() {} +void* Tensor::MapAhwbToCpuRead() const { return nullptr; } +void* Tensor::MapAhwbToCpuWrite() const { return nullptr; } + +#endif // MEDIAPIPE_TENSOR_USE_AHWB + +} // namespace mediapipe diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc new file mode 100644 index 0000000000..805dce1d89 --- /dev/null +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -0,0 +1,59 @@ +#include "mediapipe/gpu/gpu_test_base.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" + +#ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include + +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe { + +#if !MEDIAPIPE_DISABLE_GPU +class TensorAhwbTest : public mediapipe::GpuTestBase { + public: +}; + +TEST_F(TensorAhwbTest, TestCpuThenAHWB) { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); + { + auto ptr = tensor.GetCpuWriteView().buffer(); + EXPECT_NE(ptr, nullptr); + } + { + auto ahwb = tensor.GetAHardwareBufferReadView().handle(); + EXPECT_NE(ahwb, nullptr); + } +} + +TEST_F(TensorAhwbTest, TestAHWBThenCpu) { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); + { + auto ahwb = tensor.GetAHardwareBufferWriteView().handle(); + EXPECT_NE(ahwb, nullptr); + } + { + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + } +} + +TEST_F(TensorAhwbTest, TestCpuThenGl) { + RunInGlContext([] { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); + { + auto ptr = tensor.GetCpuWriteView().buffer(); + EXPECT_NE(ptr, nullptr); + } + { + auto ssbo = tensor.GetOpenGlBufferReadView().name(); + EXPECT_GT(ssbo, 0); + } + }); +} + +} // namespace mediapipe + +#endif // !MEDIAPIPE_DISABLE_GPU + +#endif // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/formats/time_series_header.proto b/mediapipe/framework/formats/time_series_header.proto index a16f47d9b9..6fd56c198c 100644 --- a/mediapipe/framework/formats/time_series_header.proto +++ b/mediapipe/framework/formats/time_series_header.proto @@ -21,6 +21,8 @@ syntax = "proto2"; package mediapipe; +option objc_class_prefix = "MediaPipe"; + // Header for a uniformly sampled time series stream. Each Packet in // the stream is a Matrix, and each column is a (vector-valued) sample of // the series, i.e. each column corresponds to a distinct sample in time. diff --git a/mediapipe/framework/input_stream_manager.cc b/mediapipe/framework/input_stream_manager.cc index f47259877b..1af2e2cc8b 100644 --- a/mediapipe/framework/input_stream_manager.cc +++ b/mediapipe/framework/input_stream_manager.cc @@ -204,6 +204,8 @@ absl::Status InputStreamManager::SetNextTimestampBound(const Timestamp bound, // untimed scheduling policies. if (bound > next_timestamp_bound_) { next_timestamp_bound_ = bound; + VLOG(3) << "Next timestamp bound for input " << name_ << " is " + << next_timestamp_bound_; if (queue_.empty()) { // If the queue was not empty then a change to the next_timestamp_bound_ // is not detectable by the consumer. diff --git a/mediapipe/framework/output_stream_manager.cc b/mediapipe/framework/output_stream_manager.cc index 0784bdccc9..b092313e21 100644 --- a/mediapipe/framework/output_stream_manager.cc +++ b/mediapipe/framework/output_stream_manager.cc @@ -168,6 +168,8 @@ void OutputStreamManager::PropagateUpdatesToMirrors( if (next_timestamp_bound != Timestamp::Unset()) { absl::MutexLock lock(&stream_mutex_); next_timestamp_bound_ = next_timestamp_bound; + VLOG(3) << "Next timestamp bound for output " << output_stream_spec_.name + << " is " << next_timestamp_bound_; } } std::list* packets_to_propagate = output_stream_shard->OutputQueue(); diff --git a/mediapipe/framework/packet.cc b/mediapipe/framework/packet.cc index 1fbd55e974..05d3c6c52c 100644 --- a/mediapipe/framework/packet.cc +++ b/mediapipe/framework/packet.cc @@ -106,19 +106,17 @@ std::string Packet::DebugString() const { return result; } -absl::Status Packet::ValidateAsType(const tool::TypeInfo& type_info) const { +absl::Status Packet::ValidateAsType(TypeId type_id) const { if (ABSL_PREDICT_FALSE(IsEmpty())) { - return absl::InternalError( - absl::StrCat("Expected a Packet of type: ", - MediaPipeTypeStringOrDemangled(type_info), - ", but received an empty Packet.")); + return absl::InternalError(absl::StrCat( + "Expected a Packet of type: ", MediaPipeTypeStringOrDemangled(type_id), + ", but received an empty Packet.")); } - bool holder_is_right_type = - holder_->GetTypeInfo().hash_code() == type_info.hash_code(); + bool holder_is_right_type = holder_->GetTypeId() == type_id; if (ABSL_PREDICT_FALSE(!holder_is_right_type)) { return absl::InvalidArgumentError(absl::StrCat( "The Packet stores \"", holder_->DebugTypeName(), "\", but \"", - MediaPipeTypeStringOrDemangled(type_info), "\" was requested.")); + MediaPipeTypeStringOrDemangled(type_id), "\" was requested.")); } return absl::OkStatus(); } diff --git a/mediapipe/framework/packet.h b/mediapipe/framework/packet.h index 82f0ec0872..1024cbc154 100644 --- a/mediapipe/framework/packet.h +++ b/mediapipe/framework/packet.h @@ -21,7 +21,6 @@ #include #include #include -#include #include "absl/base/macros.h" #include "absl/memory/memory.h" @@ -69,7 +68,7 @@ absl::StatusOr PacketFromDynamicProto(const std::string& type_name, // The preferred method of creating a Packet is with MakePacket(). // The Packet typically owns the object that it contains, but // PointToForeign allows a Packet to be constructed which does not -// own it's data. +// own its data. // // This class is thread compatible. class Packet { @@ -180,7 +179,7 @@ class Packet { // Returns an error if the packet does not contain data of type T. template absl::Status ValidateAsType() const { - return ValidateAsType(tool::TypeInfo::Get()); + return ValidateAsType(kTypeId); } // Returns an error if the packet is not an instance of @@ -189,11 +188,7 @@ class Packet { // Get the type id for the underlying type stored in the Packet. // Crashes if IsEmpty() == true. - size_t GetTypeId() const { return GetTypeInfo().hash_code(); } - - // Get the type info for the underlying type stored in the Packet. - // Crashes if IsEmpty() == true. - const tool::TypeInfo& GetTypeInfo() const; + TypeId GetTypeId() const; // Returns the timestamp. class Timestamp Timestamp() const; @@ -225,7 +220,7 @@ class Packet { packet_internal::GetHolderShared(Packet&& packet); friend class PacketType; - absl::Status ValidateAsType(const tool::TypeInfo& type_info) const; + absl::Status ValidateAsType(TypeId type_id) const; std::shared_ptr holder_; class Timestamp timestamp_; @@ -369,7 +364,7 @@ class HolderBase { virtual ~HolderBase(); template bool PayloadIsOfType() const { - return GetTypeInfo().hash_code() == tool::GetTypeHash(); + return GetTypeId() == kTypeId; } // Returns a printable string identifying the type stored in the holder. virtual const std::string DebugTypeName() const = 0; @@ -377,7 +372,7 @@ class HolderBase { // empty string. virtual const std::string RegisteredTypeName() const = 0; // Get the type id of the underlying data type. - virtual const tool::TypeInfo& GetTypeInfo() const = 0; + virtual TypeId GetTypeId() const = 0; // Downcasts this to Holder. Returns nullptr if deserialization // failed or if the requested type is not what is stored. template @@ -428,7 +423,7 @@ StatusOr> ConvertToVectorOfProtoMessageLitePtrs(const T* data, /*is_proto_vector=*/std::false_type) { return absl::InvalidArgumentError(absl::StrCat( - "The Packet stores \"", tool::TypeInfo::Get().name(), "\"", + "The Packet stores \"", kTypeId.name(), "\"", "which is not convertible to vector.")); } @@ -510,9 +505,7 @@ class Holder : public HolderBase { HolderSupport::EnsureStaticInit(); return *ptr_; } - const tool::TypeInfo& GetTypeInfo() const final { - return tool::TypeInfo::Get(); - } + TypeId GetTypeId() const final { return kTypeId; } // Releases the underlying data pointer and transfers the ownership to a // unique pointer. // This method is dangerous and is only used by Packet::Consume() if the @@ -748,9 +741,9 @@ inline Packet& Packet::operator=(Packet&& packet) { inline bool Packet::IsEmpty() const { return holder_ == nullptr; } -inline const tool::TypeInfo& Packet::GetTypeInfo() const { +inline TypeId Packet::GetTypeId() const { CHECK(holder_); - return holder_->GetTypeInfo(); + return holder_->GetTypeId(); } template diff --git a/mediapipe/framework/packet_test.proto b/mediapipe/framework/packet_test.proto index bccfd6b5fa..3f10911ab1 100644 --- a/mediapipe/framework/packet_test.proto +++ b/mediapipe/framework/packet_test.proto @@ -18,6 +18,8 @@ syntax = "proto2"; package mediapipe; +option objc_class_prefix = "MediaPipe"; + message PacketTestProto { // Tests that the tags used to encode the timestamp do not interfere with // proto tags. diff --git a/mediapipe/framework/packet_type.cc b/mediapipe/framework/packet_type.cc index c633d17a87..d3d0df2962 100644 --- a/mediapipe/framework/packet_type.cc +++ b/mediapipe/framework/packet_type.cc @@ -127,13 +127,13 @@ bool PacketType::IsOneOf() const { } bool PacketType::IsExactType() const { - return absl::holds_alternative(type_spec_); + return absl::holds_alternative(type_spec_); } const std::string* PacketType::RegisteredTypeName() const { if (auto* same_as = SameAsPtr()) return same_as->RegisteredTypeName(); - if (auto* type_info = absl::get_if(&type_spec_)) - return MediaPipeTypeStringFromTypeId((**type_info).hash_code()); + if (auto* type_id = absl::get_if(&type_spec_)) + return MediaPipeTypeStringFromTypeId(*type_id); if (auto* multi_type = absl::get_if(&type_spec_)) return multi_type->registered_type_name; return nullptr; @@ -141,8 +141,8 @@ const std::string* PacketType::RegisteredTypeName() const { namespace internal { -struct TypeInfoFormatter { - void operator()(std::string* out, const tool::TypeInfo& t) const { +struct TypeIdFormatter { + void operator()(std::string* out, TypeId t) const { absl::StrAppend(out, MediaPipeTypeStringOrDemangled(t)); } }; @@ -167,12 +167,9 @@ explicit QuoteFormatter(Formatter f) -> QuoteFormatter; } // namespace internal -std::string PacketType::TypeNameForOneOf(TypeInfoSpan types) { +std::string PacketType::TypeNameForOneOf(TypeIdSpan types) { return absl::StrCat( - "OneOf<", - absl::StrJoin(types, ", ", - absl::DereferenceFormatter(internal::TypeInfoFormatter())), - ">"); + "OneOf<", absl::StrJoin(types, ", ", internal::TypeIdFormatter()), ">"); } std::string PacketType::DebugTypeName() const { @@ -185,8 +182,8 @@ std::string PacketType::DebugTypeName() const { if (auto* special = absl::get_if(&type_spec_)) { return special->name_; } - if (auto* type_info = absl::get_if(&type_spec_)) { - return MediaPipeTypeStringOrDemangled(**type_info); + if (auto* type_id = absl::get_if(&type_spec_)) { + return MediaPipeTypeStringOrDemangled(*type_id); } if (auto* multi_type = absl::get_if(&type_spec_)) { return TypeNameForOneOf(multi_type->types); @@ -194,11 +191,11 @@ std::string PacketType::DebugTypeName() const { return "[Undefined Type]"; } -static bool HaveCommonType(absl::Span types1, - absl::Span types2) { +static bool HaveCommonType(absl::Span types1, + absl::Span types2) { for (const auto& first : types1) { for (const auto& second : types2) { - if (first->hash_code() == second->hash_code()) { + if (first == second) { return true; } } @@ -216,35 +213,34 @@ absl::Status PacketType::Validate(const Packet& packet) const { // in SetSameAs(). return GetSameAs()->Validate(packet); } - if (auto* type_info = absl::get_if(&type_spec_)) { - return packet.ValidateAsType(**type_info); + if (auto* type_id = absl::get_if(&type_spec_)) { + return packet.ValidateAsType(*type_id); } if (packet.IsEmpty()) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Empty packets are not allowed for type: " << DebugTypeName(); } if (auto* multi_type = absl::get_if(&type_spec_)) { - auto* packet_type = &packet.GetTypeInfo(); + auto packet_type = packet.GetTypeId(); if (HaveCommonType(multi_type->types, absl::MakeSpan(&packet_type, 1))) { return absl::OkStatus(); } else { return absl::InvalidArgumentError(absl::StrCat( "The Packet stores \"", packet.DebugTypeName(), "\", but one of ", absl::StrJoin(multi_type->types, ", ", - absl::DereferenceFormatter(internal::QuoteFormatter( - internal::TypeInfoFormatter()))), + internal::QuoteFormatter(internal::TypeIdFormatter())), " was requested.")); } } if (auto* special = absl::get_if(&type_spec_)) { - return special->accept_fn_(&packet.GetTypeInfo()); + return special->accept_fn_(packet.GetTypeId()); } return absl::OkStatus(); } -PacketType::TypeInfoSpan PacketType::GetTypeSpan(const TypeSpec& type_spec) { - if (auto* type_info = absl::get_if(&type_spec)) - return absl::MakeSpan(type_info, 1); +PacketType::TypeIdSpan PacketType::GetTypeSpan(const TypeSpec& type_spec) { + if (auto* type_id = absl::get_if(&type_spec)) + return absl::MakeSpan(type_id, 1); if (auto* multi_type = absl::get_if(&type_spec)) return multi_type->types; return {}; @@ -254,8 +250,8 @@ bool PacketType::IsConsistentWith(const PacketType& other) const { const PacketType* type1 = GetSameAs(); const PacketType* type2 = other.GetSameAs(); - TypeInfoSpan types1 = GetTypeSpan(type1->type_spec_); - TypeInfoSpan types2 = GetTypeSpan(type2->type_spec_); + TypeIdSpan types1 = GetTypeSpan(type1->type_spec_); + TypeIdSpan types2 = GetTypeSpan(type2->type_spec_); if (!types1.empty() && !types2.empty()) { return HaveCommonType(types1, types2); } diff --git a/mediapipe/framework/packet_type.h b/mediapipe/framework/packet_type.h index 09d4d93451..9b4bbd36cf 100644 --- a/mediapipe/framework/packet_type.h +++ b/mediapipe/framework/packet_type.h @@ -121,15 +121,15 @@ class PacketType { // We don't do union-find optimizations in order to avoid a mutex. const PacketType* other; }; - using TypeInfoSpan = absl::Span; + using TypeIdSpan = absl::Span; struct MultiType { - TypeInfoSpan types; + TypeIdSpan types; // TODO: refactor RegisteredTypeName, remove. const std::string* registered_type_name; }; struct SpecialType; - using TypeSpec = absl::variant; + using TypeSpec = + absl::variant; typedef absl::Status (*AcceptsTypeFn)(const TypeSpec& type); struct SpecialType { std::string name_; @@ -140,8 +140,8 @@ class PacketType { static absl::Status AcceptNone(const TypeSpec& type); const PacketType* SameAsPtr() const; - static TypeInfoSpan GetTypeSpan(const TypeSpec& type_spec); - static std::string TypeNameForOneOf(TypeInfoSpan types); + static TypeIdSpan GetTypeSpan(const TypeSpec& type_spec); + static std::string TypeNameForOneOf(TypeIdSpan types); TypeSpec type_spec_; @@ -259,14 +259,13 @@ absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set); template PacketType& PacketType::Set() { - type_spec_ = &tool::TypeInfo::Get(); + type_spec_ = kTypeId; return *this; } template PacketType& PacketType::SetOneOf() { - static const NoDestructor> types{ - {&tool::TypeInfo::Get()...}}; + static const NoDestructor> types{{kTypeId...}}; static const NoDestructor name{TypeNameForOneOf(*types)}; type_spec_ = MultiType{*types, &*name}; return *this; diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index a5c3254b31..05a8425b8c 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -43,7 +43,7 @@ const int kDefaultLogFileCount = 2; const char kDefaultLogFilePrefix[] = "mediapipe_trace_"; // The number of recent timestamps tracked for each input stream. -const int kPacketInfoRecentCount = 100; +const int kPacketInfoRecentCount = 400; std::string PacketIdToString(const PacketId& packet_id) { return absl::Substitute("stream_name: $0, timestamp_usec: $1", @@ -507,8 +507,8 @@ int64 GraphProfiler::AddInputStreamTimeSamples( // This is a condition rather than a failure CHECK because // under certain conditions the consumer calculator's Process() // can start before the producer calculator's Process() is finished. - LOG_EVERY_N(WARNING, 100) << "Expected packet info is missing for: " - << PacketIdToString(packet_id); + LOG_FIRST_N(WARNING, 10) << "Expected packet info is missing for: " + << PacketIdToString(packet_id); continue; } AddTimeSample( diff --git a/mediapipe/framework/subgraph.h b/mediapipe/framework/subgraph.h index 3b83d2addb..b3e7d958bd 100644 --- a/mediapipe/framework/subgraph.h +++ b/mediapipe/framework/subgraph.h @@ -36,7 +36,7 @@ class SubgraphContext { public: SubgraphContext() : SubgraphContext(nullptr, nullptr) {} // @node and/or @service_manager can be nullptr. - SubgraphContext(const CalculatorGraphConfig::Node* node, + SubgraphContext(CalculatorGraphConfig::Node* node, const GraphServiceManager* service_manager) : default_node_(node ? absl::nullopt : absl::optional( @@ -48,14 +48,19 @@ class SubgraphContext { : absl::optional(GraphServiceManager())), service_manager_(service_manager ? *service_manager : default_service_manager_.value()), - options_map_(std::move(tool::OptionsMap().Initialize(original_node_))) { - } + options_map_( + std::move(tool::MutableOptionsMap().Initialize(original_node_))) {} template const T& Options() { return options_map_.Get(); } + template + T* MutableOptions() { + return options_map_.GetMutable(); + } + const CalculatorGraphConfig::Node& OriginalNode() const { return original_node_; } @@ -67,16 +72,16 @@ class SubgraphContext { private: // Populated if node is not provided during construction. - const absl::optional default_node_; + absl::optional default_node_; - const CalculatorGraphConfig::Node& original_node_; + CalculatorGraphConfig::Node& original_node_; // Populated if service manager is not provided during construction. const absl::optional default_service_manager_; const GraphServiceManager& service_manager_; - tool::OptionsMap options_map_; + tool::MutableOptionsMap options_map_; }; // Instances of this class are responsible for providing a subgraph config. diff --git a/mediapipe/framework/test_calculators.proto b/mediapipe/framework/test_calculators.proto index af75dc13a5..77dde80b40 100644 --- a/mediapipe/framework/test_calculators.proto +++ b/mediapipe/framework/test_calculators.proto @@ -22,6 +22,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option objc_class_prefix = "MediaPipe"; + message RandomMatrixCalculatorOptions { extend CalculatorOptions { optional RandomMatrixCalculatorOptions ext = 52056136; diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 66f3061e06..28fa3ea55e 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -198,6 +198,7 @@ cc_library( ":name_util", ":options_registry", ":proto_util_lite", + ":type_util", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:packet", "//mediapipe/framework:packet_type", @@ -277,9 +278,12 @@ cc_library( hdrs = ["options_registry.h"], visibility = ["//visibility:public"], deps = [ + ":field_data_cc_proto", + ":proto_util_lite", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/synchronization", ], @@ -334,6 +338,7 @@ cc_library( hdrs = ["proto_util_lite.h"], visibility = ["//visibility:public"], deps = [ + ":field_data_cc_proto", "//mediapipe/framework:type_map", "//mediapipe/framework/port:advanced_proto_lite", "//mediapipe/framework/port:integral_types", @@ -518,9 +523,11 @@ cc_library( cc_library( name = "type_util", hdrs = ["type_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:demangle", "//mediapipe/framework:port", + "@com_google_absl//absl/base:core_headers", ], ) diff --git a/mediapipe/framework/tool/calculator_graph_template.proto b/mediapipe/framework/tool/calculator_graph_template.proto index c2fc6aa869..27153f3f78 100644 --- a/mediapipe/framework/tool/calculator_graph_template.proto +++ b/mediapipe/framework/tool/calculator_graph_template.proto @@ -3,6 +3,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/framework/deps/proto_descriptor.proto"; option java_package = "com.google.mediapipe.proto"; diff --git a/mediapipe/framework/tool/options_field_util.cc b/mediapipe/framework/tool/options_field_util.cc index 0fdbc47abb..da90919e96 100644 --- a/mediapipe/framework/tool/options_field_util.cc +++ b/mediapipe/framework/tool/options_field_util.cc @@ -7,6 +7,7 @@ #include #include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/packet.h" @@ -18,6 +19,7 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/tool/name_util.h" #include "mediapipe/framework/tool/proto_util_lite.h" +#include "mediapipe/framework/tool/type_util.h" namespace mediapipe { namespace tool { @@ -41,165 +43,39 @@ FieldType AsFieldType(proto_ns::FieldDescriptorProto::Type type) { return static_cast(type); } -absl::Status WriteValue(const FieldData& value, FieldType field_type, - std::string* field_bytes) { - StringOutputStream sos(field_bytes); - CodedOutputStream out(&sos); - switch (field_type) { - case WireFormatLite::TYPE_INT32: - WireFormatLite::WriteInt32NoTag(value.int32_value(), &out); - break; - case WireFormatLite::TYPE_SINT32: - WireFormatLite::WriteSInt32NoTag(value.int32_value(), &out); - break; - case WireFormatLite::TYPE_INT64: - WireFormatLite::WriteInt64NoTag(value.int64_value(), &out); - break; - case WireFormatLite::TYPE_SINT64: - WireFormatLite::WriteSInt64NoTag(value.int64_value(), &out); - break; - case WireFormatLite::TYPE_UINT32: - WireFormatLite::WriteUInt32NoTag(value.uint32_value(), &out); - break; - case WireFormatLite::TYPE_UINT64: - WireFormatLite::WriteUInt64NoTag(value.uint64_value(), &out); - break; - case WireFormatLite::TYPE_DOUBLE: - WireFormatLite::WriteDoubleNoTag(value.uint64_value(), &out); - break; - case WireFormatLite::TYPE_FLOAT: - WireFormatLite::WriteFloatNoTag(value.float_value(), &out); - break; - case WireFormatLite::TYPE_BOOL: - WireFormatLite::WriteBoolNoTag(value.bool_value(), &out); - break; - case WireFormatLite::TYPE_ENUM: - WireFormatLite::WriteEnumNoTag(value.enum_value(), &out); - break; - case WireFormatLite::TYPE_STRING: - out.WriteString(value.string_value()); - break; - case WireFormatLite::TYPE_MESSAGE: - out.WriteString(value.message_value().value()); - break; - default: - return absl::UnimplementedError( - absl::StrCat("Cannot write type: ", field_type)); - } - return absl::OkStatus(); -} - // Serializes a packet value. absl::Status WriteField(const FieldData& packet, const FieldDescriptor* field, std::string* result) { - FieldType field_type = AsFieldType(field->type()); - return WriteValue(packet, field_type, result); -} - -template -static ValueT ReadValue(absl::string_view field_bytes, absl::Status* status) { - ArrayInputStream ais(field_bytes.data(), field_bytes.size()); - CodedInputStream input(&ais); - ValueT result; - if (!WireFormatLite::ReadPrimitive(&input, &result)) { - status->Update(mediapipe::InvalidArgumentError(absl::StrCat( - "Bad serialized value: ", MediaPipeTypeStringOrDemangled(), - "."))); - } - return result; -} - -absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type, - absl::string_view message_type, FieldData* result) { - absl::Status status; - result->Clear(); - switch (field_type) { - case WireFormatLite::TYPE_INT32: - result->set_int32_value( - ReadValue(field_bytes, &status)); - break; - case WireFormatLite::TYPE_SINT32: - result->set_int32_value( - ReadValue(field_bytes, &status)); - break; - case WireFormatLite::TYPE_INT64: - result->set_int64_value( - ReadValue(field_bytes, &status)); - break; - case WireFormatLite::TYPE_SINT64: - result->set_int64_value( - ReadValue(field_bytes, &status)); - break; - case WireFormatLite::TYPE_UINT32: - result->set_uint32_value( - ReadValue(field_bytes, &status)); - break; - case WireFormatLite::TYPE_UINT64: - result->set_uint64_value( - ReadValue(field_bytes, &status)); - break; - case WireFormatLite::TYPE_DOUBLE: - result->set_double_value( - ReadValue(field_bytes, &status)); - break; - case WireFormatLite::TYPE_FLOAT: - result->set_float_value( - ReadValue(field_bytes, &status)); - break; - case WireFormatLite::TYPE_BOOL: - result->set_bool_value( - ReadValue(field_bytes, &status)); - break; - case WireFormatLite::TYPE_ENUM: - result->set_enum_value( - ReadValue(field_bytes, &status)); - break; - case WireFormatLite::TYPE_STRING: - result->set_string_value(std::string(field_bytes)); - break; - case WireFormatLite::TYPE_MESSAGE: - result->mutable_message_value()->set_value(std::string(field_bytes)); - result->mutable_message_value()->set_type_url(TypeUrl(message_type)); - break; - default: - status = absl::UnimplementedError( - absl::StrCat("Cannot read type: ", field_type)); - break; - } - return status; + return ProtoUtilLite::WriteValue(packet, field->type(), result); } // Deserializes a packet from a protobuf field. -absl::Status ReadField(absl::string_view bytes, const FieldDescriptor* field, +absl::Status ReadField(absl::string_view bytes, const FieldDescriptor& field, FieldData* result) { - RET_CHECK_NE(field, nullptr); - FieldType field_type = AsFieldType(field->type()); - std::string message_type = (field_type == WireFormatLite::TYPE_MESSAGE) - ? field->message_type()->full_name() + std::string message_type = (field.type() == WireFormatLite::TYPE_MESSAGE) + ? field.message_type()->full_name() : ""; - return ReadValue(bytes, field_type, message_type, result); + return ProtoUtilLite::ReadValue(bytes, field.type(), message_type, result); } // Reads all values from a repeated field. -absl::Status GetFieldValues(const FieldData& message_data, - const FieldDescriptor& field, - std::vector* result) { +absl::StatusOr> GetFieldValues( + const FieldData& message_data, const FieldDescriptor& field) { + std::vector result; const std::string& message_bytes = message_data.message_value().value(); - FieldType field_type = AsFieldType(field.type()); - ProtoUtilLite proto_util; ProtoUtilLite::ProtoPath proto_path = {{field.number(), 0}}; int count; - MP_RETURN_IF_ERROR( - proto_util.GetFieldCount(message_bytes, proto_path, field_type, &count)); + MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount(message_bytes, proto_path, + field.type(), &count)); std::vector field_values; - MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, count, - field_type, &field_values)); - for (int i = 0; i < count; ++i) { + MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange( + message_bytes, proto_path, count, field.type(), &field_values)); + for (int i = 0; i < field_values.size(); ++i) { FieldData r; - MP_RETURN_IF_ERROR(ReadField(field_values[i], &field, &r)); - result->push_back(std::move(r)); + MP_RETURN_IF_ERROR(ReadField(field_values[i], field, &r)); + result.push_back(std::move(r)); } - return absl::OkStatus(); + return result; } // Reads one value from a field. @@ -207,42 +83,70 @@ absl::Status GetFieldValue(const FieldData& message_data, const FieldPathEntry& entry, FieldData* result) { RET_CHECK_NE(entry.field, nullptr); const std::string& message_bytes = message_data.message_value().value(); - FieldType field_type = AsFieldType(entry.field->type()); - ProtoUtilLite proto_util; - ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), entry.index}}; + FieldType field_type = entry.field->type(); + int index = std::max(0, entry.index); + ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), index}}; std::vector field_values; - MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, 1, - field_type, &field_values)); - MP_RETURN_IF_ERROR(ReadField(field_values[0], entry.field, result)); + MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(message_bytes, proto_path, 1, + field_type, &field_values)); + MP_RETURN_IF_ERROR(ReadField(field_values[0], *entry.field, result)); return absl::OkStatus(); } // Writes one value to a field. -absl::Status SetFieldValue(const FieldPathEntry& entry, const FieldData& value, - FieldData* result) { - std::vector field_values; - ProtoUtilLite proto_util; - FieldType field_type = AsFieldType(entry.field->type()); - ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), entry.index}}; - std::string* message_bytes = result->mutable_message_value()->mutable_value(); +absl::Status SetFieldValue(FieldData& result, const FieldPathEntry& entry, + const FieldData& value) { + int index = std::max(0, entry.index); + ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), index}}; + std::string* message_bytes = result.mutable_message_value()->mutable_value(); int field_count; - MP_RETURN_IF_ERROR(proto_util.GetFieldCount(*message_bytes, proto_path, - field_type, &field_count)); - if (entry.index > field_count) { + MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount( + *message_bytes, proto_path, entry.field->type(), &field_count)); + if (index > field_count) { return absl::OutOfRangeError( - absl::StrCat("Option field index out of range: ", entry.index)); + absl::StrCat("Option field index out of range: ", index)); } - int replace_length = entry.index < field_count ? 1 : 0; + int replace_length = index < field_count ? 1 : 0; std::string field_value; MP_RETURN_IF_ERROR(WriteField(value, entry.field, &field_value)); - MP_RETURN_IF_ERROR(proto_util.ReplaceFieldRange( - message_bytes, proto_path, replace_length, field_type, {field_value})); + MP_RETURN_IF_ERROR(ProtoUtilLite::ReplaceFieldRange( + message_bytes, proto_path, replace_length, entry.field->type(), + {field_value})); + return absl::OkStatus(); +} + +// Writes several values to a repeated field. +// The specified |values| replace the specified |entry| index, +// or if no index is specified all field values are replaced. +absl::Status SetFieldValues(FieldData& result, const FieldPathEntry& entry, + const std::vector& values) { + if (entry.field == nullptr) { + return absl::InvalidArgumentError("Field not found."); + } + FieldType field_type = entry.field->type(); + ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), 0}}; + std::string* message_bytes = result.mutable_message_value()->mutable_value(); + int field_count; + MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount(*message_bytes, proto_path, + field_type, &field_count)); + int replace_start = 0, replace_length = field_count; + if (entry.index > -1) { + replace_start = entry.index; + replace_length = 1; + } + std::vector field_values(values.size()); + for (int i = 0; i < values.size(); ++i) { + MP_RETURN_IF_ERROR(WriteField(values[i], entry.field, &field_values[i])); + } + proto_path = {{entry.field->number(), replace_start}}; + MP_RETURN_IF_ERROR(ProtoUtilLite::ReplaceFieldRange( + message_bytes, proto_path, replace_length, field_type, field_values)); return absl::OkStatus(); } // Returns true for a field of type "google.protobuf.Any". bool IsProtobufAny(const FieldDescriptor* field) { - return AsFieldType(field->type()) == FieldType::TYPE_MESSAGE && + return field->type() == FieldType::TYPE_MESSAGE && field->message_type()->full_name() == kGoogleProtobufAny; } @@ -275,9 +179,7 @@ StatusOr FindExtensionIndex(const FieldData& message_data, } std::string& extension_type = entry->extension_type; std::vector field_values; - RET_CHECK_NE(entry->field, nullptr); - MP_RETURN_IF_ERROR( - GetFieldValues(message_data, *entry->field, &field_values)); + ASSIGN_OR_RETURN(field_values, GetFieldValues(message_data, *entry->field)); for (int i = 0; i < field_values.size(); ++i) { FieldData extension = ParseProtobufAny(field_values[i]); if (extension_type == "*" || @@ -290,9 +192,9 @@ StatusOr FindExtensionIndex(const FieldData& message_data, // Returns true if the value of a field is available. bool HasField(const FieldPath& field_path, const FieldData& message_data) { - FieldData value; - return GetField(field_path, message_data, &value).ok() && - value.value_case() != mediapipe::FieldData::VALUE_NOT_SET; + auto value = GetField(message_data, field_path); + return value.ok() && + value->value_case() != mediapipe::FieldData::VALUE_NOT_SET; } // Returns the extension field containing the specified extension-type. @@ -330,43 +232,24 @@ void SetOptionsMessage( *options_any->mutable_value() = node_options.message_value().value(); } -// Returns the count of values in a repeated field. -int FieldCount(const FieldData& message_data, const FieldDescriptor* field) { - const std::string& message_bytes = message_data.message_value().value(); - FieldType field_type = AsFieldType(field->type()); - ProtoUtilLite proto_util; - ProtoUtilLite::ProtoPath proto_path = {{field->number(), 0}}; - int count; - if (proto_util.GetFieldCount(message_bytes, proto_path, field_type, &count) - .ok()) { - return count; - } - return 0; -} - } // anonymous namespace // Deserializes a packet containing a MessageLite value. -absl::Status ReadMessage(const std::string& value, const std::string& type_name, - Packet* result) { - auto packet = packet_internal::PacketFromDynamicProto(type_name, value); - if (packet.ok()) { - *result = *packet; - } - return packet.status(); +absl::StatusOr ReadMessage(const std::string& value, + const std::string& type_name) { + return packet_internal::PacketFromDynamicProto(type_name, value); } // Merge two options FieldData values. -absl::Status MergeMessages(const FieldData& base, const FieldData& over, - FieldData* result) { +absl::StatusOr MergeMessages(const FieldData& base, + const FieldData& over) { + FieldData result; absl::Status status; if (over.value_case() == FieldData::VALUE_NOT_SET) { - *result = base; - return status; + return base; } if (base.value_case() == FieldData::VALUE_NOT_SET) { - *result = over; - return status; + return over; } if (over.value_case() != base.value_case()) { return absl::InvalidArgumentError(absl::StrCat( @@ -382,10 +265,9 @@ absl::Status MergeMessages(const FieldData& base, const FieldData& over, absl::Cord merged_value; merged_value.Append(base.message_value().value()); merged_value.Append(over.message_value().value()); - result->mutable_message_value()->set_type_url( - base.message_value().type_url()); - result->mutable_message_value()->set_value(std::string(merged_value)); - return status; + result.mutable_message_value()->set_type_url(base.message_value().type_url()); + result.mutable_message_value()->set_value(std::string(merged_value)); + return result; } // Returns either the extension field or the repeated protobuf.Any field index @@ -439,51 +321,48 @@ FieldPath GetExtensionPath(const std::string& parent_type, } // Returns the requested options protobuf for a graph node. -absl::Status GetNodeOptions(const FieldData& message_data, - const std::string& extension_type, - FieldData* result) { +absl::StatusOr GetNodeOptions(const FieldData& message_data, + const std::string& extension_type) { constexpr char kOptionsName[] = "options"; constexpr char kNodeOptionsName[] = "node_options"; std::string parent_type = options_field_util::ParseTypeUrl( std::string(message_data.message_value().type_url())); FieldPath path; - Status status; + absl::Status status; path = GetExtensionPath(parent_type, extension_type, kOptionsName, false); - status = GetField(path, message_data, result); - if (status.ok()) { - return status; + auto result = GetField(message_data, path); + if (result.ok()) { + return result; } path = GetExtensionPath(parent_type, extension_type, kNodeOptionsName, true); - status = GetField(path, message_data, result); - return status; + return GetField(message_data, path); } // Returns the requested options protobuf for a graph. -absl::Status GetGraphOptions(const FieldData& message_data, - const std::string& extension_type, - FieldData* result) { +absl::StatusOr GetGraphOptions(const FieldData& message_data, + const std::string& extension_type) { constexpr char kOptionsName[] = "options"; constexpr char kGraphOptionsName[] = "graph_options"; std::string parent_type = options_field_util::ParseTypeUrl( std::string(message_data.message_value().type_url())); FieldPath path; - Status status; + absl::Status status; path = GetExtensionPath(parent_type, extension_type, kOptionsName, false); - status = GetField(path, message_data, result); - if (status.ok()) { - return status; + auto result = GetField(message_data, path); + if (result.ok()) { + return result; } path = GetExtensionPath(parent_type, extension_type, kGraphOptionsName, true); - status = GetField(path, message_data, result); - return status; + return GetField(message_data, path); } -// Reads a FieldData value from a protobuf field. -absl::Status GetField(const FieldPath& field_path, - const FieldData& message_data, FieldData* result) { +// Reads the FieldData values from a protobuf field. +absl::StatusOr> GetFieldValues( + const FieldData& message_data, const FieldPath& field_path) { + std::vector results; if (field_path.empty()) { - *result->mutable_message_value() = message_data.message_value(); - return absl::OkStatus(); + results.push_back(message_data); + return results; } FieldPathEntry head = field_path.front(); FieldPath tail = field_path; @@ -491,65 +370,101 @@ absl::Status GetField(const FieldPath& field_path, if (!head.extension_type.empty()) { MP_RETURN_IF_ERROR(FindExtension(message_data, &head)); } - if (tail.empty() && FieldCount(message_data, head.field) == 0) { - return absl::OkStatus(); - } - MP_RETURN_IF_ERROR(GetFieldValue(message_data, head, result)); + RET_CHECK_NE(head.field, nullptr); + ASSIGN_OR_RETURN(results, GetFieldValues(message_data, *head.field)); if (IsProtobufAny(head.field)) { - *result = ParseProtobufAny(*result); + for (int i = 0; i < results.size(); ++i) { + results[i] = ParseProtobufAny(results[i]); + } + } + int index = tail.empty() ? head.index : std::max(0, head.index); + if ((int)results.size() <= index) { + return absl::OutOfRangeError(absl::StrCat( + "Missing feild value: ", head.field ? head.field->name() : "#", + " at index: ", index)); } if (!tail.empty()) { - FieldData child = *result; - MP_RETURN_IF_ERROR(GetField(tail, child, result)); + FieldData child = results.at(index); + ASSIGN_OR_RETURN(results, GetFieldValues(child, tail)); + } else if (index > -1) { + FieldData child = results.at(index); + results.clear(); + results.push_back(child); } - return absl::OkStatus(); + return results; } -// Writes a FieldData value into protobuf field. -absl::Status SetField(const FieldPath& field_path, const FieldData& value, - FieldData* message_data) { +// Reads a FieldData value from a protobuf field. +absl::StatusOr GetField(const FieldData& message_data, + const FieldPath& field_path) { + std::vector results; + ASSIGN_OR_RETURN(results, GetFieldValues(message_data, field_path)); + if (results.empty()) { + FieldPathEntry tail = field_path.back(); + return absl::OutOfRangeError(absl::StrCat( + "Missing feild value: ", tail.field ? tail.field->name() : "##", + " at index: ", tail.index)); + } + return results[0]; +} + +// Writes FieldData values into protobuf field. +absl::Status SetFieldValues(FieldData& message_data, + const FieldPath& field_path, + const std::vector& values) { if (field_path.empty()) { - *message_data->mutable_message_value() = value.message_value(); + if (values.empty()) { + return absl::InvalidArgumentError("Missing feild value."); + } + message_data = values[0]; return absl::OkStatus(); } + FieldPathEntry head = field_path.front(); FieldPath tail = field_path; tail.erase(tail.begin()); if (!head.extension_type.empty()) { - MP_RETURN_IF_ERROR(FindExtension(*message_data, &head)); + MP_RETURN_IF_ERROR(FindExtension(message_data, &head)); } if (tail.empty()) { - MP_RETURN_IF_ERROR(SetFieldValue(head, value, message_data)); - } else { - FieldData child; - MP_RETURN_IF_ERROR(GetFieldValue(*message_data, head, &child)); - MP_RETURN_IF_ERROR(SetField(tail, value, &child)); - if (IsProtobufAny(head.field)) { - child = SerializeProtobufAny(child); - } - MP_RETURN_IF_ERROR(SetFieldValue(head, child, message_data)); + MP_RETURN_IF_ERROR(SetFieldValues(message_data, head, values)); + return absl::OkStatus(); + } + FieldData child; + MP_RETURN_IF_ERROR(GetFieldValue(message_data, head, &child)); + MP_RETURN_IF_ERROR(SetFieldValues(child, tail, values)); + if (IsProtobufAny(head.field)) { + child = SerializeProtobufAny(child); } + MP_RETURN_IF_ERROR(SetFieldValue(message_data, head, child)); return absl::OkStatus(); } -// Merges a packet value into nested protobuf Message. -absl::Status MergeField(const FieldPath& field_path, const FieldData& value, - FieldData* message_data) { +// Writes a FieldData value into protobuf field. +absl::Status SetField(FieldData& message_data, const FieldPath& field_path, + const FieldData& value) { + return SetFieldValues(message_data, field_path, {value}); +} + +// Merges FieldData values into nested protobuf Message. +// For each new field index, any previous value is merged with the new value. +absl::Status MergeFieldValues(FieldData& message_data, + const FieldPath& field_path, + const std::vector& values) { absl::Status status; - FieldType field_type = field_path.empty() - ? FieldType::TYPE_MESSAGE - : AsFieldType(field_path.back().field->type()); - std::string message_type = - (value.has_message_value()) - ? ParseTypeUrl(std::string(value.message_value().type_url())) - : ""; - FieldData v = value; + FieldType field_type = field_path.empty() ? FieldType::TYPE_MESSAGE + : field_path.back().field->type(); + std::vector results = values; + std::vector prevs; + ASSIGN_OR_RETURN(prevs, GetFieldValues(message_data, field_path)); if (field_type == FieldType::TYPE_MESSAGE) { - FieldData b; - status.Update(GetField(field_path, *message_data, &b)); - status.Update(MergeMessages(b, v, &v)); + for (int i = 0; i < std::min(values.size(), prevs.size()); ++i) { + FieldData& v = results[i]; + FieldData& b = prevs[i]; + ASSIGN_OR_RETURN(v, MergeMessages(b, v)); + } } - status.Update(SetField(field_path, v, message_data)); + status.Update(SetFieldValues(message_data, field_path, results)); return status; } @@ -576,34 +491,35 @@ struct ProtoEnum { int32 value; }; -absl::Status AsPacket(const FieldData& data, Packet* result) { +absl::StatusOr AsPacket(const FieldData& data) { + Packet result; switch (data.value_case()) { case FieldData::ValueCase::kInt32Value: - *result = MakePacket(data.int32_value()); + result = MakePacket(data.int32_value()); break; case FieldData::ValueCase::kInt64Value: - *result = MakePacket(data.int64_value()); + result = MakePacket(data.int64_value()); break; case FieldData::ValueCase::kUint32Value: - *result = MakePacket(data.uint32_value()); + result = MakePacket(data.uint32_value()); break; case FieldData::ValueCase::kUint64Value: - *result = MakePacket(data.uint64_value()); + result = MakePacket(data.uint64_value()); break; case FieldData::ValueCase::kDoubleValue: - *result = MakePacket(data.double_value()); + result = MakePacket(data.double_value()); break; case FieldData::ValueCase::kFloatValue: - *result = MakePacket(data.float_value()); + result = MakePacket(data.float_value()); break; case FieldData::ValueCase::kBoolValue: - *result = MakePacket(data.bool_value()); + result = MakePacket(data.bool_value()); break; case FieldData::ValueCase::kEnumValue: - *result = MakePacket(data.enum_value()); + result = MakePacket(data.enum_value()); break; case FieldData::ValueCase::kStringValue: - *result = MakePacket(data.string_value()); + result = MakePacket(data.string_value()); break; case FieldData::ValueCase::kMessageValue: { auto r = packet_internal::PacketFromDynamicProto( @@ -612,32 +528,33 @@ absl::Status AsPacket(const FieldData& data, Packet* result) { if (!r.ok()) { return r.status(); } - *result = r.value(); + result = r.value(); break; } case FieldData::VALUE_NOT_SET: - *result = Packet(); + result = Packet(); } - return absl::OkStatus(); + return result; } -absl::Status AsFieldData(Packet packet, FieldData* result) { - static const auto* kTypeIds = new std::map{ - {tool::GetTypeHash(), WireFormatLite::CPPTYPE_INT32}, - {tool::GetTypeHash(), WireFormatLite::CPPTYPE_INT64}, - {tool::GetTypeHash(), WireFormatLite::CPPTYPE_UINT32}, - {tool::GetTypeHash(), WireFormatLite::CPPTYPE_UINT64}, - {tool::GetTypeHash(), WireFormatLite::CPPTYPE_DOUBLE}, - {tool::GetTypeHash(), WireFormatLite::CPPTYPE_FLOAT}, - {tool::GetTypeHash(), WireFormatLite::CPPTYPE_BOOL}, - {tool::GetTypeHash(), WireFormatLite::CPPTYPE_ENUM}, - {tool::GetTypeHash(), WireFormatLite::CPPTYPE_STRING}, +absl::StatusOr AsFieldData(Packet packet) { + static const auto* kTypeIds = new std::map{ + {kTypeId, WireFormatLite::CPPTYPE_INT32}, + {kTypeId, WireFormatLite::CPPTYPE_INT64}, + {kTypeId, WireFormatLite::CPPTYPE_UINT32}, + {kTypeId, WireFormatLite::CPPTYPE_UINT64}, + {kTypeId, WireFormatLite::CPPTYPE_DOUBLE}, + {kTypeId, WireFormatLite::CPPTYPE_FLOAT}, + {kTypeId, WireFormatLite::CPPTYPE_BOOL}, + {kTypeId, WireFormatLite::CPPTYPE_ENUM}, + {kTypeId, WireFormatLite::CPPTYPE_STRING}, }; + FieldData result; if (packet.ValidateAsProtoMessageLite().ok()) { - result->mutable_message_value()->set_value( + result.mutable_message_value()->set_value( packet.GetProtoMessageLite().SerializeAsString()); - result->mutable_message_value()->set_type_url( + result.mutable_message_value()->set_type_url( TypeUrl(packet.GetProtoMessageLite().GetTypeName())); return absl::OkStatus(); } @@ -649,48 +566,42 @@ absl::Status AsFieldData(Packet packet, FieldData* result) { switch (kTypeIds->at(packet.GetTypeId())) { case WireFormatLite::CPPTYPE_INT32: - result->set_int32_value(packet.Get()); + result.set_int32_value(packet.Get()); break; case WireFormatLite::CPPTYPE_INT64: - result->set_int64_value(packet.Get()); + result.set_int64_value(packet.Get()); break; case WireFormatLite::CPPTYPE_UINT32: - result->set_uint32_value(packet.Get()); + result.set_uint32_value(packet.Get()); break; case WireFormatLite::CPPTYPE_UINT64: - result->set_uint64_value(packet.Get()); + result.set_uint64_value(packet.Get()); break; case WireFormatLite::CPPTYPE_DOUBLE: - result->set_double_value(packet.Get()); + result.set_double_value(packet.Get()); break; case WireFormatLite::CPPTYPE_FLOAT: - result->set_float_value(packet.Get()); + result.set_float_value(packet.Get()); break; case WireFormatLite::CPPTYPE_BOOL: - result->set_bool_value(packet.Get()); + result.set_bool_value(packet.Get()); break; case WireFormatLite::CPPTYPE_ENUM: - result->set_enum_value(packet.Get().value); + result.set_enum_value(packet.Get().value); break; case WireFormatLite::CPPTYPE_STRING: - result->set_string_value(packet.Get()); + result.set_string_value(packet.Get()); break; } - return absl::OkStatus(); + return result; } std::string TypeUrl(absl::string_view type_name) { - constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/"; - return absl::StrCat(std::string(kTypeUrlPrefix), std::string(type_name)); + return ProtoUtilLite::TypeUrl(type_name); } std::string ParseTypeUrl(absl::string_view type_url) { - constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/"; - if (std::string(type_url).rfind(kTypeUrlPrefix, 0) == 0) { - return std::string( - type_url.substr(kTypeUrlPrefix.length(), std::string::npos)); - } - return std::string(type_url); + return ProtoUtilLite::ParseTypeUrl(type_url); } } // namespace options_field_util diff --git a/mediapipe/framework/tool/options_field_util.h b/mediapipe/framework/tool/options_field_util.h index f3c82e95d2..ed5e011b27 100644 --- a/mediapipe/framework/tool/options_field_util.h +++ b/mediapipe/framework/tool/options_field_util.h @@ -34,30 +34,38 @@ absl::Status SetField(const FieldPath& field_path, const FieldData& value, FieldData* message_data); // Reads a field value from a protobuf field. -absl::Status GetField(const FieldPath& field_path, - const FieldData& message_data, FieldData* result); +absl::StatusOr GetField(const FieldData& message_data, + const FieldPath& field_path); -// Merges a field value into nested protobuf Message. -absl::Status MergeField(const FieldPath& field_path, const FieldData& value, - FieldData* message_data); +// Reads one or all FieldData values from a protobuf field. +absl::StatusOr> GetFieldValues( + const FieldData& message_data, const FieldPath& field_path); + +// Writes FieldData values into a protobuf field. +absl::Status SetFieldValues(FieldData& message_data, + const FieldPath& field_path, + const std::vector& values); + +// Merges FieldData values into a protobuf field. +absl::Status MergeFieldValues(FieldData& message_data, + const FieldPath& field_path, + const std::vector& values); // Deserializes a packet containing a MessageLite value. -absl::Status ReadMessage(const std::string& value, const std::string& type_name, - Packet* result); +absl::StatusOr ReadMessage(const std::string& value, + const std::string& type_name); // Merge two options protobuf field values. -absl::Status MergeMessages(const FieldData& base, const FieldData& over, - FieldData* result); +absl::StatusOr MergeMessages(const FieldData& base, + const FieldData& over); // Returns the requested options protobuf for a graph. -absl::Status GetNodeOptions(const FieldData& message_data, - const std::string& extension_type, - FieldData* result); +absl::StatusOr GetNodeOptions(const FieldData& message_data, + const std::string& extension_type); // Returns the requested options protobuf for a graph node. -absl::Status GetGraphOptions(const FieldData& message_data, - const std::string& extension_type, - FieldData* result); +absl::StatusOr GetGraphOptions(const FieldData& message_data, + const std::string& extension_type); // Sets the node_options field in a Node, and clears the options field. void SetOptionsMessage(const FieldData& node_options, @@ -67,10 +75,10 @@ void SetOptionsMessage(const FieldData& node_options, FieldData AsFieldData(const proto_ns::MessageLite& message); // Constructs a Packet for a FieldData proto. -absl::Status AsPacket(const FieldData& data, Packet* result); +absl::StatusOr AsPacket(const FieldData& data); // Constructs a FieldData proto for a Packet. -absl::Status AsFieldData(Packet packet, FieldData* result); +absl::StatusOr AsFieldData(Packet packet); // Returns the protobuf type-url for a protobuf type-name. std::string TypeUrl(absl::string_view type_name); diff --git a/mediapipe/framework/tool/options_lib_template.cc b/mediapipe/framework/tool/options_lib_template.cc index 6244bc4a38..5199496d85 100644 --- a/mediapipe/framework/tool/options_lib_template.cc +++ b/mediapipe/framework/tool/options_lib_template.cc @@ -25,11 +25,12 @@ constexpr char kDescriptorContents[] = #include "{{DESCRIPTOR_INC_FILE_PATH}}" ; // NOLINT(whitespace/semicolon) -mediapipe::proto_ns::FileDescriptorSet ParseFileDescriptorSet( - const std::string& pb) { - mediapipe::proto_ns::FileDescriptorSet files; - files.ParseFromString(pb); - return files; +mediapipe::FieldData ReadFileDescriptorSet(const std::string& pb) { + mediapipe::FieldData result; + *result.mutable_message_value()->mutable_type_url() = + "proto2.FileDescriptorSet"; + *result.mutable_message_value()->mutable_value() = pb; + return result; } } // namespace @@ -39,6 +40,6 @@ namespace mediapipe { template <> const RegistrationToken tool::OptionsRegistry::registration_token< MP_OPTION_TYPE_NS::MP_OPTION_TYPE_NAME> = - tool::OptionsRegistry::Register(ParseFileDescriptorSet( + tool::OptionsRegistry::Register(ReadFileDescriptorSet( std::string(kDescriptorContents, sizeof(kDescriptorContents) - 1))); } // namespace mediapipe diff --git a/mediapipe/framework/tool/options_map.h b/mediapipe/framework/tool/options_map.h index 023e1dfb03..782d0f2400 100644 --- a/mediapipe/framework/tool/options_map.h +++ b/mediapipe/framework/tool/options_map.h @@ -30,15 +30,26 @@ struct IsExtension { template ::value, int>::type = 0> -void GetExtension(const CalculatorOptions& options, T* result) { +T* GetExtension(CalculatorOptions& options) { if (options.HasExtension(T::ext)) { - *result = options.GetExtension(T::ext); + return options.MutableExtension(T::ext); } + return nullptr; } template ::value, int>::type = 0> -void GetExtension(const CalculatorOptions& options, T* result) {} +T* GetExtension(const CalculatorOptions& options) { + return nullptr; +} + +template +void GetExtension(const CalculatorOptions& options, T* result) { + T* r = GetExtension(*const_cast(&options)); + if (r) { + *result = *r; + } +} template void GetNodeOptions(const CalculatorGraphConfig::Node& node_config, T* result) { @@ -53,23 +64,39 @@ void GetNodeOptions(const CalculatorGraphConfig::Node& node_config, T* result) { #endif } +template +void SetNodeOptions(CalculatorGraphConfig::Node& node_config, const T& value) { +#if defined(MEDIAPIPE_PROTO_LITE) && defined(MEDIAPIPE_PROTO_THIRD_PARTY) + // protobuf::Any is unavailable with third_party/protobuf:protobuf-lite. +#else + for (mediapipe::protobuf::Any& options : + *node_config.mutable_node_options()) { + if (options.Is()) { + options.PackFrom(value); + return; + } + } + node_config.add_node_options()->PackFrom(value); +#endif +} + // A map from object type to object. class TypeMap { public: template bool Has() const { - return content_.count(TypeInfo::Get()) > 0; + return content_.count(kTypeId) > 0; } template T* Get() const { if (!Has()) { - content_[TypeInfo::Get()] = std::make_shared(); + content_[kTypeId] = std::make_shared(); } - return static_cast(content_[TypeInfo::Get()].get()); + return static_cast(content_[kTypeId].get()); } private: - mutable std::map> content_; + mutable std::map> content_; }; // Extracts the options message of a specified type from a @@ -77,7 +104,7 @@ class TypeMap { class OptionsMap { public: OptionsMap& Initialize(const CalculatorGraphConfig::Node& node_config) { - node_config_ = &node_config; + node_config_ = const_cast(&node_config); return *this; } @@ -97,10 +124,40 @@ class OptionsMap { return *result; } - const CalculatorGraphConfig::Node* node_config_; + CalculatorGraphConfig::Node* node_config_; TypeMap options_; }; +class MutableOptionsMap : public OptionsMap { + public: + MutableOptionsMap& Initialize(CalculatorGraphConfig::Node& node_config) { + node_config_ = &node_config; + return *this; + } + template + void Set(const T& value) const { + *options_.Get() = value; + if (node_config_->has_options()) { + *GetExtension(*node_config_->mutable_options()) = value; + } else { + SetNodeOptions(*node_config_, value); + } + } + + template + T* GetMutable() const { + if (options_.Has()) { + return options_.Get(); + } + if (node_config_->has_options()) { + return GetExtension(*node_config_->mutable_options()); + } + T* result = options_.Get(); + GetNodeOptions(*node_config_, result); + return result; + } +}; + } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/options_registry.cc b/mediapipe/framework/tool/options_registry.cc index b65cc9fedc..f6858be0ac 100644 --- a/mediapipe/framework/tool/options_registry.cc +++ b/mediapipe/framework/tool/options_registry.cc @@ -1,6 +1,11 @@ #include "mediapipe/framework/tool/options_registry.h" +#include +#include + #include "absl/synchronization/mutex.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/tool/proto_util_lite.h" namespace mediapipe { namespace tool { @@ -9,37 +14,135 @@ namespace { // Returns a canonical message type name, with any leading "." removed. std::string CanonicalTypeName(const std::string& type_name) { - return (type_name.rfind('.', 0) == 0) ? type_name.substr(1) : type_name; + return (absl::StartsWith(type_name, ".")) ? type_name.substr(1) : type_name; +} + +// Returns the values from a protobuf field as typed FieldData. +absl::StatusOr> GetFieldValues( + const FieldData& message_data, std::string field_name) { + std::string type_name = + ProtoUtilLite::ParseTypeUrl(message_data.message_value().type_url()); + const Descriptor* descriptor = + OptionsRegistry::GetProtobufDescriptor(type_name); + RET_CHECK_NE(descriptor, nullptr); + const FieldDescriptor* field = descriptor->FindFieldByName(field_name); + if (field == nullptr) { + return std::vector(); + } + ProtoUtilLite::ProtoPath proto_path = {{field->number(), 0}}; + ProtoUtilLite::FieldValue mesage_bytes = message_data.message_value().value(); + int count; + MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount(mesage_bytes, proto_path, + field->type(), &count)); + std::vector field_values; + MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange( + mesage_bytes, proto_path, count, field->type(), &field_values)); + std::vector result; + for (int i = 0; i < field_values.size(); ++i) { + FieldData r; + std::string message_type = + field->message_type() ? field->message_type()->full_name() : ""; + MP_RETURN_IF_ERROR(ProtoUtilLite::ReadValue(field_values[i], field->type(), + message_type, &r)); + result.push_back(std::move(r)); + } + return result; +} + +// Returns a single value from a protobuf string field. +std::string GetFieldString(const FieldData& message_data, + std::string field_name) { + auto values = GetFieldValues(message_data, field_name); + if (!values->empty()) { + return values->front().string_value(); + } + return ""; +} + +// Registers the descriptors for the descriptor protobufs. These four +// descriptors are required to deserialize descriptors for other protobufs. +// This implementation avoids a code size problem introduced by +// proto_ns::DescriptorProto. +void RegisterDescriptorProtos( + absl::flat_hash_map& result) { + std::vector descriptors = { + {"proto2.FileDescriptorSet", + { + {"file", 1, FieldType::TYPE_MESSAGE, "proto2.FileDescriptorProto"}, + }}, + {"proto2.FileDescriptorProto", + { + {"package", 2, FieldType::TYPE_STRING, ""}, + {"message_type", 4, FieldType::TYPE_MESSAGE, + "proto2.DescriptorProto"}, + }}, + {"proto2.DescriptorProto", + { + {"name", 1, FieldType::TYPE_STRING, ""}, + {"field", 2, FieldType::TYPE_MESSAGE, "proto2.FieldDescriptorProto"}, + {"extension", 6, FieldType::TYPE_MESSAGE, + "proto2.FieldDescriptorProto"}, + {"nested_type", 3, FieldType::TYPE_MESSAGE, + "proto2.DescriptorProto"}, + }}, + {"proto2.FieldDescriptorProto", + { + {"name", 1, FieldType::TYPE_STRING, ""}, + {"number", 3, FieldType::TYPE_INT32, ""}, + {"type", 5, FieldType::TYPE_ENUM, ""}, + {"type_name", 6, FieldType::TYPE_STRING, ""}, + {"extendee", 2, FieldType::TYPE_STRING, ""}, + }}, + }; + for (const auto& descriptor : descriptors) { + result[descriptor.full_name()] = descriptor; + } } } // namespace RegistrationToken OptionsRegistry::Register( - const proto_ns::FileDescriptorSet& files) { - absl::MutexLock lock(&mutex()); - for (auto& file : files.file()) { - for (auto& message_type : file.message_type()) { - Register(message_type, file.package()); + const FieldData& file_descriptor_set) { + auto files = GetFieldValues(file_descriptor_set, "file"); + for (auto& file : *files) { + std::string package_name = GetFieldString(file, "package"); + auto message_types = GetFieldValues(file, "message_type"); + for (auto& message_type : *message_types) { + Register(message_type, package_name); } } return RegistrationToken([]() {}); } -void OptionsRegistry::Register(const proto_ns::DescriptorProto& message_type, +void OptionsRegistry::Register(const FieldData& message_type, const std::string& parent_name) { - auto full_name = absl::StrCat(parent_name, ".", message_type.name()); - descriptors()[full_name] = Descriptor(message_type, full_name); - for (auto& nested : message_type.nested_type()) { + std::string name = GetFieldString(message_type, "name"); + std::string full_name = absl::StrCat(parent_name, ".", name); + Descriptor descriptor(full_name, message_type); + { + absl::MutexLock lock(&mutex()); + descriptors()[full_name] = descriptor; + } + auto nested_types = GetFieldValues(message_type, "nested_type"); + for (auto& nested : *nested_types) { Register(nested, full_name); } - for (auto& extension : message_type.extension()) { - extensions()[CanonicalTypeName(extension.extendee())].push_back( - FieldDescriptor(extension)); + auto exts = GetFieldValues(message_type, "extension"); + for (auto& extension : *exts) { + FieldDescriptor field(extension); + std::string extendee = GetFieldString(extension, "extendee"); + { + absl::MutexLock lock(&mutex()); + extensions()[CanonicalTypeName(extendee)].push_back(field); + } } } const Descriptor* OptionsRegistry::GetProtobufDescriptor( const std::string& type_name) { + if (descriptors().count("proto2.DescriptorProto") == 0) { + RegisterDescriptorProtos(descriptors()); + } absl::ReaderMutexLock lock(&mutex()); auto it = descriptors().find(CanonicalTypeName(type_name)); return (it == descriptors().end()) ? nullptr : &it->second; @@ -73,11 +176,21 @@ absl::Mutex& OptionsRegistry::mutex() { return *mutex; } -Descriptor::Descriptor(const proto_ns::DescriptorProto& proto, - const std::string& full_name) +Descriptor::Descriptor(const std::string& full_name, + const FieldData& descriptor_proto) + : full_name_(full_name) { + auto fields = GetFieldValues(descriptor_proto, "field"); + for (const auto& field : *fields) { + FieldDescriptor f(field); + fields_[f.name()] = f; + } +} + +Descriptor::Descriptor(const std::string& full_name, + const std::vector& fields) : full_name_(full_name) { - for (auto& field : proto.field()) { - fields_[field.name()] = FieldDescriptor(field); + for (const auto& field : fields) { + fields_[field.name()] = field; } } @@ -89,20 +202,22 @@ const FieldDescriptor* Descriptor::FindFieldByName( return (it != fields_.end()) ? &it->second : nullptr; } -FieldDescriptor::FieldDescriptor(const proto_ns::FieldDescriptorProto& proto) { - name_ = proto.name(); - message_type_ = CanonicalTypeName(proto.type_name()); - type_ = proto.type(); - number_ = proto.number(); +FieldDescriptor::FieldDescriptor(const FieldData& field_proto) { + name_ = GetFieldString(field_proto, "name"); + number_ = GetFieldValues(field_proto, "number")->front().int32_value(); + type_ = (FieldType)GetFieldValues(field_proto, "type")->front().enum_value(); + message_type_ = CanonicalTypeName(GetFieldString(field_proto, "type_name")); } +FieldDescriptor::FieldDescriptor(std::string name, int number, FieldType type, + std::string message_type) + : name_(name), number_(number), type_(type), message_type_(message_type) {} + const std::string& FieldDescriptor::name() const { return name_; } int FieldDescriptor::number() const { return number_; } -proto_ns::FieldDescriptorProto::Type FieldDescriptor::type() const { - return type_; -} +FieldType FieldDescriptor::type() const { return type_; } const Descriptor* FieldDescriptor::message_type() const { return OptionsRegistry::GetProtobufDescriptor(message_type_); diff --git a/mediapipe/framework/tool/options_registry.h b/mediapipe/framework/tool/options_registry.h index 34f04ede61..b843b113ac 100644 --- a/mediapipe/framework/tool/options_registry.h +++ b/mediapipe/framework/tool/options_registry.h @@ -1,15 +1,20 @@ #ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_ #define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_ +#include + #include "absl/container/flat_hash_map.h" #include "mediapipe/framework/deps/registration.h" #include "mediapipe/framework/port/advanced_proto_inc.h" +#include "mediapipe/framework/tool/field_data.pb.h" namespace mediapipe { namespace tool { class Descriptor; class FieldDescriptor; +using FieldType = mediapipe::proto_ns::internal::WireFormatLite::FieldType; +using mediapipe::FieldData; // A static registry that stores descriptors for protobufs used in MediaPipe // calculator options. Lite-proto builds do not normally include descriptors. @@ -17,8 +22,8 @@ class FieldDescriptor; // referenced and specified separately within CalculatorGraphConfigs. class OptionsRegistry { public: - // Registers the protobuf descriptors for a MessageLite. - static RegistrationToken Register(const proto_ns::FileDescriptorSet& files); + // Registers the protobuf descriptors for a FileDescriptorSet. + static RegistrationToken Register(const FieldData& file_descriptor_set); // Finds the descriptor for a protobuf. static const Descriptor* GetProtobufDescriptor(const std::string& type_name); @@ -28,8 +33,8 @@ class OptionsRegistry { std::vector* result); private: - // Registers protobuf descriptors a MessageLite and nested types. - static void Register(const proto_ns::DescriptorProto& message_type, + // Registers protobuf descriptors for a message type and nested types. + static void Register(const FieldData& message_type, const std::string& parent_name); static absl::flat_hash_map& descriptors(); @@ -46,9 +51,10 @@ class OptionsRegistry { // avoids a code size problem introduced by proto_ns::FieldDescriptor. class Descriptor { public: - Descriptor() {} - Descriptor(const proto_ns::DescriptorProto& proto, - const std::string& full_name); + Descriptor() = default; + Descriptor(const std::string& full_name, const FieldData& descriptor_proto); + Descriptor(const std::string& full_name, + const std::vector& fields); const std::string& full_name() const; const FieldDescriptor* FindFieldByName(const std::string& name) const; @@ -61,18 +67,20 @@ class Descriptor { // avoids a code size problem introduced by proto_ns::FieldDescriptor. class FieldDescriptor { public: - FieldDescriptor() {} - FieldDescriptor(const proto_ns::FieldDescriptorProto& proto); + FieldDescriptor() = default; + FieldDescriptor(const FieldData& field_proto); + FieldDescriptor(std::string name, int number, FieldType type, + std::string message_type); const std::string& name() const; int number() const; - proto_ns::FieldDescriptorProto::Type type() const; + FieldType type() const; const Descriptor* message_type() const; private: std::string name_; - std::string message_type_; - proto_ns::FieldDescriptorProto::Type type_; int number_; + FieldType type_; + std::string message_type_; }; } // namespace tool diff --git a/mediapipe/framework/tool/options_syntax_util.cc b/mediapipe/framework/tool/options_syntax_util.cc index e51b0ac599..8a1eb3d9a0 100644 --- a/mediapipe/framework/tool/options_syntax_util.cc +++ b/mediapipe/framework/tool/options_syntax_util.cc @@ -91,8 +91,7 @@ class OptionsSyntaxUtil::OptionsSyntaxHelper { int index; if (absl::SimpleAtoi(option_name, &index)) { result.back().index = index; - } - if (!ExtensionType(option_name).empty()) { + } else if (!ExtensionType(option_name).empty()) { std::string extension_type = std::string(ExtensionType(option_name)); result.push_back({nullptr, 0, extension_type}); descriptor = OptionsRegistry::GetProtobufDescriptor(extension_type); @@ -102,7 +101,7 @@ class OptionsSyntaxUtil::OptionsSyntaxHelper { } auto field = descriptor->FindFieldByName(std::string(option_name)); descriptor = field ? field->message_type() : nullptr; - result.push_back({std::move(field), 0}); + result.push_back({std::move(field), -1}); } } return result; diff --git a/mediapipe/framework/tool/options_util.cc b/mediapipe/framework/tool/options_util.cc index 30aa8f88d6..80b2f9d15e 100644 --- a/mediapipe/framework/tool/options_util.cc +++ b/mediapipe/framework/tool/options_util.cc @@ -26,10 +26,9 @@ namespace mediapipe { namespace tool { using options_field_util::FieldPath; -using options_field_util::GetField; using options_field_util::GetGraphOptions; using options_field_util::GetNodeOptions; -using options_field_util::MergeField; +using options_field_util::MergeFieldValues; using options_field_util::MergeMessages; // Returns the type for the root options message if specified. @@ -56,10 +55,19 @@ std::string MessageType(FieldData message) { std::string(message.message_value().type_url())); } +// Assigns the value from a StatusOr if avialable. +#define ASSIGN_IF_OK(lhs, rexpr) \ + { \ + auto statusor = (rexpr); \ + if (statusor.ok()) { \ + lhs = statusor.value(); \ + } \ + } + // Copy literal options from graph_options to node_options. absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node, CalculatorGraphConfig* config) { - Status status; + absl::Status status; FieldData graph_data = options_field_util::AsFieldData(*config); FieldData parent_data = options_field_util::AsFieldData(parent_node); @@ -75,25 +83,26 @@ absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node, std::string node_tag = syntax_util.OptionFieldsTag(tag_and_name[0]); std::string node_extension_type = ExtensionType(node_tag); FieldData graph_options; - GetGraphOptions(graph_data, graph_extension_type, &graph_options) - .IgnoreError(); + ASSIGN_IF_OK(graph_options, + GetGraphOptions(graph_data, graph_extension_type)); FieldData parent_options; - GetNodeOptions(parent_data, graph_extension_type, &parent_options) - .IgnoreError(); - status.Update( - MergeMessages(graph_options, parent_options, &graph_options)); + ASSIGN_IF_OK(parent_options, + GetNodeOptions(parent_data, graph_extension_type)); + ASSIGN_OR_RETURN(graph_options, + MergeMessages(graph_options, parent_options)); FieldData node_options; - status.Update( - GetNodeOptions(node_data, node_extension_type, &node_options)); + ASSIGN_OR_RETURN(node_options, + GetNodeOptions(node_data, node_extension_type)); if (!node_options.has_message_value() || !graph_options.has_message_value()) { continue; } FieldPath graph_path = GetPath(graph_tag, MessageType(graph_options)); FieldPath node_path = GetPath(node_tag, MessageType(node_options)); - FieldData packet_data; - status.Update(GetField(graph_path, graph_options, &packet_data)); - status.Update(MergeField(node_path, packet_data, &node_options)); + std::vector packet_data; + ASSIGN_OR_RETURN(packet_data, GetFieldValues(graph_options, graph_path)); + MP_RETURN_IF_ERROR( + MergeFieldValues(node_options, node_path, packet_data)); options_field_util::SetOptionsMessage(node_options, &node); } node.clear_option_value(); @@ -105,7 +114,7 @@ absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node, absl::Status DefineGraphOptions(const CalculatorGraphConfig::Node& parent_node, CalculatorGraphConfig* config) { MP_RETURN_IF_ERROR(CopyLiteralOptions(parent_node, config)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace tool diff --git a/mediapipe/framework/tool/options_util_test.cc b/mediapipe/framework/tool/options_util_test.cc index 870865b1f4..ad9bc9d423 100644 --- a/mediapipe/framework/tool/options_util_test.cc +++ b/mediapipe/framework/tool/options_util_test.cc @@ -13,8 +13,10 @@ // limitations under the License. #include +#include #include +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/gmock.h" @@ -30,23 +32,27 @@ namespace mediapipe { namespace { -using ::mediapipe::proto_ns::FieldDescriptorProto; using FieldType = ::mediapipe::proto_ns::FieldDescriptorProto::Type; +using ::testing::HasSubstr; + +// Assigns the value from a StatusOr if avialable. +#define ASSERT_AND_ASSIGN(lhs, rexpr) \ + { \ + auto statusor = (rexpr); \ + MP_ASSERT_OK(statusor); \ + lhs = statusor.value(); \ + } // A test Calculator using DeclareOptions and DefineOptions. class NightLightCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } - absl::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } - absl::Status Process(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); } private: NightLightCalculatorOptions options_; @@ -124,7 +130,7 @@ TEST_F(OptionsUtilTest, CopyLiteralOptions) { CalculatorGraph graph; graph_config.set_num_threads(4); - MP_EXPECT_OK(graph.Initialize({subgraph_config, graph_config}, {}, {})); + MP_ASSERT_OK(graph.Initialize({subgraph_config, graph_config}, {}, {})); CalculatorGraphConfig expanded_config = graph.Config(); expanded_config.clear_executor(); @@ -236,8 +242,8 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) { tool::options_field_util::FieldPath field_path = syntax_util.OptionFieldPath(split[1], descriptor); EXPECT_EQ(field_path.size(), 2); - EXPECT_TRUE(Equals(field_path[0], "sub_options", 0, "")); - EXPECT_TRUE(Equals(field_path[1], "num_lights", 0, "")); + EXPECT_TRUE(Equals(field_path[0], "sub_options", -1, "")); + EXPECT_TRUE(Equals(field_path[1], "num_lights", -1, "")); { // NightLightCalculatorOptions in Node.options. @@ -252,11 +258,11 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) { auto path = field_path; std::string node_extension_type = ExtensionType(std::string(split[1])); FieldData node_options; - MP_EXPECT_OK(tool::options_field_util::GetNodeOptions( - node_data, node_extension_type, &node_options)); + ASSERT_AND_ASSIGN(node_options, tool::options_field_util::GetNodeOptions( + node_data, node_extension_type)); FieldData packet_data; - MP_EXPECT_OK(tool::options_field_util::GetField(field_path, node_options, - &packet_data)); + ASSERT_AND_ASSIGN(packet_data, tool::options_field_util::GetField( + node_options, field_path)); EXPECT_EQ(packet_data.value_case(), FieldData::kInt32Value); EXPECT_EQ(packet_data.int32_value(), 33); } @@ -273,11 +279,11 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) { auto path = field_path; std::string node_extension_type = ExtensionType(std::string(split[1])); FieldData node_options; - MP_EXPECT_OK(tool::options_field_util::GetNodeOptions( - node_data, node_extension_type, &node_options)); + ASSERT_AND_ASSIGN(node_options, tool::options_field_util::GetNodeOptions( + node_data, node_extension_type)); FieldData packet_data; - MP_EXPECT_OK(tool::options_field_util::GetField(field_path, node_options, - &packet_data)); + ASSERT_AND_ASSIGN(packet_data, tool::options_field_util::GetField( + node_options, field_path)); EXPECT_EQ(packet_data.value_case(), FieldData::kInt32Value); EXPECT_EQ(packet_data.int32_value(), 33); } @@ -285,5 +291,333 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) { // TODO: Test with specified extension_type. } +// Constructs the field path for a string of field names. +FieldPath MakeFieldPath(std::string tag, FieldData message_data) { + tool::OptionsSyntaxUtil syntax_util; + const tool::Descriptor* descriptor = + tool::OptionsRegistry::GetProtobufDescriptor( + tool::options_field_util::ParseTypeUrl( + message_data.message_value().type_url())); + return syntax_util.OptionFieldPath(tag, descriptor); +} + +// Returns the field path addressing the entire specified field. +FieldPath EntireField(FieldPath field_path) { + field_path.back().index = -1; + return field_path; +} + +// Converts an int to a FieldData record. +FieldData AsFieldData(int v) { + return tool::options_field_util::AsFieldData(MakePacket(v)).value(); +} + +// Equality comparison for field contents. +template +absl::Status Equals(const T& v1, const T& v2) { + RET_CHECK_EQ(v1, v2); + return absl::OkStatus(); +} + +// Equality comparison for protobuf field contents. +// The generic Equals() fails because MessageLite lacks operator==(). +// The protobuf comparison is performed using testing::EqualsProto. +using LightBundle = NightLightCalculatorOptions::LightBundle; +template <> +absl::Status Equals(const LightBundle& v1, const LightBundle& v2) { + std::string s_1, s_2; + v1.SerializeToString(&s_1); + v2.SerializeToString(&s_2); + RET_CHECK(s_1 == s_2); + return absl::OkStatus(); +} + +// Equality comparison for FieldData vectors. +template +absl::Status Equals(std::vector b1, std::vector b2) { + using tool::options_field_util::AsPacket; + RET_CHECK_EQ(b1.size(), b2.size()); + for (int i = 0; i < b1.size(); ++i) { + ASSIGN_OR_RETURN(Packet p1, AsPacket(b1.at(i))); + ASSIGN_OR_RETURN(Packet p2, AsPacket(b2.at(i))); + MP_RETURN_IF_ERROR(Equals(p1.Get(), p2.Get())); + } + return absl::OkStatus(); +} + +// Unit-tests for graph options feild accessors from options_field_util. +class OptionsFieldUtilTest : public ::testing::Test { + protected: + void SetUp() override {} + void TearDown() override {} +}; + +// Tests empty FieldPaths applied to empty options. +TEST_F(OptionsFieldUtilTest, EmptyFieldPaths) { + FieldData graph_options; + FieldData node_options; + FieldPath graph_path; + FieldPath node_path; + std::vector packet_data; + ASSERT_AND_ASSIGN(packet_data, GetFieldValues(graph_options, graph_path)); + MP_EXPECT_OK(MergeFieldValues(node_options, node_path, packet_data)); +} + +// Tests GetFieldValues applied to an int field. +TEST_F(OptionsFieldUtilTest, GetFieldValuesInt) { + NightLightCalculatorOptions node_proto; + node_proto.mutable_sub_options(); + node_proto.mutable_sub_options()->add_num_lights(33); + node_proto.mutable_sub_options()->add_num_lights(44); + FieldData node_data = tool::options_field_util::AsFieldData(node_proto); + + // Read an entire populated repeated field. + FieldPath path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data); + MP_EXPECT_OK(Equals(GetFieldValues(node_data, path).value(), + {AsFieldData(33), AsFieldData(44)})); + + // Read a specific populated repeated field index. + path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data); + MP_EXPECT_OK( + Equals(GetFieldValues(node_data, path).value(), {AsFieldData(44)})); +} + +// Tests GetFieldValues applied to a protobuf field. +TEST_F(OptionsFieldUtilTest, GetFieldValuesProtobuf) { + using tool::options_field_util::AsFieldData; + using LightBundle = NightLightCalculatorOptions::LightBundle; + NightLightCalculatorOptions node_proto; + node_proto.mutable_sub_options(); + node_proto.mutable_sub_options()->add_bundle(); + *node_proto.mutable_sub_options()->mutable_bundle(0)->mutable_room_id() = + "111"; + node_proto.mutable_sub_options() + ->mutable_bundle(0) + ->add_room_lights() + ->set_frame_rate(11.1); + node_proto.mutable_sub_options() + ->mutable_bundle(0) + ->add_room_lights() + ->set_frame_rate(22.1); + FieldData node_data = AsFieldData(node_proto); + + // Read all values from a repeated protobuf field. + LightBundle expected_proto; + *expected_proto.mutable_room_id() = "111"; + expected_proto.add_room_lights()->set_frame_rate(11.1); + expected_proto.add_room_lights()->set_frame_rate(22.1); + FieldData expected_data = AsFieldData(expected_proto); + FieldPath path = MakeFieldPath("OPTIONS/sub_options/bundle", node_data); + MP_EXPECT_OK(Equals(GetFieldValues(node_data, path).value(), + {expected_data})); + + // Read a specific index from a repeated protobuf field. + path = MakeFieldPath("OPTIONS/sub_options/bundle/0", node_data); + MP_EXPECT_OK(Equals(GetFieldValues(node_data, path).value(), + {expected_data})); +} + +// Tests SetFieldValues applied to an int field. +TEST_F(OptionsFieldUtilTest, SetFieldValuesInt) { + NightLightCalculatorOptions node_proto; + node_proto.mutable_sub_options(); + FieldData node_data = tool::options_field_util::AsFieldData(node_proto); + + // Replace an entire empty repeated field. + FieldPath path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data); + MP_ASSERT_OK(SetFieldValues(node_data, path, {AsFieldData(33)})); + MP_EXPECT_OK( + Equals(GetFieldValues(node_data, path).value(), {AsFieldData(33)})); + + // Replace an entire populated repeated field. + MP_ASSERT_OK(SetFieldValues(node_data, path, {AsFieldData(44)})); + MP_EXPECT_OK( + Equals(GetFieldValues(node_data, path).value(), {AsFieldData(44)})); + + // Replace an entire repeated field with a new list of values. + MP_ASSERT_OK( + SetFieldValues(node_data, path, {AsFieldData(33), AsFieldData(44)})); + MP_EXPECT_OK(Equals(GetFieldValues(node_data, path).value(), + {AsFieldData(33), AsFieldData(44)})); + + // Replace a single field index with a new list of values. + path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data); + MP_ASSERT_OK( + SetFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)})); + MP_EXPECT_OK( + Equals(GetFieldValues(node_data, EntireField(path)).value(), + {AsFieldData(33), AsFieldData(55), AsFieldData(66)})); + + // Replace a single field middle index with a new list of values. + path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data); + MP_ASSERT_OK( + SetFieldValues(node_data, path, {AsFieldData(11), AsFieldData(12)})); + MP_EXPECT_OK(Equals( + GetFieldValues(node_data, EntireField(path)).value(), + {AsFieldData(33), AsFieldData(11), AsFieldData(12), AsFieldData(66)})); + + // Replace field index 0 with a new value. + path = MakeFieldPath("OPTIONS/sub_options/num_lights/0", node_data); + MP_ASSERT_OK(SetFieldValues(node_data, path, {AsFieldData(77)})); + MP_EXPECT_OK(Equals( + GetFieldValues(node_data, EntireField(path)).value(), + {AsFieldData(77), AsFieldData(11), AsFieldData(12), AsFieldData(66)})); + + // Replace field index 0 with an empty list of values. + MP_ASSERT_OK(SetFieldValues(node_data, path, {})); + MP_EXPECT_OK( + Equals(GetFieldValues(node_data, EntireField(path)).value(), + {AsFieldData(11), AsFieldData(12), AsFieldData(66)})); + + // Replace an entire populated field with an empty list of values. + path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data); + MP_ASSERT_OK(SetFieldValues(node_data, path, {})); + MP_ASSERT_OK( + Equals(GetFieldValues(node_data, EntireField(path)).value(), {})); + + // Replace a missing field index with new values. + path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data); + absl::Status status = + SetFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)}); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + // TODO: status.message() appears empty on KokoroGCPDocker. + // EXPECT_THAT(status.message(), + // HasSubstr("index >= 0 && index <= v.size()")); +} + +// Tests SetFieldValues applied to a protobuf field. +TEST_F(OptionsFieldUtilTest, SetFieldValuesProtobuf) { + using tool::options_field_util::AsFieldData; + using LightBundle = NightLightCalculatorOptions::LightBundle; + NightLightCalculatorOptions node_proto; + node_proto.mutable_sub_options(); + FieldData node_data = AsFieldData(node_proto); + + // Replace an empty repeated protobuf field. + LightBundle bundle_proto; + *bundle_proto.mutable_room_id() = "222"; + bundle_proto.add_room_lights()->set_frame_rate(22.1); + FieldData bundle_data = AsFieldData(bundle_proto); + FieldData expected_data = bundle_data; + FieldPath path = MakeFieldPath("OPTIONS/sub_options/bundle", node_data); + MP_ASSERT_OK(SetFieldValues(node_data, path, {bundle_data})); + MP_EXPECT_OK(Equals( + GetFieldValues(node_data, EntireField(path)).value(), {expected_data})); + + // Replace a populated repeated protobuf field. + *bundle_proto.mutable_room_id() = "333"; + bundle_proto.mutable_room_lights(0)->set_frame_rate(33.1); + bundle_data = AsFieldData(bundle_proto); + LightBundle expected_proto; + *expected_proto.mutable_room_id() = "333"; + expected_proto.add_room_lights()->set_frame_rate(33.1); + expected_data = AsFieldData(expected_proto); + MP_ASSERT_OK(SetFieldValues(node_data, path, {bundle_data})); + MP_EXPECT_OK(Equals( + GetFieldValues(node_data, EntireField(path)).value(), {expected_data})); +} + +// Tests MergeFieldValues applied to an int field. +TEST_F(OptionsFieldUtilTest, MergeFieldValuesInt) { + NightLightCalculatorOptions node_proto; + node_proto.mutable_sub_options(); + FieldData node_data = tool::options_field_util::AsFieldData(node_proto); + + // Replace an entire empty repeated field. + FieldPath path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data); + MP_ASSERT_OK(MergeFieldValues(node_data, path, {AsFieldData(33)})); + MP_EXPECT_OK( + Equals(GetFieldValues(node_data, path).value(), {AsFieldData(33)})); + + // Replace an entire populated repeated field. + MP_ASSERT_OK(MergeFieldValues(node_data, path, {AsFieldData(44)})); + MP_EXPECT_OK( + Equals(GetFieldValues(node_data, path).value(), {AsFieldData(44)})); + + // Replace an entire repeated field with a new list of values. + MP_ASSERT_OK( + MergeFieldValues(node_data, path, {AsFieldData(33), AsFieldData(44)})); + MP_EXPECT_OK(Equals(GetFieldValues(node_data, path).value(), + {AsFieldData(33), AsFieldData(44)})); + + // Replace a singe field index with a new list of values. + path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data); + MP_ASSERT_OK( + MergeFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)})); + MP_EXPECT_OK( + Equals(GetFieldValues(node_data, EntireField(path)).value(), + {AsFieldData(33), AsFieldData(55), AsFieldData(66)})); + + // Replace a single field middle index with a new list of values. + path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data); + MP_ASSERT_OK( + MergeFieldValues(node_data, path, {AsFieldData(11), AsFieldData(12)})); + MP_EXPECT_OK(Equals( + GetFieldValues(node_data, EntireField(path)).value(), + {AsFieldData(33), AsFieldData(11), AsFieldData(12), AsFieldData(66)})); + + // Replace field index 0 with a new value. + path = MakeFieldPath("OPTIONS/sub_options/num_lights/0", node_data); + MP_ASSERT_OK(MergeFieldValues(node_data, path, {AsFieldData(77)})); + MP_EXPECT_OK(Equals( + GetFieldValues(node_data, EntireField(path)).value(), + {AsFieldData(77), AsFieldData(11), AsFieldData(12), AsFieldData(66)})); + + // Replace field index 0 with an empty list of values. + MP_ASSERT_OK(MergeFieldValues(node_data, path, {})); + MP_EXPECT_OK( + Equals(GetFieldValues(node_data, EntireField(path)).value(), + {AsFieldData(11), AsFieldData(12), AsFieldData(66)})); + + // Replace an entire populated field with an empty list of values. + path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data); + MP_ASSERT_OK(MergeFieldValues(node_data, path, {})); + MP_EXPECT_OK( + Equals(GetFieldValues(node_data, EntireField(path)).value(), {})); + + // Replace a missing field index with new values. + path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data); + absl::Status status = + MergeFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)}); + EXPECT_EQ(status.code(), absl::StatusCode::kOutOfRange); + EXPECT_THAT(status.message(), + HasSubstr("Missing feild value: num_lights at index: 1")); +} + +// Tests MergeFieldValues applied to a protobuf field. +TEST_F(OptionsFieldUtilTest, MergeFieldValuesProtobuf) { + using tool::options_field_util::AsFieldData; + using LightBundle = NightLightCalculatorOptions::LightBundle; + NightLightCalculatorOptions node_proto; + node_proto.mutable_sub_options(); + FieldData node_data = AsFieldData(node_proto); + + // Merge an empty repeated protobuf field. + LightBundle bundle_proto; + *bundle_proto.mutable_room_id() = "222"; + bundle_proto.add_room_lights()->set_frame_rate(22.1); + FieldData bundle_data = AsFieldData(bundle_proto); + FieldData expected_data = bundle_data; + FieldPath path = MakeFieldPath("OPTIONS/sub_options/bundle", node_data); + MP_ASSERT_OK(MergeFieldValues(node_data, path, {bundle_data})); + MP_EXPECT_OK(Equals( + GetFieldValues(node_data, EntireField(path)).value(), {expected_data})); + + // Merge a populated repeated protobuf field. + // "LightBundle.room_id" merges to "333". + // "LightBundle.room_lights" merges to {{22.1}, {33.1}}. + *bundle_proto.mutable_room_id() = "333"; + bundle_proto.mutable_room_lights(0)->set_frame_rate(33.1); + bundle_data = AsFieldData(bundle_proto); + LightBundle expected_proto; + *expected_proto.mutable_room_id() = "333"; + expected_proto.add_room_lights()->set_frame_rate(22.1); + expected_proto.add_room_lights()->set_frame_rate(33.1); + expected_data = AsFieldData(expected_proto); + MP_ASSERT_OK(MergeFieldValues(node_data, path, {bundle_data})); + MP_EXPECT_OK(Equals( + GetFieldValues(node_data, EntireField(path)).value(), {expected_data})); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index b9649ce5b0..4628815eab 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -16,11 +16,13 @@ #include +#include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/tool/field_data.pb.h" #include "mediapipe/framework/type_map.h" #define RET_CHECK_NO_LOG(cond) RET_CHECK(cond).SetNoLogging() @@ -37,6 +39,7 @@ using FieldAccess = ProtoUtilLite::FieldAccess; using FieldValue = ProtoUtilLite::FieldValue; using ProtoPath = ProtoUtilLite::ProtoPath; using FieldType = ProtoUtilLite::FieldType; +using mediapipe::FieldData; // Returns true if a wire type includes a length indicator. bool IsLengthDelimited(WireFormatLite::WireType wire_type) { @@ -408,5 +411,149 @@ absl::Status ProtoUtilLite::Deserialize( return absl::OkStatus(); } +absl::Status ProtoUtilLite::WriteValue(const FieldData& value, + FieldType field_type, + std::string* field_bytes) { + StringOutputStream sos(field_bytes); + CodedOutputStream out(&sos); + switch (field_type) { + case WireFormatLite::TYPE_INT32: + WireFormatLite::WriteInt32NoTag(value.int32_value(), &out); + break; + case WireFormatLite::TYPE_SINT32: + WireFormatLite::WriteSInt32NoTag(value.int32_value(), &out); + break; + case WireFormatLite::TYPE_INT64: + WireFormatLite::WriteInt64NoTag(value.int64_value(), &out); + break; + case WireFormatLite::TYPE_SINT64: + WireFormatLite::WriteSInt64NoTag(value.int64_value(), &out); + break; + case WireFormatLite::TYPE_UINT32: + WireFormatLite::WriteUInt32NoTag(value.uint32_value(), &out); + break; + case WireFormatLite::TYPE_UINT64: + WireFormatLite::WriteUInt64NoTag(value.uint64_value(), &out); + break; + case WireFormatLite::TYPE_DOUBLE: + WireFormatLite::WriteDoubleNoTag(value.uint64_value(), &out); + break; + case WireFormatLite::TYPE_FLOAT: + WireFormatLite::WriteFloatNoTag(value.float_value(), &out); + break; + case WireFormatLite::TYPE_BOOL: + WireFormatLite::WriteBoolNoTag(value.bool_value(), &out); + break; + case WireFormatLite::TYPE_ENUM: + WireFormatLite::WriteEnumNoTag(value.enum_value(), &out); + break; + case WireFormatLite::TYPE_STRING: + out.WriteString(value.string_value()); + break; + case WireFormatLite::TYPE_MESSAGE: + out.WriteString(value.message_value().value()); + break; + default: + return absl::UnimplementedError( + absl::StrCat("Cannot write type: ", field_type)); + } + return absl::OkStatus(); +} + +template +static ValueT ReadValue(absl::string_view field_bytes, absl::Status* status) { + ArrayInputStream ais(field_bytes.data(), field_bytes.size()); + CodedInputStream input(&ais); + ValueT result; + if (!WireFormatLite::ReadPrimitive(&input, &result)) { + status->Update(absl::InvalidArgumentError(absl::StrCat( + "Bad serialized value: ", MediaPipeTypeStringOrDemangled(), + "."))); + } + return result; +} + +absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type, + absl::string_view message_type, FieldData* result) { + absl::Status status; + result->Clear(); + switch (field_type) { + case WireFormatLite::TYPE_INT32: + result->set_int32_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_SINT32: + result->set_int32_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_INT64: + result->set_int64_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_SINT64: + result->set_int64_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_UINT32: + result->set_uint32_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_UINT64: + result->set_uint64_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_DOUBLE: + result->set_double_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_FLOAT: + result->set_float_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_BOOL: + result->set_bool_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_ENUM: + result->set_enum_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_STRING: + result->set_string_value(std::string(field_bytes)); + break; + case WireFormatLite::TYPE_MESSAGE: + result->mutable_message_value()->set_value(std::string(field_bytes)); + result->mutable_message_value()->set_type_url( + ProtoUtilLite::TypeUrl(message_type)); + break; + default: + status = absl::UnimplementedError( + absl::StrCat("Cannot read type: ", field_type)); + break; + } + return status; +} + +absl::Status ProtoUtilLite::ReadValue(absl::string_view field_bytes, + FieldType field_type, + absl::string_view message_type, + FieldData* result) { + return mediapipe::tool::ReadValue(field_bytes, field_type, message_type, + result); +} + +std::string ProtoUtilLite::TypeUrl(absl::string_view type_name) { + constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/"; + return absl::StrCat(std::string(kTypeUrlPrefix), std::string(type_name)); +} + +std::string ProtoUtilLite::ParseTypeUrl(absl::string_view type_url) { + constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/"; + if (absl::StartsWith(std::string(type_url), std::string(kTypeUrlPrefix))) { + return std::string(type_url.substr(kTypeUrlPrefix.length())); + } + return std::string(type_url); +} + } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index 71221291fa..7d3a263f37 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -23,10 +23,12 @@ #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/field_data.pb.h" namespace mediapipe { namespace tool { +// TODO: Replace this class with a namespace following Google style. class ProtoUtilLite { public: // Defines field types and tag formats. @@ -89,6 +91,23 @@ class ProtoUtilLite { static absl::Status Deserialize(const std::vector& field_values, FieldType field_type, std::vector* result); + + // Write a protobuf field value from a typed FieldData value. + static absl::Status WriteValue(const mediapipe::FieldData& value, + FieldType field_type, + std::string* field_bytes); + + // Read a protobuf field value into a typed FieldData value. + static absl::Status ReadValue(absl::string_view field_bytes, + FieldType field_type, + absl::string_view message_type, + mediapipe::FieldData* result); + + // Returns the protobuf type-url for a protobuf type-name. + static std::string TypeUrl(absl::string_view type_name); + + // Returns the protobuf type-name for a protobuf type-url. + static std::string ParseTypeUrl(absl::string_view type_url); }; } // namespace tool diff --git a/mediapipe/framework/tool/status_util.cc b/mediapipe/framework/tool/status_util.cc index 57faa3899f..0e3a592462 100644 --- a/mediapipe/framework/tool/status_util.cc +++ b/mediapipe/framework/tool/status_util.cc @@ -59,7 +59,8 @@ absl::Status CombinedStatus(const std::string& general_comment, } } if (error_code == StatusCode::kOk) return OkStatus(); - Status combined = absl::Status( + Status combined; + combined = absl::Status( error_code, absl::StrCat(general_comment, "\n", absl::StrJoin(errors, "\n"))); return combined; diff --git a/mediapipe/framework/tool/status_util_test.cc b/mediapipe/framework/tool/status_util_test.cc index c7e845aa6c..005ee26342 100644 --- a/mediapipe/framework/tool/status_util_test.cc +++ b/mediapipe/framework/tool/status_util_test.cc @@ -28,8 +28,11 @@ namespace mediapipe { namespace { using testing::ContainerEq; +using testing::Eq; using testing::HasSubstr; using testing::IsEmpty; +using testing::Matches; +using testing::Pointwise; TEST(StatusTest, StatusStopIsNotOk) { EXPECT_FALSE(tool::StatusStop().ok()); } diff --git a/mediapipe/framework/tool/subgraph_expansion.cc b/mediapipe/framework/tool/subgraph_expansion.cc index 354c1fd0a1..9f81153f1d 100644 --- a/mediapipe/framework/tool/subgraph_expansion.cc +++ b/mediapipe/framework/tool/subgraph_expansion.cc @@ -293,7 +293,7 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, if (subgraph_nodes_start == nodes->end()) break; std::vector subgraphs; for (auto it = subgraph_nodes_start; it != nodes->end(); ++it) { - const auto& node = *it; + auto& node = *it; int node_id = it - nodes->begin(); std::string node_name = CanonicalNodeName(*config, node_id); MP_RETURN_IF_ERROR(ValidateSubgraphFields(node)); diff --git a/mediapipe/framework/tool/type_util.h b/mediapipe/framework/tool/type_util.h index 4157389e43..9c955f2a36 100644 --- a/mediapipe/framework/tool/type_util.h +++ b/mediapipe/framework/tool/type_util.h @@ -16,79 +16,129 @@ #define MEDIAPIPE_FRAMEWORK_TOOL_TYPE_UTIL_H_ #include +#include #include +#include +#include "absl/base/attributes.h" +#include "mediapipe/framework/demangle.h" #include "mediapipe/framework/port.h" namespace mediapipe { -namespace tool { -#if !MEDIAPIPE_HAS_RTTI -// A unique identifier for type T. -class TypeInfo { +// An identifier for a type. This class is lightweight and is meant to be passed +// by value. +// To get the TypeId for SomeType, write kTypeId. +class TypeId { public: - size_t hash_code() const { return reinterpret_cast(this); } - bool operator==(const TypeInfo& other) const { return &other == this; } - bool operator<(const TypeInfo& other) const { return &other < this; } - const char* name() const { return ""; } - template - static const TypeInfo& Get() { - static TypeInfo* static_type_info = new TypeInfo; - return *static_type_info; + size_t hash_code() const { return impl_.hash_code(); } + std::string name() const { return impl_.name(); } + bool operator==(const TypeId& other) const { return impl_ == other.impl_; } + bool operator<(const TypeId& other) const { return impl_ < other.impl_; } + + template + friend H AbslHashValue(H h, const TypeId& r) { + return H::combine(std::move(h), r.hash_code()); + } + + template + static constexpr inline TypeId Of() { + return TypeId{Impl::Get()}; } private: - TypeInfo() {} - TypeInfo(const TypeInfo&) = delete; -}; + // This implementation uses no RTTI. It distinguishes types, but does not + // know their names. + // TODO: record compile-time type string for (some or all) types. + template + struct TypeTag { + static constexpr char dummy = 0; + }; + struct NoRttiImpl { + template + static constexpr inline NoRttiImpl Get() { + return {&TypeTag::dummy}; + } + size_t hash_code() const { return reinterpret_cast(tag_); } + std::string name() const { return ""; } + bool operator==(const NoRttiImpl& other) const { + return tag_ == other.tag_; + } + bool operator<(const NoRttiImpl& other) const { return tag_ < other.tag_; } -#else // MEDIAPIPE_HAS_RTTI -// The std unique identifier for type T. -class TypeInfo { - public: - size_t hash_code() const { return info_.hash_code(); } - bool operator==(const TypeInfo& o) const { return info_ == o.info_; } - bool operator<(const TypeInfo& o) const { return info_.before(o.info_); } - const char* name() const { return info_.name(); } - template - static const TypeInfo& Get() { - static TypeInfo* static_type_info = new TypeInfo(typeid(T)); - return *static_type_info; + const void* tag_; + }; + +#if MEDIAPIPE_HAS_RTTI + template + static const std::type_info& GetTypeInfo() { + return typeid(T); } + // This implementation uses RTTI, and delegates all operations to + // std::type_info. In order to support constexpr construction, we don't store + // a type_info directly (which is not constexpr), but a pointer to a function + // returning it (which is). This implementation is a bit slower than the + // others. The only potential advantage would be the ability to match types + // across multiple dynamic libraries, but we don't support that setup anyway. + // This is provided for completeness. + struct FullRttiImpl { + template + static constexpr inline FullRttiImpl Get() { + return {GetTypeInfo}; + } + size_t hash_code() const { return get_().hash_code(); } + std::string name() const { return Demangle(get_().name()); } + bool operator==(const FullRttiImpl& other) const { + return get_ == other.get_ || get_() == other.get_(); + } + bool operator<(const FullRttiImpl& other) const { + return get_().before(other.get_()); + } - private: - TypeInfo(const std::type_info& info) : info_(info) {} - TypeInfo(const TypeInfo&) = delete; + decltype(&GetTypeInfo) get_; + }; - private: - const std::type_info& info_; - friend class TypeIndex; -}; -#endif + // This implementation also stores a pointer to a std::type_info getter + // function, but it only invokes it to get the type's name. It's equivalent to + // NoRttiImpl for most operations, but it allows getting the type's name. + struct FastRttiImpl { + template + static constexpr inline FastRttiImpl Get() { + return {GetTypeInfo}; + } + size_t hash_code() const { return reinterpret_cast(get_); } + std::string name() const { return Demangle(get_().name()); } + bool operator==(const FastRttiImpl& other) const { + return get_ == other.get_; + } + bool operator<(const FastRttiImpl& other) const { + return reinterpret_cast(get_) < + reinterpret_cast(other.get_); + } -// An associative key for TypeInfo. -class TypeIndex { - public: - TypeIndex(const TypeInfo& info) : info_(info) {} - size_t hash_code() const { return info_.hash_code(); } - bool operator==(const TypeIndex& other) const { return info_ == other.info_; } - bool operator<(const TypeIndex& other) const { return info_ < other.info_; } + decltype(&GetTypeInfo) get_; + }; - private: - const TypeInfo& info_; + using Impl = FastRttiImpl; +#else + using Impl = NoRttiImpl; +#endif // MEDIAPIPE_HAS_RTTI + constexpr explicit TypeId(Impl impl) : impl_(impl) {} + + Impl impl_; }; -// Helper method that returns a hash code of the given type. This allows for -// typeid testing across multiple binaries, unlike FastTypeId which used a -// memory location that only works within the same binary. Moreover, we use this -// for supporting multiple .so binaries in a single Android app built using the -// same compiler and C++ libraries. -// Note that std::type_info may still generate the same hash code for different -// types, although the c++ standard recommends that implementations avoid this -// as much as possible. +template +static constexpr TypeId kTypeId = TypeId::Of(); + +namespace tool { + +// Helper method that returns a hash code of the given type. +// Superseded by TypeId. template +ABSL_DEPRECATED("Use TypeId directly instead.") size_t GetTypeHash() { - return TypeInfo::Get().hash_code(); + return kTypeId.hash_code(); } } // namespace tool diff --git a/mediapipe/framework/type_map.h b/mediapipe/framework/type_map.h index 0b11959443..e26efa039c 100644 --- a/mediapipe/framework/type_map.h +++ b/mediapipe/framework/type_map.h @@ -361,32 +361,30 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string); // End define MEDIAPIPE_REGISTER_TYPE_WITH_PROXY. // Helper functions's to retrieve registration data. -inline const std::string* MediaPipeTypeStringFromTypeId(const size_t type_id) { +inline const std::string* MediaPipeTypeStringFromTypeId(TypeId type_id) { const MediaPipeTypeData* value = - PacketTypeIdToMediaPipeTypeData::GetValue(type_id); + PacketTypeIdToMediaPipeTypeData::GetValue(type_id.hash_code()); return (value) ? &value->type_string : nullptr; } // Returns string identifier of type or NULL if not registered. template inline const std::string* MediaPipeTypeString() { - return MediaPipeTypeStringFromTypeId(tool::GetTypeHash()); + return MediaPipeTypeStringFromTypeId(kTypeId); } -inline std::string MediaPipeTypeStringOrDemangled( - const tool::TypeInfo& type_info) { - const std::string* type_string = - MediaPipeTypeStringFromTypeId(type_info.hash_code()); +inline std::string MediaPipeTypeStringOrDemangled(TypeId type_id) { + const std::string* type_string = MediaPipeTypeStringFromTypeId(type_id); if (type_string) { return *type_string; } else { - return mediapipe::Demangle(type_info.name()); + return type_id.name(); } } template std::string MediaPipeTypeStringOrDemangled() { - return MediaPipeTypeStringOrDemangled(tool::TypeInfo::Get()); + return MediaPipeTypeStringOrDemangled(kTypeId); } // Returns type hash id of type identified by type_string or NULL if not diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 8057acec60..16aad6e9bb 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -14,6 +14,8 @@ #include "mediapipe/framework/validated_graph_config.h" +#include + #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -140,35 +142,6 @@ absl::Status AddPredefinedExecutorConfigs(CalculatorGraphConfig* graph_config) { return absl::OkStatus(); } -absl::Status PerformBasicTransforms( - const CalculatorGraphConfig& input_graph_config, - const GraphRegistry* graph_registry, - const Subgraph::SubgraphOptions* graph_options, - const GraphServiceManager* service_manager, - CalculatorGraphConfig* output_graph_config) { - *output_graph_config = input_graph_config; - MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(output_graph_config, graph_registry, - graph_options, service_manager)); - - MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(output_graph_config)); - - // Populate each node with the graph level input stream handler if a - // stream handler wasn't explicitly provided. - // TODO Instead of pre-populating, handle the graph level - // default appropriately within CalculatorGraph. - if (output_graph_config->has_input_stream_handler()) { - const auto& graph_level_input_stream_handler = - output_graph_config->input_stream_handler(); - for (auto& node : *output_graph_config->mutable_node()) { - if (!node.has_input_stream_handler()) { - *node.mutable_input_stream_handler() = graph_level_input_stream_handler; - } - } - } - - return absl::OkStatus(); -} - } // namespace // static @@ -346,8 +319,7 @@ absl::Status NodeTypeInfo::Initialize( } absl::Status ValidatedGraphConfig::Initialize( - const CalculatorGraphConfig& input_config, - const GraphRegistry* graph_registry, + CalculatorGraphConfig input_config, const GraphRegistry* graph_registry, const Subgraph::SubgraphOptions* graph_options, const GraphServiceManager* service_manager) { RET_CHECK(!initialized_) @@ -358,9 +330,9 @@ absl::Status ValidatedGraphConfig::Initialize( << input_config.DebugString(); #endif - MP_RETURN_IF_ERROR(PerformBasicTransforms( - input_config, graph_registry, graph_options, service_manager, &config_)); - + config_ = std::move(input_config); + MP_RETURN_IF_ERROR( + PerformBasicTransforms(graph_registry, graph_options, service_manager)); // Initialize the basic node information. MP_RETURN_IF_ERROR(InitializeGeneratorInfo()); MP_RETURN_IF_ERROR(InitializeCalculatorInfo()); @@ -441,7 +413,12 @@ absl::Status ValidatedGraphConfig::Initialize( const GraphServiceManager* service_manager) { graph_registry = graph_registry ? graph_registry : &GraphRegistry::global_graph_registry; - SubgraphContext subgraph_context(graph_options, service_manager); + Subgraph::SubgraphOptions local_graph_options; + if (graph_options) { + local_graph_options = *graph_options; + } + SubgraphContext subgraph_context = + SubgraphContext(&local_graph_options, service_manager); auto status_or_config = graph_registry->CreateByName("", graph_type, &subgraph_context); MP_RETURN_IF_ERROR(status_or_config.status()); @@ -466,6 +443,32 @@ absl::Status ValidatedGraphConfig::Initialize( service_manager); } +absl::Status ValidatedGraphConfig::PerformBasicTransforms( + const GraphRegistry* graph_registry, + const Subgraph::SubgraphOptions* graph_options, + const GraphServiceManager* service_manager) { + MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(&config_, graph_registry, + graph_options, service_manager)); + + MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(&config_)); + + // Populate each node with the graph level input stream handler if a + // stream handler wasn't explicitly provided. + // TODO Instead of pre-populating, handle the graph level + // default appropriately within CalculatorGraph. + if (config_.has_input_stream_handler()) { + const auto& graph_level_input_stream_handler = + config_.input_stream_handler(); + for (auto& node : *config_.mutable_node()) { + if (!node.has_input_stream_handler()) { + *node.mutable_input_stream_handler() = graph_level_input_stream_handler; + } + } + } + + return absl::OkStatus(); +} + absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() { std::vector statuses; calculators_.reserve(config_.node_size()); @@ -690,6 +693,7 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode( if (!need_sorting_ptr) { LOG(WARNING) << "Input Stream \"" << name << "\" for node with sorted index " << node_index + << " name " << node_type_info->Contract().GetNodeName() << " is marked as a back edge, but its output stream is " "already available. This means it was not necessary " "to mark it as a back edge."; @@ -701,6 +705,7 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode( if (edge_info.back_edge) { VLOG(1) << "Encountered expected behavior: the back edge \"" << name << "\" for node with (possibly sorted) index " << node_index + << " name " << node_type_info->Contract().GetNodeName() << " has an output stream which we have not yet seen."; } else if (need_sorting_ptr) { *need_sorting_ptr = true; @@ -709,7 +714,9 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode( } else { return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Input Stream \"" << name << "\" for node with sorted index " - << node_index << " does not have a corresponding output stream."; + << node_index << " name " + << node_type_info->Contract().GetNodeName() + << " does not have a corresponding output stream."; } } diff --git a/mediapipe/framework/validated_graph_config.h b/mediapipe/framework/validated_graph_config.h index aee605f980..11f9553cd2 100644 --- a/mediapipe/framework/validated_graph_config.h +++ b/mediapipe/framework/validated_graph_config.h @@ -195,7 +195,7 @@ class ValidatedGraphConfig { // before any other functions. Subgraphs are specified through the // global graph registry or an optional local graph registry. absl::Status Initialize( - const CalculatorGraphConfig& input_config, + CalculatorGraphConfig input_config, const GraphRegistry* graph_registry = nullptr, const Subgraph::SubgraphOptions* graph_options = nullptr, const GraphServiceManager* service_manager = nullptr); @@ -302,6 +302,13 @@ class ValidatedGraphConfig { } private: + // Perform transforms such as converting legacy features, expanding + // subgraphs, and popluting input stream handler. + absl::Status PerformBasicTransforms( + const GraphRegistry* graph_registry, + const Subgraph::SubgraphOptions* graph_options, + const GraphServiceManager* service_manager); + // Initialize the PacketGenerator information. absl::Status InitializeGeneratorInfo(); // Initialize the Calculator information. diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 8c9c433b04..de9b755a9b 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -53,6 +53,12 @@ cc_library( deps = ["//mediapipe/framework:graph_service"], ) +cc_library( + name = "attachments", + hdrs = ["attachments.h"], + visibility = ["//visibility:public"], +) + GL_BASE_LINK_OPTS = select({ "//conditions:default": [], "//mediapipe:android": [ @@ -172,6 +178,7 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + ":attachments", ":gl_base", ":gl_thread_collector", ":gpu_buffer_format", diff --git a/mediapipe/gpu/MPPGraphGPUData.h b/mediapipe/gpu/MPPGraphGPUData.h index 4745026190..3d8fc0c949 100644 --- a/mediapipe/gpu/MPPGraphGPUData.h +++ b/mediapipe/gpu/MPPGraphGPUData.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_ -#define MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_ +#ifndef MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ +#define MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ #import #import @@ -68,4 +68,4 @@ class GpuBufferMultiPool; @end -#endif // MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_ +#endif // MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ diff --git a/mediapipe/gpu/attachments.h b/mediapipe/gpu/attachments.h new file mode 100644 index 0000000000..ca9f074c46 --- /dev/null +++ b/mediapipe/gpu/attachments.h @@ -0,0 +1,64 @@ +#ifndef MEDIAPIPE_GPU_ATTACHMENTS_H_ +#define MEDIAPIPE_GPU_ATTACHMENTS_H_ + +#include +#include + +namespace mediapipe { +namespace internal { + +// Unique pointer with a type-erased destructor. +template +using AttachmentPtr = std::unique_ptr>; + +// Like make_unique. +template +static std::enable_if_t::value, AttachmentPtr> +MakeAttachmentPtr(Args&&... args) { + return {new T(std::forward(args)...), + [](void* ptr) { delete static_cast(ptr); }}; +} + +template +class AttachmentBase {}; + +// An cacheable resource that can be associated with a context. +// Attachments are defined as constants. +// When access to an attachment is requested, it will be retrieved from the +// context if already created, or the factory function will be invoked to create +// it. The factory function for a given attachment is invoked at most once per +// context. The lifetime of the object it returns is managed by the context. +template +class Attachment : public AttachmentBase { + public: + using FactoryT = std::function(Context&)>; + Attachment(FactoryT factory) : factory_(factory) {} + + Attachment(const Attachment&) = delete; + Attachment(Attachment&&) = delete; + Attachment& operator=(const Attachment&) = delete; + Attachment& operator=(Attachment&&) = delete; + + T& Get(Context& ctx) const { return ctx.GetCachedAttachment(*this); } + + const FactoryT& factory() const { return factory_; } + + // Ptr and MakePtr here make it more convenient to define new types of + // attachment contexts, since you only need a using declaration for Attachment + // and can refer to Ptr from it. + using Ptr = AttachmentPtr; + + template + inline static std::enable_if_t::value, AttachmentPtr> + MakePtr(Args&&... args) { + return MakeAttachmentPtr(std::forward(args)...); + } + + private: + FactoryT factory_; +}; + +} // namespace internal +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_ATTACHMENTS_H_ diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 7f1fbbbdcd..9b40310f04 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -29,6 +29,7 @@ #include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/threadpool.h" #include "mediapipe/framework/timestamp.h" +#include "mediapipe/gpu/attachments.h" #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gpu_buffer_format.h" @@ -286,42 +287,15 @@ class GlContext : public std::enable_shared_from_this { // Sets default texture filtering parameters. void SetStandardTextureParams(GLenum target, GLint internal_format); + using AttachmentBase = internal::AttachmentBase; template - using AttachmentPtr = std::unique_ptr>; - - template - static std::enable_if_t::value, AttachmentPtr> - MakeAttachmentPtr(Args&&... args) { - return {new T(std::forward(args)...), - [](void* ptr) { delete static_cast(ptr); }}; - } - - class AttachmentBase {}; - - template - class Attachment : public AttachmentBase { - public: - using FactoryT = std::function(GlContext&)>; - Attachment(FactoryT factory) : factory_(factory) {} - - Attachment(const Attachment&) = delete; - Attachment(Attachment&&) = delete; - Attachment& operator=(const Attachment&) = delete; - Attachment& operator=(Attachment&&) = delete; - - T& Get(GlContext& ctx) const { return ctx.GetCachedAttachment(*this); } - - const FactoryT& factory() const { return factory_; } - - private: - FactoryT factory_; - }; + using Attachment = internal::Attachment; // TOOD: const result? template T& GetCachedAttachment(const Attachment& attachment) { DCHECK(IsCurrent()); - AttachmentPtr& entry = attachments_[&attachment]; + internal::AttachmentPtr& entry = attachments_[&attachment]; if (entry == nullptr) { entry = attachment.factory()(*this); } @@ -454,7 +428,8 @@ class GlContext : public std::enable_shared_from_this { // better mechanism? bool can_linear_filter_float_textures_; - absl::flat_hash_map> attachments_; + absl::flat_hash_map> + attachments_; // Number of glFinish calls completed on the GL thread. // Changes should be guarded by mutex_. However, we use simple atomic diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index bb215dbbd2..e899fc85dd 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -11,7 +11,7 @@ namespace mediapipe { internal::GpuBufferStorage& GpuBuffer::GetStorageForView( - TypeRef view_provider_type, bool for_writing) const { + TypeId view_provider_type, bool for_writing) const { const std::shared_ptr* chosen_storage = nullptr; // First see if any current storage supports the view. diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 47f334a242..88bff7e1f2 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -130,8 +130,6 @@ class GpuBuffer { } private: - using TypeRef = internal::TypeRef; - class PlaceholderGpuBufferStorage : public internal::GpuBufferStorageImpl { public: @@ -147,14 +145,13 @@ class GpuBuffer { GpuBufferFormat format_ = GpuBufferFormat::kUnknown; }; - internal::GpuBufferStorage& GetStorageForView(TypeRef view_provider_type, + internal::GpuBufferStorage& GetStorageForView(TypeId view_provider_type, bool for_writing) const; template internal::ViewProvider* GetViewProvider(bool for_writing) const { using VP = internal::ViewProvider; - return GetStorageForView(TypeRef::Get(), for_writing) - .template down_cast(); + return GetStorageForView(kTypeId, for_writing).template down_cast(); } std::shared_ptr& no_storage() const { diff --git a/mediapipe/gpu/gpu_buffer_storage.cc b/mediapipe/gpu/gpu_buffer_storage.cc index 2f06876532..e525fe94f1 100644 --- a/mediapipe/gpu/gpu_buffer_storage.cc +++ b/mediapipe/gpu/gpu_buffer_storage.cc @@ -8,14 +8,14 @@ using StorageConverter = GpuBufferStorageRegistry::StorageConverter; using RegistryToken = GpuBufferStorageRegistry::RegistryToken; StorageFactory GpuBufferStorageRegistry::StorageFactoryForViewProvider( - TypeRef view_provider_type) { + TypeId view_provider_type) { auto it = factory_for_view_provider_.find(view_provider_type); if (it == factory_for_view_provider_.end()) return nullptr; return it->second; } StorageConverter GpuBufferStorageRegistry::StorageConverterForViewProvider( - TypeRef view_provider_type, TypeRef existing_storage_type) { + TypeId view_provider_type, TypeId existing_storage_type) { auto it = converter_for_view_provider_and_existing_storage_.find( {view_provider_type, existing_storage_type}); if (it == converter_for_view_provider_and_existing_storage_.end()) @@ -24,7 +24,7 @@ StorageConverter GpuBufferStorageRegistry::StorageConverterForViewProvider( } RegistryToken GpuBufferStorageRegistry::Register( - StorageFactory factory, std::vector provider_hashes) { + StorageFactory factory, std::vector provider_hashes) { // TODO: choose between multiple factories for same provider type. for (const auto p : provider_hashes) { factory_for_view_provider_[p] = factory; @@ -33,8 +33,8 @@ RegistryToken GpuBufferStorageRegistry::Register( } RegistryToken GpuBufferStorageRegistry::Register( - StorageConverter converter, std::vector provider_hashes, - TypeRef source_storage) { + StorageConverter converter, std::vector provider_hashes, + TypeId source_storage) { // TODO: choose between multiple converters for same provider type. for (const auto p : provider_hashes) { converter_for_view_provider_and_existing_storage_[{p, source_storage}] = diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 059448683e..3d872eb664 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -1,9 +1,11 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_ +#include #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "mediapipe/framework/deps/no_destructor.h" @@ -20,31 +22,6 @@ struct types {}; template class ViewProvider; -// An identifier for a type. We have often used size_t holding a hash for this -// purpose in MediaPipe, but a non-primitive type makes the code more readable. -// Ideally we should clean up the various ways this is handled throughout the -// framework and consolidate the utilities in type_util. When that is done, this -// type can be replaced. -class TypeRef { - public: - template - static TypeRef Get() { - return TypeRef{tool::GetTypeHash()}; - } - - bool operator==(const TypeRef& other) const { return hash_ == other.hash_; } - - template - friend H AbslHashValue(H h, const TypeRef& r) { - return H::combine(std::move(h), r.hash_); - } - - private: - explicit TypeRef(size_t hash) : hash_(hash) {} - - size_t hash_; -}; - // Interface for a backing storage for GpuBuffer. class GpuBufferStorage { public: @@ -56,18 +33,18 @@ class GpuBufferStorage { // The public methods delegate to the type-erased private virtual method. template T* down_cast() { - return static_cast(const_cast(down_cast(TypeRef::Get()))); + return static_cast(const_cast(down_cast(kTypeId))); } template const T* down_cast() const { - return static_cast(down_cast(TypeRef::Get())); + return static_cast(down_cast(kTypeId)); } - bool can_down_cast_to(TypeRef to) const { return down_cast(to) != nullptr; } - virtual TypeRef storage_type() const = 0; + bool can_down_cast_to(TypeId to) const { return down_cast(to) != nullptr; } + virtual TypeId storage_type() const = 0; private: - virtual const void* down_cast(TypeRef to) const = 0; + virtual const void* down_cast(TypeId to) const = 0; }; // Used to disambiguate between overloads by manually specifying their priority. @@ -113,18 +90,18 @@ class GpuBufferStorageRegistry { -> std::shared_ptr { return converter(std::static_pointer_cast(source)); }, - StorageTo::GetProviderTypes(), TypeRef::Get()); + StorageTo::GetProviderTypes(), kTypeId); } // Returns a factory function for a storage that implements // view_provider_type. - StorageFactory StorageFactoryForViewProvider(TypeRef view_provider_type); + StorageFactory StorageFactoryForViewProvider(TypeId view_provider_type); // Returns a conversion function that, given a storage of // existing_storage_type, converts its contents to a new storage that // implements view_provider_type. StorageConverter StorageConverterForViewProvider( - TypeRef view_provider_type, TypeRef existing_storage_type); + TypeId view_provider_type, TypeId existing_storage_type); private: template @@ -139,13 +116,13 @@ class GpuBufferStorageRegistry { } RegistryToken Register(StorageFactory factory, - std::vector provider_hashes); + std::vector provider_hashes); RegistryToken Register(StorageConverter converter, - std::vector provider_hashes, - TypeRef source_storage); + std::vector provider_hashes, + TypeId source_storage); - absl::flat_hash_map factory_for_view_provider_; - absl::flat_hash_map, StorageConverter> + absl::flat_hash_map factory_for_view_provider_; + absl::flat_hash_map, StorageConverter> converter_for_view_provider_and_existing_storage_; }; @@ -166,21 +143,21 @@ struct ForceStaticInstantiation { template class GpuBufferStorageImpl : public GpuBufferStorage, public U... { public: - static const std::vector& GetProviderTypes() { - static std::vector kHashes{TypeRef::Get()...}; + static const std::vector& GetProviderTypes() { + static std::vector kHashes{kTypeId...}; return kHashes; } private: - virtual const void* down_cast(TypeRef to) const override { + virtual const void* down_cast(TypeId to) const override { return down_cast_impl(to, types{}); } - TypeRef storage_type() const override { return TypeRef::Get(); } + TypeId storage_type() const override { return kTypeId; } - const void* down_cast_impl(TypeRef to, types<>) const { return nullptr; } + const void* down_cast_impl(TypeId to, types<>) const { return nullptr; } template - const void* down_cast_impl(TypeRef to, types) const { - if (to == TypeRef::Get()) return static_cast(this); + const void* down_cast_impl(TypeId to, types) const { + if (to == kTypeId) return static_cast(this); return down_cast_impl(to, types{}); } diff --git a/mediapipe/graphs/face_effect/subgraphs/single_face_geometry_from_landmarks_gpu.pbtxt b/mediapipe/graphs/face_effect/subgraphs/single_face_geometry_from_landmarks_gpu.pbtxt index 364e386549..ec9dce7a34 100644 --- a/mediapipe/graphs/face_effect/subgraphs/single_face_geometry_from_landmarks_gpu.pbtxt +++ b/mediapipe/graphs/face_effect/subgraphs/single_face_geometry_from_landmarks_gpu.pbtxt @@ -74,7 +74,7 @@ node { # Puts the single set of smoothed landmarks back into a collection to simplify # passing the result into the `FaceGeometryFromLandmarks` subgraph. node { - calculator: "ConcatenateLandmarListVectorCalculator" + calculator: "ConcatenateNormalizedLandmarkListVectorCalculator" input_stream: "smoothed_face_landmarks" output_stream: "multi_smoothed_face_landmarks" } diff --git a/mediapipe/java/com/google/mediapipe/components/GlSurfaceViewRenderer.java b/mediapipe/java/com/google/mediapipe/components/GlSurfaceViewRenderer.java index 44af3f5859..7a6c547a20 100644 --- a/mediapipe/java/com/google/mediapipe/components/GlSurfaceViewRenderer.java +++ b/mediapipe/java/com/google/mediapipe/components/GlSurfaceViewRenderer.java @@ -221,8 +221,7 @@ public void setNextFrame(TextureFrame frame) { Matrix.setIdentityM(textureTransformMatrix, 0 /* offset */); } TextureFrame oldFrame = nextFrame.getAndSet(frame); - if (oldFrame != null - && (frame == null || (oldFrame.getTextureName() != frame.getTextureName()))) { + if (oldFrame != null && oldFrame != frame) { oldFrame.release(); } surfaceTexture = null; diff --git a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java index b3290f70e1..b724c6eaee 100644 --- a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java @@ -27,14 +27,22 @@ public class GraphTextureFrame implements TextureFrame { private int width; private int height; private long timestamp = Long.MIN_VALUE; + // True when created with PacketGetter.getTextureFrameDeferredSync(). This will result in gpuWait + // when calling getTextureName(). + private final boolean deferredSync; GraphTextureFrame(long nativeHandle, long timestamp) { + this(nativeHandle, timestamp, false); + } + + GraphTextureFrame(long nativeHandle, long timestamp, boolean deferredSync) { nativeBufferHandle = nativeHandle; // TODO: use a single JNI call to fill in all info textureName = nativeGetTextureName(nativeBufferHandle); width = nativeGetWidth(nativeBufferHandle); height = nativeGetHeight(nativeBufferHandle); this.timestamp = timestamp; + this.deferredSync = deferredSync; } /** @@ -42,13 +50,22 @@ public class GraphTextureFrame implements TextureFrame { * *

Note: if this texture has been obtained using getTextureFrameDeferredWait, a GPU wait on the * producer sync will be done here. That means this method should be called on the GL context that - * will actually use the texture. + * will actually use the texture. Note that in this case, it is also susceptible to a race + * condition if release() is called after the if-check for nativeBufferHandle is already passed. */ @Override public int getTextureName() { - // Note that, if a CPU wait has already been done, the sync point will have been - // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. - nativeGpuWait(nativeBufferHandle); + // Return special texture id 0 if handle is 0 i.e. frame is already released. + if (nativeBufferHandle == 0) { + return 0; + } + // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using + // PacketGetter.getTextureFrameDeferredSync(). + if (deferredSync) { + // Note that, if a CPU wait has already been done, the sync point will have been + // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. + nativeGpuWait(nativeBufferHandle); + } return textureName; } diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java index f22de08dce..7e66e0b75b 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java @@ -316,7 +316,7 @@ public static GraphTextureFrame getTextureFrame(final Packet packet) { public static GraphTextureFrame getTextureFrameDeferredSync(final Packet packet) { return new GraphTextureFrame( nativeGetGpuBuffer(packet.getNativeHandle(), /* waitOnCpu= */ false), - packet.getTimestamp()); + packet.getTimestamp(), /* deferredSync= */true); } private static native long nativeGetPacketFromReference(long nativePacketHandle); diff --git a/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetection.java b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetection.java index be9be1f332..37aca84843 100644 --- a/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetection.java +++ b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetection.java @@ -16,7 +16,6 @@ import android.content.Context; import com.google.common.collect.ImmutableList; -import com.google.mediapipe.formats.proto.DetectionProto.Detection; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.solutioncore.ErrorListener; @@ -24,6 +23,7 @@ import com.google.mediapipe.solutioncore.OutputHandler; import com.google.mediapipe.solutioncore.ResultListener; import com.google.mediapipe.solutioncore.SolutionInfo; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; import java.util.HashMap; import java.util.Map; import javax.annotation.Nullable; diff --git a/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetectionResult.java b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetectionResult.java index d665a95f6a..413095f6e6 100644 --- a/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetectionResult.java +++ b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetectionResult.java @@ -17,10 +17,10 @@ import android.graphics.Bitmap; import com.google.auto.value.AutoBuilder; import com.google.common.collect.ImmutableList; -import com.google.mediapipe.formats.proto.DetectionProto.Detection; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.TextureFrame; import com.google.mediapipe.solutioncore.ImageSolutionResult; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; import java.util.List; /** diff --git a/mediapipe/modules/face_detection/BUILD b/mediapipe/modules/face_detection/BUILD index b1cddeb6f5..84c9388ea7 100644 --- a/mediapipe/modules/face_detection/BUILD +++ b/mediapipe/modules/face_detection/BUILD @@ -16,6 +16,8 @@ load( "//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_simple_subgraph", ) +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") +load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") licenses(["notice"]) @@ -26,7 +28,7 @@ mediapipe_simple_subgraph( graph = "face_detection_short_range_by_roi_cpu.pbtxt", register_as = "FaceDetectionShortRangeByRoiCpu", deps = [ - ":face_detection_short_range_common", + ":face_detection_short_range", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/util:to_image_calculator", @@ -38,7 +40,7 @@ mediapipe_simple_subgraph( graph = "face_detection_short_range_by_roi_gpu.pbtxt", register_as = "FaceDetectionShortRangeByRoiGpu", deps = [ - ":face_detection_short_range_common", + ":face_detection_short_range", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/util:to_image_calculator", @@ -50,10 +52,7 @@ mediapipe_simple_subgraph( graph = "face_detection_short_range_cpu.pbtxt", register_as = "FaceDetectionShortRangeCpu", deps = [ - ":face_detection_short_range_common", - "//mediapipe/calculators/tensor:image_to_tensor_calculator", - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/calculators/util:to_image_calculator", + ":face_detection_short_range", ], ) @@ -62,22 +61,66 @@ mediapipe_simple_subgraph( graph = "face_detection_short_range_gpu.pbtxt", register_as = "FaceDetectionShortRangeGpu", deps = [ - ":face_detection_short_range_common", - "//mediapipe/calculators/tensor:image_to_tensor_calculator", - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/calculators/util:to_image_calculator", + ":face_detection_short_range", + ], +) + +mediapipe_simple_subgraph( + name = "face_detection_short_range", + graph = "face_detection_short_range.pbtxt", + register_as = "FaceDetectionShortRange", + deps = [ + ":face_detection", ], ) mediapipe_simple_subgraph( - name = "face_detection_short_range_common", - graph = "face_detection_short_range_common.pbtxt", - register_as = "FaceDetectionShortRangeCommon", + name = "face_detection_full_range", + graph = "face_detection_full_range.pbtxt", + register_as = "FaceDetectionFullRange", deps = [ + ":face_detection", + ], +) + +mediapipe_simple_subgraph( + name = "face_detection_without_roi", + graph = "face_detection_without_roi.pbtxt", + register_as = "FaceDetectionWithoutRoi", + deps = [ + ":face_detection", + ], +) + +mediapipe_simple_subgraph( + name = "face_detection", + graph = "face_detection.pbtxt", + register_as = "FaceDetection", + deps = [ + ":face_detection_cc_proto", + ":face_detection_options_lib", + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:tensors_to_detections_calculator", "//mediapipe/calculators/tflite:ssd_anchors_calculator", "//mediapipe/calculators/util:detection_projection_calculator", "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:to_image_calculator", + ], +) + +mediapipe_proto_library( + name = "face_detection_proto", + srcs = ["face_detection.proto"], + deps = [ + "//mediapipe/calculators/core:gate_calculator_proto", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_proto", + "//mediapipe/calculators/tensor:inference_calculator_proto", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_proto", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/gpu:gpu_origin_proto", ], ) @@ -86,10 +129,7 @@ mediapipe_simple_subgraph( graph = "face_detection_full_range_cpu.pbtxt", register_as = "FaceDetectionFullRangeCpu", deps = [ - ":face_detection_full_range_common", - "//mediapipe/calculators/tensor:image_to_tensor_calculator", - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/calculators/util:to_image_calculator", + ":face_detection_full_range", ], ) @@ -98,22 +138,7 @@ mediapipe_simple_subgraph( graph = "face_detection_full_range_gpu.pbtxt", register_as = "FaceDetectionFullRangeGpu", deps = [ - ":face_detection_full_range_common", - "//mediapipe/calculators/tensor:image_to_tensor_calculator", - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/calculators/util:to_image_calculator", - ], -) - -mediapipe_simple_subgraph( - name = "face_detection_full_range_common", - graph = "face_detection_full_range_common.pbtxt", - register_as = "FaceDetectionFullRangeCommon", - deps = [ - "//mediapipe/calculators/tensor:tensors_to_detections_calculator", - "//mediapipe/calculators/tflite:ssd_anchors_calculator", - "//mediapipe/calculators/util:detection_projection_calculator", - "//mediapipe/calculators/util:non_max_suppression_calculator", + ":face_detection_full_range", ], ) @@ -122,7 +147,7 @@ mediapipe_simple_subgraph( graph = "face_detection_short_range_image.pbtxt", register_as = "FaceDetectionShortRangeImage", deps = [ - ":face_detection_short_range_common", + ":face_detection_short_range", "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:inference_calculator", @@ -134,7 +159,7 @@ mediapipe_simple_subgraph( graph = "face_detection_full_range_image.pbtxt", register_as = "FaceDetectionFullRangeImage", deps = [ - ":face_detection_full_range_common", + ":face_detection_full_range", "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:inference_calculator", diff --git a/mediapipe/modules/face_detection/face_detection.pbtxt b/mediapipe/modules/face_detection/face_detection.pbtxt new file mode 100644 index 0000000000..b85d224d2b --- /dev/null +++ b/mediapipe/modules/face_detection/face_detection.pbtxt @@ -0,0 +1,164 @@ +# MediaPipe graph to detect faces. +# +# EXAMPLE: +# node { +# calculator: "FaceDetectionFrontCpu" +# input_stream: "IMAGE:image" +# input_stream: "ROI:roi" +# output_stream: "DETECTIONS:face_detections" +# } + +type: "FaceDetection" + +# The input image, either ImageFrame, GpuBuffer, or (multi-backend) Image. +input_stream: "IMAGE:image" + +# ROI (region of interest) within the given image where faces should be +# detected. (NormalizedRect) +input_stream: "ROI:roi" + +# Detected faces. (std::vector) +# NOTE: there will not be an output packet in the DETECTIONS stream for this +# particular timestamp if none of faces detected. However, the MediaPipe +# framework will internally inform the downstream calculators of the absence of +# this packet so that they don't wait for it unnecessarily. +output_stream: "DETECTIONS:detections" + +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} +} + +# Converts the input CPU or GPU image to the multi-backend image type (Image). +node: { + calculator: "ToImageCalculator" + input_stream: "IMAGE:image" + output_stream: "IMAGE:multi_backend_image" +} + +# Transforms the input image into a 128x128 tensor while keeping the aspect +# ratio (what is expected by the corresponding face detection model), resulting +# in potential letterboxing in the transformed image. +node: { + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:multi_backend_image" + input_stream: "NORM_RECT:roi" + output_stream: "TENSORS:input_tensors" + output_stream: "MATRIX:transform_matrix" + options: { + [mediapipe.ImageToTensorCalculatorOptions.ext] { + keep_aspect_ratio: true + output_tensor_float_range { + min: -1.0 + max: 1.0 + } + border_mode: BORDER_ZERO + } + } + option_value: "gpu_origin:options/gpu_origin" + option_value: "output_tensor_width:options/tensor_width" + option_value: "output_tensor_height:options/tensor_height" +} + +# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "InferenceCalculator" + input_stream: "TENSORS:input_tensors" + output_stream: "TENSORS:detection_tensors" + options: { + [mediapipe.InferenceCalculatorOptions.ext] {} + } + option_value: "delegate:options/delegate" + option_value: "model_path:options/model_path" +} + +# Detection tensors. (std::vector) +#input_stream: "TENSORS:detection_tensors" + +# A 4x4 row-major-order matrix that maps a point represented in the detection +# tensors to a desired coordinate system, e.g., in the original input image +# before scaling/cropping. (std::array) +#input_stream: "MATRIX:transform_matrix" + +# Detected faces. (std::vector) +# NOTE: there will not be an output packet in the DETECTIONS stream for this +# particular timestamp if none of faces detected. However, the MediaPipe +# framework will internally inform the downstream calculators of the absence of +# this packet so that they don't wait for it unnecessarily. +#output_stream: "DETECTIONS:detections" + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + options: { + [mediapipe.SsdAnchorsCalculatorOptions.ext] { + num_layers: 1 + min_scale: 0.1484375 + max_scale: 0.75 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + aspect_ratios: 1.0 + fixed_anchor_size: true + } + } + option_value: "input_size_width:tensor_width" + option_value: "input_size_height:tensor_height" + option_value: "num_layers:num_layers" + option_value: "strides:strides" + option_value: "interpolated_scale_aspect_ratio:interpolated_scale_aspect_ratio" +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:unfiltered_detections" + options: { + [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { + num_classes: 1 + num_coords: 16 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 6 + num_values_per_keypoint: 2 + sigmoid_score: true + score_clipping_thresh: 100.0 + reverse_output_order: true + } + } + option_value: "num_boxes:num_boxes" + option_value: "x_scale:x_scale" + option_value: "y_scale:y_scale" + option_value: "h_scale:h_scale" + option_value: "w_scale:w_scale" + option_value: "min_score_thresh:min_score_thresh" +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "unfiltered_detections" + output_stream: "filtered_detections" + options: { + [mediapipe.NonMaxSuppressionCalculatorOptions.ext] { + min_suppression_threshold: 0.3 + overlap_type: INTERSECTION_OVER_UNION + algorithm: WEIGHTED + } + } +} + +# Projects the detections from input tensor to the corresponding locations on +# the original image (input to the graph). +node { + calculator: "DetectionProjectionCalculator" + input_stream: "DETECTIONS:filtered_detections" + input_stream: "PROJECTION_MATRIX:transform_matrix" + output_stream: "DETECTIONS:detections" +} diff --git a/mediapipe/modules/face_detection/face_detection.proto b/mediapipe/modules/face_detection/face_detection.proto new file mode 100644 index 0000000000..f5df8d6470 --- /dev/null +++ b/mediapipe/modules/face_detection/face_detection.proto @@ -0,0 +1,59 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/calculators/tensor/inference_calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; +import "mediapipe/gpu/gpu_origin.proto"; + +option java_package = "com.google.mediapipe.modules.facedetection"; +option java_outer_classname = "FaceDetectionFrontProto"; + +// Defines the face geometry pipeline estimation result format. +message FaceDetectionOptions { + extend mediapipe.CalculatorOptions { + optional FaceDetectionOptions ext = 374290926; + } + // Path to the TF Lite model (ex: /path/to/modelname.tflite). + optional string model_path = 1; + + // The coordinate origin corner, either CONVENTIONAL or TOP_LEFT. + optional GpuOrigin.Mode gpu_origin = 11; + + // Size of the tensor provided to the face-detection model. + optional int32 tensor_width = 21; + optional int32 tensor_height = 22; + // Number of output feature maps to generate the anchors on. + optional int32 num_layers = 23; + // Strides of each output feature maps. + repeated int32 strides = 24; + // The aspect ratio of the interpolated anchor from the SsdAnchorsCalculator. + optional float interpolated_scale_aspect_ratio = 25 [default = 1.0]; + + // The number of output boxes predicted by the detection model. + optional int32 num_boxes = 31; + // Parameters for decoding SSD detection model. + optional float x_scale = 32 [default = 0.0]; + optional float y_scale = 33 [default = 0.0]; + optional float w_scale = 34 [default = 0.0]; + optional float h_scale = 35 [default = 0.0]; + // Score threshold for perserving from the SSD detections. + optional float min_score_thresh = 36; + + // TfLite delegate to run inference. + optional InferenceCalculatorOptions.Delegate delegate = 6; +} diff --git a/mediapipe/modules/face_detection/face_detection_full_range.pbtxt b/mediapipe/modules/face_detection/face_detection_full_range.pbtxt new file mode 100644 index 0000000000..b526b67f92 --- /dev/null +++ b/mediapipe/modules/face_detection/face_detection_full_range.pbtxt @@ -0,0 +1,54 @@ +# MediaPipe graph to detect faces. (CPU input and inference by default.) +# +# It is required that "face_detection_full_range.tflite" is available at +# "mediapipe/modules/face_detection/face_detection_full_range.tflite" +# path during execution. +# +# EXAMPLE: +# node { +# calculator: "FaceDetectionShortRange" +# input_stream: "IMAGE:image_frame" +# output_stream: "DETECTIONS:face_detections" +# } + +type: "FaceDetectionFullRange" + +# The input image, either ImageFrame, GpuBuffer, or (multi-backend) Image. +input_stream: "IMAGE:image" + +# ROI (region of interest) within the given image where faces should be +# detected. (NormalizedRect) +input_stream: "ROI:roi" + +# Detected faces. (std::vector) +output_stream: "DETECTIONS:detections" + +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} +} + +node { + calculator: "FaceDetection" + input_stream: "IMAGE:image" + input_stream: "ROI:roi" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] { + model_path: "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite" + tensor_width: 192 + tensor_height: 192 + + num_layers: 1 + strides: 4 + interpolated_scale_aspect_ratio: 0.0 + + num_boxes: 2304 + x_scale: 192.0 + y_scale: 192.0 + h_scale: 192.0 + w_scale: 192.0 + min_score_thresh: 0.6 + } + } + option_value: "OPTIONS:options" +} diff --git a/mediapipe/modules/face_detection/face_detection_full_range_common.pbtxt b/mediapipe/modules/face_detection/face_detection_full_range_common.pbtxt deleted file mode 100644 index 937e8be1b3..0000000000 --- a/mediapipe/modules/face_detection/face_detection_full_range_common.pbtxt +++ /dev/null @@ -1,102 +0,0 @@ -# MediaPipe graph performing common processing to detect faces using -# face_detection_full_range_sparse.tflite model, currently consisting of tensor -# post processing. -# -# EXAMPLE: -# node { -# calculator: "FaceDetectionFullRangeCommon" -# input_stream: "TENSORS:detection_tensors" -# input_stream: "MATRIX:transform_matrix" -# output_stream: "DETECTIONS:detections" -# } - -type: "FaceDetectionShortRangeCommon" - -# Detection tensors. (std::vector) -input_stream: "TENSORS:detection_tensors" - -# A 4x4 row-major-order matrix that maps a point represented in the detection -# tensors to a desired coordinate system, e.g., in the original input image -# before scaling/cropping. (std::array) -input_stream: "MATRIX:transform_matrix" - -# Detected faces. (std::vector) -# NOTE: there will not be an output packet in the DETECTIONS stream for this -# particular timestamp if none of faces detected. However, the MediaPipe -# framework will internally inform the downstream calculators of the absence of -# this packet so that they don't wait for it unnecessarily. -output_stream: "DETECTIONS:detections" - -# Generates a single side packet containing a vector of SSD anchors based on -# the specification in the options. -node { - calculator: "SsdAnchorsCalculator" - output_side_packet: "anchors" - options: { - [mediapipe.SsdAnchorsCalculatorOptions.ext] { - num_layers: 1 - min_scale: 0.1484375 - max_scale: 0.75 - input_size_height: 192 - input_size_width: 192 - anchor_offset_x: 0.5 - anchor_offset_y: 0.5 - strides: 4 - aspect_ratios: 1.0 - fixed_anchor_size: true - interpolated_scale_aspect_ratio: 0.0 - } - } -} - -# Decodes the detection tensors generated by the TensorFlow Lite model, based on -# the SSD anchors and the specification in the options, into a vector of -# detections. Each detection describes a detected object. -node { - calculator: "TensorsToDetectionsCalculator" - input_stream: "TENSORS:detection_tensors" - input_side_packet: "ANCHORS:anchors" - output_stream: "DETECTIONS:unfiltered_detections" - options: { - [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { - num_classes: 1 - num_boxes: 2304 - num_coords: 16 - box_coord_offset: 0 - keypoint_coord_offset: 4 - num_keypoints: 6 - num_values_per_keypoint: 2 - sigmoid_score: true - score_clipping_thresh: 100.0 - reverse_output_order: true - x_scale: 192.0 - y_scale: 192.0 - h_scale: 192.0 - w_scale: 192.0 - min_score_thresh: 0.6 - } - } -} - -# Performs non-max suppression to remove excessive detections. -node { - calculator: "NonMaxSuppressionCalculator" - input_stream: "unfiltered_detections" - output_stream: "filtered_detections" - options: { - [mediapipe.NonMaxSuppressionCalculatorOptions.ext] { - min_suppression_threshold: 0.3 - overlap_type: INTERSECTION_OVER_UNION - algorithm: WEIGHTED - } - } -} - -# Projects the detections from input tensor to the corresponding locations on -# the original image (input to the graph). -node { - calculator: "DetectionProjectionCalculator" - input_stream: "DETECTIONS:filtered_detections" - input_stream: "PROJECTION_MATRIX:transform_matrix" - output_stream: "DETECTIONS:detections" -} diff --git a/mediapipe/modules/face_detection/face_detection_full_range_cpu.pbtxt b/mediapipe/modules/face_detection/face_detection_full_range_cpu.pbtxt index 2350401907..50c0f5d3ee 100644 --- a/mediapipe/modules/face_detection/face_detection_full_range_cpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_full_range_cpu.pbtxt @@ -1,80 +1,25 @@ -# MediaPipe graph to detect faces. (CPU input, and inference is executed on -# CPU.) -# -# It is required that "face_detection_full_range_sparse.tflite" is available at -# "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite" -# path during execution. -# -# EXAMPLE: -# node { -# calculator: "FaceDetectionFullRangeCpu" -# input_stream: "IMAGE:image" -# output_stream: "DETECTIONS:face_detections" -# } +# MediaPipe graph to detect faces. (CPU input and inference.) type: "FaceDetectionFullRangeCpu" -# CPU image. (ImageFrame) +# The input image, either ImageFrame, or (multi-backend) Image. input_stream: "IMAGE:image" # Detected faces. (std::vector) -# NOTE: there will not be an output packet in the DETECTIONS stream for this -# particular timestamp if none of faces detected. However, the MediaPipe -# framework will internally inform the downstream calculators of the absence of -# this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" -# Converts the input CPU image (ImageFrame) to the multi-backend image type -# (Image). -node: { - calculator: "ToImageCalculator" - input_stream: "IMAGE_CPU:image" - output_stream: "IMAGE:multi_backend_image" +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} } -# Transforms the input image into a 192x192 tensor while keeping the aspect -# ratio (what is expected by the corresponding face detection model), resulting -# in potential letterboxing in the transformed image. -node: { - calculator: "ImageToTensorCalculator" - input_stream: "IMAGE:multi_backend_image" - output_stream: "TENSORS:input_tensors" - output_stream: "MATRIX:transform_matrix" - options: { - [mediapipe.ImageToTensorCalculatorOptions.ext] { - output_tensor_width: 192 - output_tensor_height: 192 - keep_aspect_ratio: true - output_tensor_float_range { - min: -1.0 - max: 1.0 - } - border_mode: BORDER_ZERO - } - } -} - -# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a -# vector of tensors representing, for instance, detection boxes/keypoints and -# scores. node { - calculator: "InferenceCalculator" - input_stream: "TENSORS:input_tensors" - output_stream: "TENSORS:detection_tensors" - options: { - [mediapipe.InferenceCalculatorOptions.ext] { - model_path: "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite" - delegate { - xnnpack {} - } + calculator: "FaceDetectionFullRange" + input_stream: "IMAGE:image" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] { + delegate { xnnpack {} } } } -} - -# Performs tensor post processing to generate face detections. -node { - calculator: "FaceDetectionFullRangeCommon" - input_stream: "TENSORS:detection_tensors" - input_stream: "MATRIX:transform_matrix" - output_stream: "DETECTIONS:detections" + option_value: "OPTIONS:options" } diff --git a/mediapipe/modules/face_detection/face_detection_full_range_gpu.pbtxt b/mediapipe/modules/face_detection/face_detection_full_range_gpu.pbtxt index 703b717808..52b6e361d3 100644 --- a/mediapipe/modules/face_detection/face_detection_full_range_gpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_full_range_gpu.pbtxt @@ -1,80 +1,26 @@ -# MediaPipe graph to detect faces. (GPU input, and inference is executed on -# GPU.) -# -# It is required that "face_detection_full_range_sparse.tflite" is available at -# "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite" -# path during execution. -# -# EXAMPLE: -# node { -# calculator: "FaceDetectionFullRangeGpu" -# input_stream: "IMAGE:image" -# output_stream: "DETECTIONS:face_detections" -# } +# MediaPipe graph to detect faces. (GPU input and inference.) type: "FaceDetectionFullRangeGpu" -# GPU image. (GpuBuffer) +# The input image, either GpuBuffer, or (multi-backend) Image. input_stream: "IMAGE:image" # Detected faces. (std::vector) -# NOTE: there will not be an output packet in the DETECTIONS stream for this -# particular timestamp if none of faces detected. However, the MediaPipe -# framework will internally inform the downstream calculators of the absence of -# this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" -# Converts the input GPU image (GpuBuffer) to the multi-backend image type -# (Image). -node: { - calculator: "ToImageCalculator" - input_stream: "IMAGE_GPU:image" - output_stream: "IMAGE:multi_backend_image" +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} } -# Transforms the input image into a 128x128 tensor while keeping the aspect -# ratio (what is expected by the corresponding face detection model), resulting -# in potential letterboxing in the transformed image. -node: { - calculator: "ImageToTensorCalculator" - input_stream: "IMAGE:multi_backend_image" - output_stream: "TENSORS:input_tensors" - output_stream: "MATRIX:transform_matrix" - options: { - [mediapipe.ImageToTensorCalculatorOptions.ext] { - output_tensor_width: 192 - output_tensor_height: 192 - keep_aspect_ratio: true - output_tensor_float_range { - min: -1.0 - max: 1.0 - } - border_mode: BORDER_ZERO - gpu_origin: TOP_LEFT - } - } -} - -# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a -# vector of tensors representing, for instance, detection boxes/keypoints and -# scores. node { - calculator: "InferenceCalculator" - input_stream: "TENSORS:input_tensors" - output_stream: "TENSORS:detection_tensors" - options: { - [mediapipe.InferenceCalculatorOptions.ext] { - model_path: "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite" - # + calculator: "FaceDetectionFullRange" + input_stream: "IMAGE:image" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] { + gpu_origin: TOP_LEFT delegate: { gpu { use_advanced_gpu_api: true } } } } -} - -# Performs tensor post processing to generate face detections. -node { - calculator: "FaceDetectionFullRangeCommon" - input_stream: "TENSORS:detection_tensors" - input_stream: "MATRIX:transform_matrix" - output_stream: "DETECTIONS:detections" + option_value: "OPTIONS:options" } diff --git a/mediapipe/modules/face_detection/face_detection_full_range_image.pbtxt b/mediapipe/modules/face_detection/face_detection_full_range_image.pbtxt index 4e0bc0b4db..b645638fd3 100644 --- a/mediapipe/modules/face_detection/face_detection_full_range_image.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_full_range_image.pbtxt @@ -36,51 +36,12 @@ node { } } -# Transforms the input image into a 128x128 tensor while keeping the aspect -# ratio (what is expected by the corresponding face detection model), resulting -# in potential letterboxing in the transformed image. -node: { - calculator: "ImageToTensorCalculator" - input_stream: "IMAGE:throttled_image" - output_stream: "TENSORS:input_tensors" - output_stream: "MATRIX:transform_matrix" - options: { - [mediapipe.ImageToTensorCalculatorOptions.ext] { - output_tensor_width: 192 - output_tensor_height: 192 - keep_aspect_ratio: true - output_tensor_float_range { - min: -1.0 - max: 1.0 - } - border_mode: BORDER_ZERO - gpu_origin: CONVENTIONAL - } - } -} - -# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a -# vector of tensors representing, for instance, detection boxes/keypoints and -# scores. -# TODO: Use GraphOptions to modify the delegate field to be -# `delegate { xnnpack {} }` for the CPU only use cases. -node { - calculator: "InferenceCalculator" - input_stream: "TENSORS:input_tensors" - output_stream: "TENSORS:detection_tensors" - options: { - [mediapipe.InferenceCalculatorOptions.ext] { - model_path: "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite" - # - delegate: { gpu { use_advanced_gpu_api: true } } - } - } -} - -# Performs tensor post processing to generate face detections. node { - calculator: "FaceDetectionFullRangeCommon" - input_stream: "TENSORS:detection_tensors" - input_stream: "MATRIX:transform_matrix" + calculator: "FaceDetectionFullRange" + input_stream: "IMAGE:throttled_image" output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} + } + option_value: "OPTIONS:options" } diff --git a/mediapipe/modules/face_detection/face_detection_short_range.pbtxt b/mediapipe/modules/face_detection/face_detection_short_range.pbtxt new file mode 100644 index 0000000000..eb9ed32005 --- /dev/null +++ b/mediapipe/modules/face_detection/face_detection_short_range.pbtxt @@ -0,0 +1,57 @@ +# MediaPipe graph to detect faces. (CPU input and inference by default.) +# +# It is required that "face_detection_short_range.tflite" is available at +# "mediapipe/modules/face_detection/face_detection_short_range.tflite" +# path during execution. +# +# EXAMPLE: +# node { +# calculator: "FaceDetectionShortRange" +# input_stream: "IMAGE:image_frame" +# output_stream: "DETECTIONS:face_detections" +# } + +type: "FaceDetectionShortRange" + +# The input image, either ImageFrame, GpuBuffer, or (multi-backend) Image. +input_stream: "IMAGE:image" + +# ROI (region of interest) within the given image where faces should be +# detected. (NormalizedRect) +input_stream: "ROI:roi" + +# Detected faces. (std::vector) +output_stream: "DETECTIONS:detections" + +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} +} + +node { + calculator: "FaceDetection" + input_stream: "IMAGE:image" + input_stream: "ROI:roi" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] { + model_path: "mediapipe/modules/face_detection/face_detection_short_range.tflite" + tensor_width: 128 + tensor_height: 128 + + num_layers: 4 + strides: 8 + strides: 16 + strides: 16 + strides: 16 + interpolated_scale_aspect_ratio: 1.0 + + num_boxes: 896 + x_scale: 128.0 + y_scale: 128.0 + h_scale: 128.0 + w_scale: 128.0 + min_score_thresh: 0.5 + } + } + option_value: "OPTIONS:options" +} diff --git a/mediapipe/modules/face_detection/face_detection_short_range_by_roi_cpu.pbtxt b/mediapipe/modules/face_detection/face_detection_short_range_by_roi_cpu.pbtxt index b3adfeb833..6dab2966d9 100644 --- a/mediapipe/modules/face_detection/face_detection_short_range_by_roi_cpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_short_range_by_roi_cpu.pbtxt @@ -1,21 +1,8 @@ -# MediaPipe graph to detect faces. (CPU input, and inference is executed on -# CPU.) -# -# It is required that "face_detection_short_range.tflite" is available at -# "mediapipe/modules/face_detection/face_detection_short_range.tflite" -# path during execution. -# -# EXAMPLE: -# node { -# calculator: "FaceDetectionShortRangeByRoiCpu" -# input_stream: "IMAGE:image" -# input_stream: "ROI:roi" -# output_stream: "DETECTIONS:face_detections" -# } +# MediaPipe graph to detect faces. (CPU input and inference, with region-of-interest.) -type: "FaceDetectionShortRangeByRoiCpu" +type: "FaceDetectionShortRangeCpu" -# CPU image. (ImageFrame) +# The input image, either ImageFrame, or (multi-backend) Image. input_stream: "IMAGE:image" # ROI (region of interest) within the given image where faces should be @@ -23,61 +10,21 @@ input_stream: "IMAGE:image" input_stream: "ROI:roi" # Detected faces. (std::vector) -# NOTE: there will not be an output packet in the DETECTIONS stream for this -# particular timestamp if none of faces detected. However, the MediaPipe -# framework will internally inform the downstream calculators of the absence of -# this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" -# Converts the input CPU image (ImageFrame) to the multi-backend image type -# (Image). -node: { - calculator: "ToImageCalculator" - input_stream: "IMAGE_CPU:image" - output_stream: "IMAGE:multi_backend_image" +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} } -# Transforms specified region of image into 128x128 tensor keeping aspect ratio -# (padding tensor if needed). node { - calculator: "ImageToTensorCalculator" - input_stream: "IMAGE:multi_backend_image" - input_stream: "NORM_RECT:roi" - output_stream: "TENSORS:input_tensors" - output_stream: "MATRIX:transform_matrix" - options: { - [mediapipe.ImageToTensorCalculatorOptions.ext] { - output_tensor_width: 128 - output_tensor_height: 128 - keep_aspect_ratio: true - output_tensor_float_range { - min: -1.0 - max: 1.0 - } - border_mode: BORDER_ZERO - } - } -} - -# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a -# vector of tensors representing, for instance, detection boxes/keypoints and -# scores. -node { - calculator: "InferenceCalculator" - input_stream: "TENSORS:input_tensors" - output_stream: "TENSORS:detection_tensors" - options: { - [mediapipe.InferenceCalculatorOptions.ext] { - model_path: "mediapipe/modules/face_detection/face_detection_short_range.tflite" + calculator: "FaceDetectionShortRange" + input_stream: "IMAGE:image" + input_stream: "ROI:roi" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] { delegate { xnnpack {} } } } -} - -# Performs tensor post processing to generate face detections. -node { - calculator: "FaceDetectionShortRangeCommon" - input_stream: "TENSORS:detection_tensors" - input_stream: "MATRIX:transform_matrix" - output_stream: "DETECTIONS:detections" + option_value: "OPTIONS:options" } diff --git a/mediapipe/modules/face_detection/face_detection_short_range_by_roi_gpu.pbtxt b/mediapipe/modules/face_detection/face_detection_short_range_by_roi_gpu.pbtxt index c35331e0e6..6f9e9e98f0 100644 --- a/mediapipe/modules/face_detection/face_detection_short_range_by_roi_gpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_short_range_by_roi_gpu.pbtxt @@ -1,21 +1,8 @@ -# MediaPipe graph to detect faces. (GPU input, and inference is executed on -# GPU.) -# -# It is required that "face_detection_short_range.tflite" is available at -# "mediapipe/modules/face_detection/face_detection_short_range.tflite" -# path during execution. -# -# EXAMPLE: -# node { -# calculator: "FaceDetectionShortRangeByRoiGpu" -# input_stream: "IMAGE:image" -# input_stream: "ROI:roi" -# output_stream: "DETECTIONS:face_detections" -# } +# MediaPipe graph to detect faces. (GPU input and inference, with region-of-interest.) -type: "FaceDetectionShortRangeByRoiGpu" +type: "FaceDetectionShortRangeGpu" -# GPU image. (GpuBuffer) +# The input image, either ImageFrame, or (multi-backend) Image. input_stream: "IMAGE:image" # ROI (region of interest) within the given image where faces should be @@ -23,61 +10,22 @@ input_stream: "IMAGE:image" input_stream: "ROI:roi" # Detected faces. (std::vector) -# NOTE: there will not be an output packet in the DETECTIONS stream for this -# particular timestamp if none of faces detected. However, the MediaPipe -# framework will internally inform the downstream calculators of the absence of -# this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" -# Converts the input GPU image (GpuBuffer) to the multi-backend image type -# (Image). -node: { - calculator: "ToImageCalculator" - input_stream: "IMAGE_GPU:image" - output_stream: "IMAGE:multi_backend_image" +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} } -# Transforms specified region of image into 128x128 tensor keeping aspect ratio -# (padding tensor if needed). node { - calculator: "ImageToTensorCalculator" - input_stream: "IMAGE:multi_backend_image" - input_stream: "NORM_RECT:roi" - output_stream: "TENSORS:input_tensors" - output_stream: "MATRIX:transform_matrix" - options: { - [mediapipe.ImageToTensorCalculatorOptions.ext] { - output_tensor_width: 128 - output_tensor_height: 128 - keep_aspect_ratio: true - output_tensor_float_range { - min: -1.0 - max: 1.0 - } - border_mode: BORDER_ZERO + calculator: "FaceDetectionShortRange" + input_stream: "IMAGE:image" + input_stream: "ROI:roi" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] { gpu_origin: TOP_LEFT + delegate: { gpu { use_advanced_gpu_api: true } } } } -} - -# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a -# vector of tensors representing, for instance, detection boxes/keypoints and -# scores. -node { - calculator: "InferenceCalculator" - input_stream: "TENSORS:input_tensors" - output_stream: "TENSORS:detection_tensors" - options: { - [mediapipe.InferenceCalculatorOptions.ext] { - model_path: "mediapipe/modules/face_detection/face_detection_short_range.tflite" - } - } -} - -# Performs tensor post processing to generate face detections. -node { - calculator: "FaceDetectionShortRangeCommon" - input_stream: "TENSORS:detection_tensors" - input_stream: "MATRIX:transform_matrix" - output_stream: "DETECTIONS:detections" + option_value: "OPTIONS:options" } diff --git a/mediapipe/modules/face_detection/face_detection_short_range_common.pbtxt b/mediapipe/modules/face_detection/face_detection_short_range_common.pbtxt deleted file mode 100644 index 4a6a54f447..0000000000 --- a/mediapipe/modules/face_detection/face_detection_short_range_common.pbtxt +++ /dev/null @@ -1,103 +0,0 @@ -# MediaPipe graph performing common processing to detect faces, currently -# consisting of tensor post processing. -# -# EXAMPLE: -# node { -# calculator: "FaceDetectionShortRangeCommon" -# input_stream: "TENSORS:detection_tensors" -# input_stream: "MATRIX:transform_matrix" -# output_stream: "DETECTIONS:detections" -# } - -type: "FaceDetectionShortRangeCommon" - -# Detection tensors. (std::vector) -input_stream: "TENSORS:detection_tensors" - -# A 4x4 row-major-order matrix that maps a point represented in the detection -# tensors to a desired coordinate system, e.g., in the original input image -# before scaling/cropping. (std::array) -input_stream: "MATRIX:transform_matrix" - -# Detected faces. (std::vector) -# NOTE: there will not be an output packet in the DETECTIONS stream for this -# particular timestamp if none of faces detected. However, the MediaPipe -# framework will internally inform the downstream calculators of the absence of -# this packet so that they don't wait for it unnecessarily. -output_stream: "DETECTIONS:detections" - -# Generates a single side packet containing a vector of SSD anchors based on -# the specification in the options. -node { - calculator: "SsdAnchorsCalculator" - output_side_packet: "anchors" - options: { - [mediapipe.SsdAnchorsCalculatorOptions.ext] { - num_layers: 4 - min_scale: 0.1484375 - max_scale: 0.75 - input_size_height: 128 - input_size_width: 128 - anchor_offset_x: 0.5 - anchor_offset_y: 0.5 - strides: 8 - strides: 16 - strides: 16 - strides: 16 - aspect_ratios: 1.0 - fixed_anchor_size: true - } - } -} - -# Decodes the detection tensors generated by the TensorFlow Lite model, based on -# the SSD anchors and the specification in the options, into a vector of -# detections. Each detection describes a detected object. -node { - calculator: "TensorsToDetectionsCalculator" - input_stream: "TENSORS:detection_tensors" - input_side_packet: "ANCHORS:anchors" - output_stream: "DETECTIONS:unfiltered_detections" - options: { - [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { - num_classes: 1 - num_boxes: 896 - num_coords: 16 - box_coord_offset: 0 - keypoint_coord_offset: 4 - num_keypoints: 6 - num_values_per_keypoint: 2 - sigmoid_score: true - score_clipping_thresh: 100.0 - reverse_output_order: true - x_scale: 128.0 - y_scale: 128.0 - h_scale: 128.0 - w_scale: 128.0 - min_score_thresh: 0.5 - } - } -} - -# Performs non-max suppression to remove excessive detections. -node { - calculator: "NonMaxSuppressionCalculator" - input_stream: "unfiltered_detections" - output_stream: "filtered_detections" - options: { - [mediapipe.NonMaxSuppressionCalculatorOptions.ext] { - min_suppression_threshold: 0.3 - overlap_type: INTERSECTION_OVER_UNION - algorithm: WEIGHTED - } - } -} - -# Projects the detections from input tensor to the corresponding locations on -# the original image (input to the graph). -node { - calculator: "DetectionProjectionCalculator" - input_stream: "DETECTIONS:filtered_detections" - input_stream: "PROJECTION_MATRIX:transform_matrix" - output_stream: "DETECTIONS:detections" -} diff --git a/mediapipe/modules/face_detection/face_detection_short_range_cpu.pbtxt b/mediapipe/modules/face_detection/face_detection_short_range_cpu.pbtxt index 0db242049c..21b63917a8 100644 --- a/mediapipe/modules/face_detection/face_detection_short_range_cpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_short_range_cpu.pbtxt @@ -1,78 +1,25 @@ -# MediaPipe graph to detect faces. (CPU input, and inference is executed on -# CPU.) -# -# It is required that "face_detection_short_range.tflite" is available at -# "mediapipe/modules/face_detection/face_detection_short_range.tflite" -# path during execution. -# -# EXAMPLE: -# node { -# calculator: "FaceDetectionShortRangeCpu" -# input_stream: "IMAGE:image" -# output_stream: "DETECTIONS:face_detections" -# } +# MediaPipe graph to detect faces. (CPU input and inference.) type: "FaceDetectionShortRangeCpu" -# CPU image. (ImageFrame) +# The input image, either ImageFrame, or (multi-backend) Image. input_stream: "IMAGE:image" # Detected faces. (std::vector) -# NOTE: there will not be an output packet in the DETECTIONS stream for this -# particular timestamp if none of faces detected. However, the MediaPipe -# framework will internally inform the downstream calculators of the absence of -# this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" -# Converts the input CPU image (ImageFrame) to the multi-backend image type -# (Image). -node: { - calculator: "ToImageCalculator" - input_stream: "IMAGE_CPU:image" - output_stream: "IMAGE:multi_backend_image" +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} } -# Transforms the input image into a 128x128 tensor while keeping the aspect -# ratio (what is expected by the corresponding face detection model), resulting -# in potential letterboxing in the transformed image. -node: { - calculator: "ImageToTensorCalculator" - input_stream: "IMAGE:multi_backend_image" - output_stream: "TENSORS:input_tensors" - output_stream: "MATRIX:transform_matrix" - options: { - [mediapipe.ImageToTensorCalculatorOptions.ext] { - output_tensor_width: 128 - output_tensor_height: 128 - keep_aspect_ratio: true - output_tensor_float_range { - min: -1.0 - max: 1.0 - } - border_mode: BORDER_ZERO - } - } -} - -# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a -# vector of tensors representing, for instance, detection boxes/keypoints and -# scores. node { - calculator: "InferenceCalculator" - input_stream: "TENSORS:input_tensors" - output_stream: "TENSORS:detection_tensors" - options: { - [mediapipe.InferenceCalculatorOptions.ext] { - model_path: "mediapipe/modules/face_detection/face_detection_short_range.tflite" + calculator: "FaceDetectionShortRange" + input_stream: "IMAGE:image" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] { delegate { xnnpack {} } } } -} - -# Performs tensor post processing to generate face detections. -node { - calculator: "FaceDetectionShortRangeCommon" - input_stream: "TENSORS:detection_tensors" - input_stream: "MATRIX:transform_matrix" - output_stream: "DETECTIONS:detections" + option_value: "OPTIONS:options" } diff --git a/mediapipe/modules/face_detection/face_detection_short_range_gpu.pbtxt b/mediapipe/modules/face_detection/face_detection_short_range_gpu.pbtxt index ce0d25b133..ededa1353c 100644 --- a/mediapipe/modules/face_detection/face_detection_short_range_gpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_short_range_gpu.pbtxt @@ -1,78 +1,26 @@ -# MediaPipe graph to detect faces. (GPU input, and inference is executed on -# GPU.) -# -# It is required that "face_detection_short_range.tflite" is available at -# "mediapipe/modules/face_detection/face_detection_short_range.tflite" -# path during execution. -# -# EXAMPLE: -# node { -# calculator: "FaceDetectionShortRangeGpu" -# input_stream: "IMAGE:image" -# output_stream: "DETECTIONS:face_detections" -# } +# MediaPipe graph to detect faces. (GPU input and inference.) type: "FaceDetectionShortRangeGpu" -# GPU image. (GpuBuffer) +# The input image, either GpuBuffer, or (multi-backend) Image. input_stream: "IMAGE:image" # Detected faces. (std::vector) -# NOTE: there will not be an output packet in the DETECTIONS stream for this -# particular timestamp if none of faces detected. However, the MediaPipe -# framework will internally inform the downstream calculators of the absence of -# this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" -# Converts the input GPU image (GpuBuffer) to the multi-backend image type -# (Image). -node: { - calculator: "ToImageCalculator" - input_stream: "IMAGE_GPU:image" - output_stream: "IMAGE:multi_backend_image" +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} } -# Transforms the input image into a 128x128 tensor while keeping the aspect -# ratio (what is expected by the corresponding face detection model), resulting -# in potential letterboxing in the transformed image. -node: { - calculator: "ImageToTensorCalculator" - input_stream: "IMAGE:multi_backend_image" - output_stream: "TENSORS:input_tensors" - output_stream: "MATRIX:transform_matrix" - options: { - [mediapipe.ImageToTensorCalculatorOptions.ext] { - output_tensor_width: 128 - output_tensor_height: 128 - keep_aspect_ratio: true - output_tensor_float_range { - min: -1.0 - max: 1.0 - } - border_mode: BORDER_ZERO - gpu_origin: TOP_LEFT - } - } -} - -# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a -# vector of tensors representing, for instance, detection boxes/keypoints and -# scores. node { - calculator: "InferenceCalculator" - input_stream: "TENSORS:input_tensors" - output_stream: "TENSORS:detection_tensors" - options: { - [mediapipe.InferenceCalculatorOptions.ext] { - model_path: "mediapipe/modules/face_detection/face_detection_short_range.tflite" + calculator: "FaceDetectionShortRange" + input_stream: "IMAGE:image" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] { + gpu_origin: TOP_LEFT + delegate: { gpu { use_advanced_gpu_api: true } } } } -} - -# Performs tensor post processing to generate face detections. -node { - calculator: "FaceDetectionShortRangeCommon" - input_stream: "TENSORS:detection_tensors" - input_stream: "MATRIX:transform_matrix" - output_stream: "DETECTIONS:detections" + option_value: "OPTIONS:options" } diff --git a/mediapipe/modules/face_detection/face_detection_short_range_image.pbtxt b/mediapipe/modules/face_detection/face_detection_short_range_image.pbtxt index a2590418b2..c421ae47bc 100644 --- a/mediapipe/modules/face_detection/face_detection_short_range_image.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_short_range_image.pbtxt @@ -43,52 +43,12 @@ node { } } -# Transforms the input image into a 128x128 tensor while keeping the aspect -# ratio (what is expected by the corresponding face detection model), resulting -# in potential letterboxing in the transformed image. -node: { - calculator: "ImageToTensorCalculator" - input_stream: "IMAGE:throttled_image" - output_stream: "TENSORS:input_tensors" - output_stream: "MATRIX:transform_matrix" - options: { - [mediapipe.ImageToTensorCalculatorOptions.ext] { - output_tensor_width: 128 - output_tensor_height: 128 - keep_aspect_ratio: true - output_tensor_float_range { - min: -1.0 - max: 1.0 - } - border_mode: BORDER_ZERO - gpu_origin: CONVENTIONAL - } - } -} - -# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a -# vector of tensors representing, for instance, detection boxes/keypoints and -# scores. -# TODO: Use GraphOptions to modify the delegate field to be -# `delegate { xnnpack {} }` for the CPU only use cases. node { - calculator: "InferenceCalculator" - input_stream: "TENSORS:input_tensors" - output_stream: "TENSORS:detection_tensors" - options: { - [mediapipe.InferenceCalculatorOptions.ext] { - model_path: "mediapipe/modules/face_detection/face_detection_short_range.tflite" - - # - delegate: { gpu { use_advanced_gpu_api: true } } - } - } -} - -# Performs tensor post processing to generate face detections. -node { - calculator: "FaceDetectionShortRangeCommon" - input_stream: "TENSORS:detection_tensors" - input_stream: "MATRIX:transform_matrix" + calculator: "FaceDetectionShortRange" + input_stream: "IMAGE:throttled_image" output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} + } + option_value: "OPTIONS:options" } diff --git a/mediapipe/modules/face_detection/face_detection_test.cc b/mediapipe/modules/face_detection/face_detection_test.cc new file mode 100644 index 0000000000..0ed1b495df --- /dev/null +++ b/mediapipe/modules/face_detection/face_detection_test.cc @@ -0,0 +1,384 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" +#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/options_util.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gpu_origin.pb.h" +#include "mediapipe/modules/face_detection/face_detection.pb.h" + +#if !defined(__APPLE__) && !__ANDROID__ +#include "mediapipe/gpu/gl_app_texture_support.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_test_base.h" +#endif // !defined(__APPLE__) && !__ANDROID__ + +namespace mediapipe { +namespace { +using mediapipe::FaceDetectionOptions; + +// Ensure protobuf registration. +void RegisterProtobufTypes() { + MakePacket(); + MakePacket(); + MakePacket(); + + MakePacket(); + MakePacket(); + MakePacket(); + MakePacket(); +} + +// Returns a Packet with an ImageFrame showing a face. +Packet TestImageFrame() { + std::unique_ptr input_image = LoadTestPng( + file::JoinPath(GetTestRootDir(), "mediapipe/objc/testdata/sergey.png")); + EXPECT_EQ(input_image->Height(), 600); + return MakePacket(std::move(*input_image)); +} + +// Returns the registered type name for the basic face-detection-graph. +std::string GetFaceDetectionGraphType() { return "FaceDetectionWithoutRoi"; } + +// Returns the config from "face_detection_without_roi.pbtxt". +CalculatorGraphConfig GetFaceDetectionGraph() { + return GraphRegistry().CreateByName("", GetFaceDetectionGraphType()).value(); +} + +// Returns the config from "face_detection.pbtxt". +CalculatorGraphConfig GetFaceDetectionWithRoiGraph() { + return GraphRegistry().CreateByName("", "FaceDetection").value(); +} + +// Returns the config from "face_detection_short_range.pbtxt". +CalculatorGraphConfig GetFaceDetectionShortRangeCpu() { + CalculatorGraphConfig config = + GraphRegistry().CreateByName("", "FaceDetectionShortRangeCpu").value(); + return config; +} + +// Returns the FaceDetectionOptions from "face_detection_short_range_cpu.pbtxt". +FaceDetectionOptions GetFaceDetectionShortRangeOptions() { + CalculatorGraphConfig config; + LoadTestGraph(&config, + GetTestFilePath("mediapipe/modules/face_detection/" + "face_detection_short_range.binarypb")); + tool::OptionsMap map; + map.Initialize(config.node(0)); + return map.Get(); +} + +// Returns the FaceDetectionOptions from "face_detection_full_range_cpu.pbtxt". +FaceDetectionOptions GetFaceDetectionFullRangeOptions() { + CalculatorGraphConfig config; + LoadTestGraph(&config, GetTestFilePath("mediapipe/modules/face_detection/" + "face_detection_full_range.binarypb")); + tool::OptionsMap map; + map.Initialize(config.node(0)); + return map.Get(); +} + +// Returns the FaceDetectionOptions needed to enable CPU processing. +FaceDetectionOptions GetCpuOptions() { + FaceDetectionOptions result; + result.mutable_delegate()->xnnpack(); + return result; +} + +// Returns the FaceDetectionOptions needed to enable GPU processing. +FaceDetectionOptions GetGpuOptions() { + FaceDetectionOptions result; + result.set_gpu_origin(mediapipe::GpuOrigin_Mode::GpuOrigin_Mode_TOP_LEFT); + result.mutable_delegate()->mutable_gpu()->set_use_advanced_gpu_api(true); + return result; +} + +// Returns an example region of interest rectangle. +mediapipe::NormalizedRect GetTestRoi() { + mediapipe::NormalizedRect result; + result.set_x_center(0.5); + result.set_y_center(0.5); + result.set_width(0.8); + result.set_height(0.8); + return result; +} + +// Tests for options input and output packets and streams. +class FaceDetectionTest : public ::testing::Test { + protected: + void SetUp() override { RegisterProtobufTypes(); } + void TearDown() override {} +}; + +TEST_F(FaceDetectionTest, ExpandFaceDetectionShortRangeCpu) { + CalculatorGraphConfig config = GetFaceDetectionShortRangeCpu(); + Packet frame1 = TestImageFrame(); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + std::vector output; + MP_ASSERT_OK(graph.ObserveOutputStream("detections", [&](const Packet& p) { + output.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("image", frame1.At(Timestamp(20000)))); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output.front().Get>().size(), 1); +} + +TEST_F(FaceDetectionTest, ExpandFaceDetection) { + CalculatorGraphConfig config = GetFaceDetectionGraph(); + mediapipe::FaceDetectionOptions face_options = + GetFaceDetectionShortRangeOptions(); + face_options.MergeFrom(GetCpuOptions()); + config.clear_graph_options(); + config.add_graph_options()->PackFrom(face_options); + Packet frame1 = TestImageFrame(); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + std::vector output; + MP_ASSERT_OK(graph.ObserveOutputStream("detections", [&](const Packet& p) { + output.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("image", frame1.At(Timestamp(20000)))); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output.front().Get>().size(), 1); +} + +TEST_F(FaceDetectionTest, FaceDetectionShortRangeApi) { + CalculatorGraphConfig config = GetFaceDetectionGraph(); + config.clear_graph_options(); + mediapipe::FaceDetectionOptions face_options = + GetFaceDetectionShortRangeOptions(); + Subgraph::SubgraphOptions graph_options; + face_options.MergeFrom(GetCpuOptions()); + graph_options.add_node_options()->PackFrom(face_options); + Packet frame1 = TestImageFrame(); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize({config}, {}, {}, GetFaceDetectionGraphType(), + &graph_options)); + + std::vector output; + MP_ASSERT_OK(graph.ObserveOutputStream("detections", [&](const Packet& p) { + output.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("image", frame1.At(Timestamp(20000)))); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output.front().Get>().size(), 1); +} + +TEST_F(FaceDetectionTest, FaceDetectionWrapperApi) { + CalculatorGraphConfig config = GetFaceDetectionGraph(); + config.clear_graph_options(); + mediapipe::FaceDetectionOptions face_options = + GetFaceDetectionShortRangeOptions(); + face_options.MergeFrom(GetCpuOptions()); + Subgraph::SubgraphOptions graph_options; + graph_options.add_node_options()->PackFrom(face_options); + Packet frame1 = TestImageFrame(); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize({config}, {}, {}, GetFaceDetectionGraphType(), + &graph_options)); + + std::vector output; + MP_ASSERT_OK(graph.ObserveOutputStream("detections", [&](const Packet& p) { + output.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("image", frame1.At(Timestamp(20000)))); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output.front().Get>().size(), 1); +} + +TEST_F(FaceDetectionTest, FaceDetectionFullRangeApi) { + CalculatorGraphConfig config = GetFaceDetectionGraph(); + config.clear_graph_options(); + mediapipe::FaceDetectionOptions face_options = + GetFaceDetectionFullRangeOptions(); + Subgraph::SubgraphOptions graph_options; + face_options.MergeFrom(GetCpuOptions()); + graph_options.add_node_options()->PackFrom(face_options); + Packet frame1 = TestImageFrame(); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize({config}, {}, {}, GetFaceDetectionGraphType(), + &graph_options)); + + std::vector output; + MP_ASSERT_OK(graph.ObserveOutputStream("detections", [&](const Packet& p) { + output.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("image", frame1.At(Timestamp(20000)))); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output.front().Get>().size(), 1); +} + +TEST_F(FaceDetectionTest, FaceDetectionShortRangeByRoiCpu) { + CalculatorGraphConfig config = GetFaceDetectionWithRoiGraph(); + config.clear_graph_options(); + mediapipe::FaceDetectionOptions face_options = + GetFaceDetectionShortRangeOptions(); + face_options.MergeFrom(GetCpuOptions()); + Subgraph::SubgraphOptions graph_options; + graph_options.add_node_options()->PackFrom(face_options); + Packet frame1 = TestImageFrame(); + + CalculatorGraph graph; + MP_ASSERT_OK( + graph.Initialize({config}, {}, {}, "FaceDetection", &graph_options)); + + std::vector output; + MP_ASSERT_OK(graph.ObserveOutputStream("detections", [&](const Packet& p) { + output.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("image", frame1.At(Timestamp(20000)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "roi", MakePacket(GetTestRoi()) + .At(Timestamp(20000)))); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output.front().Get>().size(), 1); +} + +// These GpuBuffer tests are disabled on mobile for now. +#if !defined(__APPLE__) && !__ANDROID__ + +class FaceDetectionGpuTest : public mediapipe::GpuTestBase { + protected: + void SetUp() override {} + void TearDown() override {} + + // Returns a Packet with a GpuBuffer from an ImageFrame. + Packet GpuBuffer(Packet image_frame) { + std::unique_ptr gpu_buffer; + helper_.RunInGlContext([this, &image_frame, &gpu_buffer] { + auto src = helper_.CreateSourceTexture(image_frame.Get()); + gpu_buffer = src.GetFrame(); + }); + return Adopt(gpu_buffer.release()); + } +}; + +TEST_F(FaceDetectionGpuTest, FaceDetectionFullRangeGpu) { + CalculatorGraphConfig config = GetFaceDetectionGraph(); + config.clear_graph_options(); + mediapipe::FaceDetectionOptions face_options = + GetFaceDetectionFullRangeOptions(); + face_options.MergeFrom(GetGpuOptions()); + + Subgraph::SubgraphOptions graph_options; + graph_options.add_node_options()->PackFrom(face_options); + Packet frame1 = GpuBuffer(TestImageFrame()); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize({config}, {}, {}, GetFaceDetectionGraphType(), + &graph_options)); + + MP_ASSERT_OK(mediapipe::SetExternalGlContextForGraph( + &graph, helper_.GetGlContext().native_context())); + std::vector output; + MP_ASSERT_OK(graph.ObserveOutputStream("detections", [&](const Packet& p) { + output.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("image", frame1.At(Timestamp(20000)))); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output.front().Get>().size(), 1); +} + +TEST_F(FaceDetectionGpuTest, FaceDetectionShortRangeGpu) { + CalculatorGraphConfig config = GetFaceDetectionGraph(); + config.clear_graph_options(); + mediapipe::FaceDetectionOptions face_options = + GetFaceDetectionShortRangeOptions(); + face_options.MergeFrom(GetGpuOptions()); + + Subgraph::SubgraphOptions graph_options; + graph_options.add_node_options()->PackFrom(face_options); + Packet frame1 = GpuBuffer(TestImageFrame()); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize({config}, {}, {}, GetFaceDetectionGraphType(), + &graph_options)); + + MP_ASSERT_OK(mediapipe::SetExternalGlContextForGraph( + &graph, helper_.GetGlContext().native_context())); + std::vector output; + MP_ASSERT_OK(graph.ObserveOutputStream("detections", [&](const Packet& p) { + output.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("image", frame1.At(Timestamp(20000)))); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output.front().Get>().size(), 1); +} + +#endif // #if !defined(__APPLE__) && !__ANDROID__ + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/modules/face_detection/face_detection_without_roi.pbtxt b/mediapipe/modules/face_detection/face_detection_without_roi.pbtxt new file mode 100644 index 0000000000..b72893460f --- /dev/null +++ b/mediapipe/modules/face_detection/face_detection_without_roi.pbtxt @@ -0,0 +1,35 @@ +# MediaPipe graph to detect faces. +# This graph omits the "ROI" input stream of the FaceDetection graph. +# For now top level graph input streams can only be ommitted using an +# enclosing graph, see b/202896911. +# TODO: Remove this graph after b/202896911 is addressed. +# +# EXAMPLE: +# node { +# calculator: "FaceDetectionFrontCpu" +# input_stream: "IMAGE:image" +# output_stream: "DETECTIONS:face_detections" +# } + +type: "FaceDetectionWithoutRoi" + +# The input image, either ImageFrame, GpuBuffer, or (multi-backend) Image. +input_stream: "IMAGE:image" + +# Detected faces. (std::vector) +output_stream: "DETECTIONS:detections" + +# The face detection graph options. +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} +} + +node { + calculator: "FaceDetection" + input_stream: "IMAGE:image" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} + } + option_value: "OPTIONS:options" +} diff --git a/mediapipe/python/pybind/packet_getter.cc b/mediapipe/python/pybind/packet_getter.cc index 0abe928a59..f0cc84f3f3 100644 --- a/mediapipe/python/pybind/packet_getter.cc +++ b/mediapipe/python/pybind/packet_getter.cc @@ -247,7 +247,22 @@ void PublicPacketGetters(pybind11::module* m) { )doc"); m->def( - "get_float_list", &GetContent>, + "get_float_list", + [](const Packet& packet) { + if (packet.ValidateAsType>().ok()) { + return packet.Get>(); + } else if (packet.ValidateAsType>().ok()) { + auto float_array = packet.Get>(); + return std::vector(float_array.begin(), float_array.end()); + } else if (packet.ValidateAsType>().ok()) { + auto float_array = packet.Get>(); + return std::vector(float_array.begin(), float_array.end()); + } else { + throw RaisePyError(PyExc_ValueError, + "Packet doesn't contain std::vector or " + "std::array containers."); + } + }, R"doc(Get the content of a MediaPipe float vector Packet as a float list. Args: diff --git a/mediapipe/python/solution_base.py b/mediapipe/python/solution_base.py index b33d116acf..e6d36d01bf 100644 --- a/mediapipe/python/solution_base.py +++ b/mediapipe/python/solution_base.py @@ -28,6 +28,7 @@ class contains the shared logic among the high-level Solution APIs including import numpy as np +from google.protobuf.internal import containers from google.protobuf import descriptor from google.protobuf import message # resources dependency @@ -216,6 +217,7 @@ def __init__( binary_graph_path: Optional[str] = None, graph_config: Optional[calculator_pb2.CalculatorGraphConfig] = None, calculator_params: Optional[Mapping[str, Any]] = None, + graph_options: Optional[message.Message] = None, side_inputs: Optional[Mapping[str, Any]] = None, outputs: Optional[List[str]] = None, stream_type_hints: Optional[Mapping[str, PacketDataType]] = None): @@ -227,6 +229,7 @@ def __init__( format. calculator_params: A mapping from the {calculator_name}.{options_field_name} str to the field value. + graph_options: The graph options protobuf for the mediapipe graph. side_inputs: A mapping from the side packet name to the packet raw data. outputs: A list of the graph output stream names to observe. If the list is empty, all the output streams listed in the graph config will be @@ -267,6 +270,10 @@ def __init__( if calculator_params: self._modify_calculator_options(canonical_graph_config_proto, calculator_params) + if graph_options: + self._set_extension(canonical_graph_config_proto.graph_options, + graph_options) + self._graph = calculator_graph.CalculatorGraph( graph_config=canonical_graph_config_proto) self._simulated_timestamp = 0 @@ -530,6 +537,50 @@ def modify_options_fields(calculator_options, options_field_list): if num_modified < len(nested_calculator_params): raise ValueError('Not all calculator params are valid.') + def create_graph_options(self, options_message: message.Message, + values: Mapping[str, Any]) -> message.Message: + """Sets protobuf field values. + + Args: + options_message: the options protobuf message. + values: field value pairs, where each field may be a "." separated path. + + Returns: + the options protobuf message. + """ + + if hasattr(values, 'items'): + values = values.items() + for pair in values: + (field, value) = pair + fields = field.split('.') + m = options_message + while len(fields) > 1: + m = getattr(m, fields[0]) + del fields[0] + v = getattr(m, fields[0]) + if hasattr(v, 'append'): + del v[:] + v.extend(value) + elif hasattr(v, 'CopyFrom'): + v.CopyFrom(value) + else: + setattr(m, fields[0], value) + return options_message + + def _set_extension(self, + extension_list: containers.RepeatedCompositeFieldContainer, + extension_value: message.Message) -> None: + """Sets one value in a repeated protobuf.Any extension field.""" + for extension_any in extension_list: + if extension_any.Is(extension_value.DESCRIPTOR): + v = type(extension_value)() + extension_any.Unpack(v) + v.MergeFrom(extension_value) + extension_any.Pack(v) + return + extension_list.add().Pack(extension_value) + def _make_packet(self, packet_data_type: PacketDataType, data: Any) -> packet.Packet: if (packet_data_type == PacketDataType.IMAGE_FRAME or diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index ea5d881cbf..bebcbe97c9 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -28,7 +28,7 @@ _PRESENCE_THRESHOLD = 0.5 _VISIBILITY_THRESHOLD = 0.5 -_RGB_CHANNELS = 3 +_BGR_CHANNELS = 3 WHITE_COLOR = (224, 224, 224) BLACK_COLOR = (0, 0, 0) @@ -74,7 +74,7 @@ def draw_detection( """Draws the detction bounding box and keypoints on the image. Args: - image: A three channel RGB image represented as numpy ndarray. + image: A three channel BGR image represented as numpy ndarray. detection: A detection proto message to be annotated on the image. keypoint_drawing_spec: A DrawingSpec object that specifies the keypoints' drawing settings such as color, line thickness, and circle radius. @@ -83,13 +83,13 @@ def draw_detection( Raises: ValueError: If one of the followings: - a) If the input image is not three channel RGB. + a) If the input image is not three channel BGR. b) If the location data is not relative data. """ if not detection.location_data: return - if image.shape[2] != _RGB_CHANNELS: - raise ValueError('Input image must contain three channel rgb data.') + if image.shape[2] != _BGR_CHANNELS: + raise ValueError('Input image must contain three channel bgr data.') image_rows, image_cols, _ = image.shape location = detection.location_data @@ -130,7 +130,7 @@ def draw_landmarks( """Draws the landmarks and the connections on the image. Args: - image: A three channel RGB image represented as numpy ndarray. + image: A three channel BGR image represented as numpy ndarray. landmark_list: A normalized landmark list proto message to be annotated on the image. connections: A list of landmark index tuples that specifies how landmarks to @@ -147,13 +147,13 @@ def draw_landmarks( Raises: ValueError: If one of the followings: - a) If the input image is not three channel RGB. + a) If the input image is not three channel BGR. b) If any connetions contain invalid landmark index. """ if not landmark_list: return - if image.shape[2] != _RGB_CHANNELS: - raise ValueError('Input image must contain three channel rgb data.') + if image.shape[2] != _BGR_CHANNELS: + raise ValueError('Input image must contain three channel bgr data.') image_rows, image_cols, _ = image.shape idx_to_coordinates = {} for idx, landmark in enumerate(landmark_list.landmark): @@ -208,7 +208,7 @@ def draw_axis( """Draws the 3D axis on the image. Args: - image: A three channel RGB image represented as numpy ndarray. + image: A three channel BGR image represented as numpy ndarray. rotation: Rotation matrix from object to camera coordinate frame. translation: Translation vector from object to camera coordinate frame. focal_length: camera focal length along x and y directions. @@ -219,10 +219,10 @@ def draw_axis( Raises: ValueError: If one of the followings: - a) If the input image is not three channel RGB. + a) If the input image is not three channel BGR. """ - if image.shape[2] != _RGB_CHANNELS: - raise ValueError('Input image must contain three channel rgb data.') + if image.shape[2] != _BGR_CHANNELS: + raise ValueError('Input image must contain three channel bgr data.') image_rows, image_cols, _ = image.shape # Create axis points in camera coordinate frame. axis_world = np.float32([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) diff --git a/mediapipe/python/solutions/drawing_utils_test.py b/mediapipe/python/solutions/drawing_utils_test.py index 7a0acb31a9..0039f9a905 100644 --- a/mediapipe/python/solutions/drawing_utils_test.py +++ b/mediapipe/python/solutions/drawing_utils_test.py @@ -27,7 +27,8 @@ DEFAULT_BBOX_DRAWING_SPEC = drawing_utils.DrawingSpec() DEFAULT_CONNECTION_DRAWING_SPEC = drawing_utils.DrawingSpec() -DEFAULT_CIRCLE_DRAWING_SPEC = drawing_utils.DrawingSpec(color=(0, 0, 255)) +DEFAULT_CIRCLE_DRAWING_SPEC = drawing_utils.DrawingSpec( + color=drawing_utils.RED_COLOR) DEFAULT_AXIS_DRAWING_SPEC = drawing_utils.DrawingSpec() DEFAULT_CYCLE_BORDER_COLOR = (224, 224, 224) @@ -37,13 +38,13 @@ class DrawingUtilTest(parameterized.TestCase): def test_invalid_input_image(self): image = np.arange(18, dtype=np.uint8).reshape(3, 3, 2) with self.assertRaisesRegex( - ValueError, 'Input image must contain three channel rgb data.'): + ValueError, 'Input image must contain three channel bgr data.'): drawing_utils.draw_landmarks(image, landmark_pb2.NormalizedLandmarkList()) with self.assertRaisesRegex( - ValueError, 'Input image must contain three channel rgb data.'): + ValueError, 'Input image must contain three channel bgr data.'): drawing_utils.draw_detection(image, detection_pb2.Detection()) with self.assertRaisesRegex( - ValueError, 'Input image must contain three channel rgb data.'): + ValueError, 'Input image must contain three channel bgr data.'): rotation = np.eye(3, dtype=np.float32) translation = np.array([0., 0., 1.]) drawing_utils.draw_axis(image, rotation, translation) diff --git a/mediapipe/python/solutions/face_detection.py b/mediapipe/python/solutions/face_detection.py index 7d4da8fe97..eee9f1175d 100644 --- a/mediapipe/python/solutions/face_detection.py +++ b/mediapipe/python/solutions/face_detection.py @@ -19,13 +19,7 @@ import numpy as np from mediapipe.framework.formats import detection_pb2 from mediapipe.framework.formats import location_data_pb2 -# pylint: disable=unused-import -from mediapipe.calculators.tensor import image_to_tensor_calculator_pb2 -from mediapipe.calculators.tensor import inference_calculator_pb2 -from mediapipe.calculators.tensor import tensors_to_detections_calculator_pb2 -from mediapipe.calculators.tflite import ssd_anchors_calculator_pb2 -from mediapipe.calculators.util import non_max_suppression_calculator_pb2 -# pylint: enable=unused-import +from mediapipe.modules.face_detection import face_detection_pb2 from mediapipe.python.solution_base import SolutionBase _SHORT_RANGE_GRAPH_FILE_PATH = 'mediapipe/modules/face_detection/face_detection_short_range_cpu.binarypb' @@ -84,14 +78,13 @@ def __init__(self, min_detection_confidence=0.5, model_selection=0): """ binary_graph_path = _FULL_RANGE_GRAPH_FILE_PATH if model_selection == 1 else _SHORT_RANGE_GRAPH_FILE_PATH - subgraph_name = 'facedetectionfullrangecommon' if model_selection == 1 else 'facedetectionshortrangecommon' super().__init__( binary_graph_path=binary_graph_path, - calculator_params={ - subgraph_name + '__TensorsToDetectionsCalculator.min_score_thresh': - min_detection_confidence, - }, + graph_options=self.create_graph_options( + face_detection_pb2.FaceDetectionOptions(), { + 'min_score_thresh': min_detection_confidence, + }), outputs=['detections']) def process(self, image: np.ndarray) -> NamedTuple: diff --git a/mediapipe/python/solutions/face_mesh.py b/mediapipe/python/solutions/face_mesh.py index 1fe9d91cca..997c0661df 100644 --- a/mediapipe/python/solutions/face_mesh.py +++ b/mediapipe/python/solutions/face_mesh.py @@ -99,7 +99,7 @@ def __init__(self, 'use_prev_landmarks': not static_image_mode, }, calculator_params={ - 'facedetectionshortrangecpu__facedetectionshortrangecommon__TensorsToDetectionsCalculator.min_score_thresh': + 'facedetectionshortrangecpu__facedetectionshortrange__facedetection__TensorsToDetectionsCalculator.min_score_thresh': min_detection_confidence, 'facelandmarkcpu__ThresholdingCalculator.threshold': min_tracking_confidence, diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index 12a34a4f5c..d41a124422 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -137,6 +137,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":label_map_cc_proto", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/mediapipe/util/label_map.proto b/mediapipe/util/label_map.proto index 5d1123fb22..79301d2b63 100644 --- a/mediapipe/util/label_map.proto +++ b/mediapipe/util/label_map.proto @@ -33,8 +33,3 @@ message LabelMapItem { // hierarchy. repeated string child_name = 3; } - -// Mapping from index to a label map item. -message LabelMap { - map index_to_item = 1; -} diff --git a/mediapipe/util/label_map_util.cc b/mediapipe/util/label_map_util.cc index 849cf4299d..914a2ba765 100644 --- a/mediapipe/util/label_map_util.cc +++ b/mediapipe/util/label_map_util.cc @@ -25,7 +25,7 @@ namespace mediapipe { -absl::StatusOr BuildLabelMapFromFiles( +absl::StatusOr> BuildLabelMapFromFiles( absl::string_view labels_file_contents, absl::string_view display_names_file) { if (labels_file_contents.empty()) { @@ -68,9 +68,9 @@ absl::StatusOr BuildLabelMapFromFiles( label_map_items[i].set_display_name(display_names[i]); } } - LabelMap label_map; + proto_ns::Map label_map; for (int i = 0; i < label_map_items.size(); ++i) { - (*label_map.mutable_index_to_item())[i] = label_map_items[i]; + label_map[i] = label_map_items[i]; } return label_map; } diff --git a/mediapipe/util/label_map_util.h b/mediapipe/util/label_map_util.h index 75a5f7e751..cef2618b89 100644 --- a/mediapipe/util/label_map_util.h +++ b/mediapipe/util/label_map_util.h @@ -16,6 +16,8 @@ #define MEDIAPIPE_UTIL_LABEL_MAP_UTIL_H_ #include "absl/strings/string_view.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/statusor.h" #include "mediapipe/util/label_map.pb.h" @@ -25,9 +27,9 @@ namespace mediapipe { // both expected to contain one label per line. // Returns an error e.g. if there's a mismatch between the number of labels and // display names. -absl::StatusOr BuildLabelMapFromFiles( - absl::string_view labels_file_contents, - absl::string_view display_names_file); +absl::StatusOr> +BuildLabelMapFromFiles(absl::string_view labels_file_contents, + absl::string_view display_names_file); } // namespace mediapipe diff --git a/mediapipe/util/sequence/README.md b/mediapipe/util/sequence/README.md index e3e0172904..d1b6250555 100644 --- a/mediapipe/util/sequence/README.md +++ b/mediapipe/util/sequence/README.md @@ -422,7 +422,7 @@ tasks and tracking (or class) fields for tracking information. |`region/point/x`|feature list float list|`add_bbox_point_x` / `AddBBoxPointX`|A list of normalized x values for points in a frame.| |`region/point/y`|feature list float list|`add_bbox_point_y` / `AddBBoxPointY`|A list of normalized y values for points in a frame.| |`region/point/\*`| *special* |`add_bbox_point` / `AddBBoxPoint`|Operates on point/x,point/y with a single call.| -|`region/point/radius`|feature list float list|`add_bbox_point_radius` / `AddBBoxPointRadius`|A list of radii for points in a frame.| +|`region/radius`|feature list float list|`add_bbox_point_radius` / `AddBBoxRadius`|A list of radii for points in a frame.| |`region/3d_point/x`|feature list float list|`add_bbox_3d_point_x` / `AddBBox3dPointX`|A list of normalized x values for points in a frame.| |`region/3d_point/y`|feature list float list|`add_bbox_3d_point_y` / `AddBBox3dPointY`|A list of normalized y values for points in a frame.| |`region/3d_point/z`|feature list float list|`add_bbox_3d_point_z` / `AddBBox3dPointZ`|A list of normalized z values for points in a frame.| @@ -460,6 +460,7 @@ tasks and tracking (or class) fields for tracking information. |`image/label/confidence`|feature list float list|`add_image_label_confidence` / `AddImageLabelConfidence`|If an image at a specific timestamp should have a label, use this. If a range of time, prefer Segments instead.| |`image/format`|context bytes|`set_image_format` / `SetImageFormat`|The encoding format of the images.| |`image/channels`|context int|`set_image_channels` / `SetImageChannels`|The number of channels in the image.| +|`image/colorspace`|context bytes|`set_image_colorspace` / `SetColorspace`|The colorspace of the images.| |`image/height`|context int|`set_image_height` / `SetImageHeight`|The height of the image in pixels.| |`image/width`|context int|`set_image_width` / `SetImageWidth`|The width of the image in pixels.| |`image/frame_rate`|context float|`set_image_frame_rate` / `SetImageFrameRate`|The rate of images in frames per second.| diff --git a/mediapipe/util/sequence/media_sequence.cc b/mediapipe/util/sequence/media_sequence.cc index eef63c5419..62bfa19b48 100644 --- a/mediapipe/util/sequence/media_sequence.cc +++ b/mediapipe/util/sequence/media_sequence.cc @@ -161,7 +161,8 @@ absl::Status ReconcileMetadataFeatureFloats( if (absl::StrContains(key, kFeatureFloatsKey)) { const auto prefix = key.substr(0, key.find(kFeatureFloatsKey) - 1); int number_of_elements = GetFeatureFloatsAt(prefix, *sequence, 0).size(); - if (HasFeatureDimensions(prefix, *sequence)) { + if (HasFeatureDimensions(prefix, *sequence) && + !GetFeatureDimensions(prefix, *sequence).empty()) { int64 product = 1; for (int64 value : GetFeatureDimensions(prefix, *sequence)) { product *= value; diff --git a/mediapipe/util/sequence/media_sequence.h b/mediapipe/util/sequence/media_sequence.h index 2f71024b4e..8b55bfd919 100644 --- a/mediapipe/util/sequence/media_sequence.h +++ b/mediapipe/util/sequence/media_sequence.h @@ -501,6 +501,9 @@ void Clear3dPoint(const std::string& prefix, FIXED_PREFIX_VECTOR_BYTES_FEATURE_LIST( \ CONCAT_STR2(identifier, EmbeddingEncoded), kRegionEmbeddingEncodedKey, \ prefix) \ + FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST( \ + CONCAT_STR2(identifier, EmbeddingConfidence), \ + kRegionEmbeddingConfidenceKey, prefix) \ FIXED_PREFIX_VECTOR_INT64_CONTEXT_FEATURE( \ CONCAT_STR2(identifier, EmbeddingDimensionsPerRegion), \ kRegionEmbeddingDimensionsPerRegionKey, prefix) \ diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index ca3021ea4d..3402aea557 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -400,12 +400,20 @@ TEST(MediaSequenceTest, RoundTripBBoxEmbedding) { tensorflow::SequenceExample sequence; std::vector> embeddings = { {"embedding00", "embedding01"}, {"embedding10", "embedding11"}}; + std::vector> confidences = {{0.7, 0.8}, {0.9, 0.95}}; for (int i = 0; i < embeddings.size(); ++i) { AddBBoxEmbeddingEncoded("GT_KEY", embeddings[i], &sequence); ASSERT_EQ(GetBBoxEmbeddingEncodedSize("GT_KEY", sequence), i + 1); const auto& sequence_embeddings = GetBBoxEmbeddingEncodedAt("GT_KEY", sequence, i); EXPECT_THAT(sequence_embeddings, testing::ElementsAreArray(embeddings[i])); + + AddBBoxEmbeddingConfidence("GT_KEY", confidences[i], &sequence); + ASSERT_EQ(GetBBoxEmbeddingConfidenceSize("GT_KEY", sequence), i + 1); + const auto& sequence_confidences = + GetBBoxEmbeddingConfidenceAt("GT_KEY", sequence, i); + EXPECT_THAT(sequence_confidences, + testing::ElementsAreArray(confidences[i])); } } diff --git a/mediapipe/util/tracking/tracking.cc b/mediapipe/util/tracking/tracking.cc index 8d0afb08f4..7e80cd5cef 100644 --- a/mediapipe/util/tracking/tracking.cc +++ b/mediapipe/util/tracking/tracking.cc @@ -648,7 +648,6 @@ bool MotionBoxLines(const MotionBoxState& state, const Vector2_f& scaling, std::array* box_lines) { CHECK(box_lines); std::array corners = MotionBoxCorners(state, scaling); - std::array lines; for (int k = 0; k < 4; ++k) { const Vector2_f diff = corners[(k + 1) % 4] - corners[k]; const Vector2_f normal = diff.Ortho().Normalize(); diff --git a/requirements.txt b/requirements.txt index b37158d806..00e51ffd65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ attrs>=19.1.0 matplotlib numpy opencv-contrib-python -protobuf>=3.11.4 +protobuf>=3.11,<4 diff --git a/setup_android_sdk_and_ndk.sh b/setup_android_sdk_and_ndk.sh index c16021eda5..edb27deb74 100644 --- a/setup_android_sdk_and_ndk.sh +++ b/setup_android_sdk_and_ndk.sh @@ -17,7 +17,7 @@ # Script to setup Android SDK and NDK. # usage: # $ cd -# $ bash ./setup_android_sdk_and_ndk.sh ~/Android/Sdk ~/Android/Ndk r21 +# $ bash ./setup_android_sdk_and_ndk.sh ~/Android/Sdk ~/Android/Ndk r21 [--accept-licenses] set -e @@ -39,6 +39,7 @@ fi android_sdk_path=$1 android_ndk_path=$2 ndk_version=$3 +licenses=$4 if [ -z $1 ] then @@ -68,6 +69,10 @@ else unzip /tmp/android_sdk/commandline_tools.zip -d /tmp/android_sdk/ mkdir -p $android_sdk_path /tmp/android_sdk/cmdline-tools/bin/sdkmanager --update --sdk_root=${android_sdk_path} + if [ "$licenses" == "--accept-licenses" ] + then + yes | /tmp/android_sdk/cmdline-tools/bin/sdkmanager --licenses --sdk_root=${android_sdk_path} + fi /tmp/android_sdk/cmdline-tools/bin/sdkmanager "build-tools;30.0.3" "platform-tools" "platforms;android-30" "extras;android;m2repository" --sdk_root=${android_sdk_path} rm -rf /tmp/android_sdk/ echo "Android SDK is now installed. Consider setting \$ANDROID_HOME environment variable to be ${android_sdk_path}"