Skip to content

Commit

Permalink
Add MODEL_VIEW side input to tflite_model_calculator
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609184982
  • Loading branch information
MediaPipe Team authored and copybara-github committed Feb 22, 2024
1 parent b15e5c0 commit d2bc9e5
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 12 deletions.
4 changes: 3 additions & 1 deletion mediapipe/calculators/tflite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.
#

load("@bazel_skylib//lib:selects.bzl", "selects")
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
load("@bazel_skylib//lib:selects.bzl", "selects")

licenses(["notice"])

Expand Down Expand Up @@ -322,6 +322,7 @@ cc_library_with_tflite(
"//mediapipe/framework:packet",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
],
alwayslink = 1,
)
Expand Down Expand Up @@ -547,6 +548,7 @@ cc_test(
"//mediapipe/calculators/util:local_file_contents_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@org_tensorflow//tensorflow/lite:framework",
Expand Down
41 changes: 31 additions & 10 deletions mediapipe/calculators/tflite/tflite_model_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
#include <string>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/model_builder.h"

namespace mediapipe {

Expand All @@ -36,6 +37,10 @@ namespace mediapipe {
// blob and use it as input here.
// MODEL_FD - Tflite model file descriptor std::tuple<int, size_t, size_t>
// containing (fd, offset, size).
// MODEL_VIEW - TfLite model file contents in absl::string_view, whose
// underline buffer is owned outside of this calculator. User can
// get the model view from a managed environment and pass it to
// the graph as input side packet.
//
// Output side packets:
// MODEL - TfLite model. (std::unique_ptr<tflite::FlatBufferModel,
Expand All @@ -55,17 +60,25 @@ class TfLiteModelCalculator : public CalculatorBase {
std::unique_ptr<tflite::FlatBufferModel,
std::function<void(tflite::FlatBufferModel*)>>;

static constexpr absl::string_view kModelViewTag = "MODEL_VIEW";
static constexpr absl::string_view kModelBlobTag = "MODEL_BLOB";
static constexpr absl::string_view kModelFDTag = "MODEL_FD";

static absl::Status GetContract(CalculatorContract* cc) {
if (cc->InputSidePackets().HasTag("MODEL_BLOB")) {
cc->InputSidePackets().Tag("MODEL_BLOB").Set<std::string>();
if (cc->InputSidePackets().HasTag(kModelBlobTag)) {
cc->InputSidePackets().Tag(kModelBlobTag).Set<std::string>();
}

if (cc->InputSidePackets().HasTag("MODEL_FD")) {
if (cc->InputSidePackets().HasTag(kModelFDTag)) {
cc->InputSidePackets()
.Tag("MODEL_FD")
.Tag(kModelFDTag)
.Set<std::tuple<int, size_t, size_t>>();
}

if (cc->InputSidePackets().HasTag(kModelViewTag)) {
cc->InputSidePackets().Tag(kModelViewTag).Set<absl::string_view>();
}

cc->OutputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
return absl::OkStatus();
}
Expand All @@ -74,16 +87,24 @@ class TfLiteModelCalculator : public CalculatorBase {
Packet model_packet;
std::unique_ptr<tflite::FlatBufferModel> model;

if (cc->InputSidePackets().HasTag("MODEL_BLOB")) {
model_packet = cc->InputSidePackets().Tag("MODEL_BLOB");
if (cc->InputSidePackets().HasTag(kModelBlobTag)) {
model_packet = cc->InputSidePackets().Tag(kModelBlobTag);
const std::string& model_blob = model_packet.Get<std::string>();
model = tflite::FlatBufferModel::BuildFromBuffer(model_blob.data(),
model_blob.size());
}

if (cc->InputSidePackets().HasTag("MODEL_FD")) {
if (cc->InputSidePackets().HasTag(kModelViewTag)) {
model_packet = cc->InputSidePackets().Tag(kModelViewTag);
const absl::string_view& model_view =
model_packet.Get<absl::string_view>();
model = tflite::FlatBufferModel::BuildFromBuffer(model_view.data(),
model_view.size());
}

if (cc->InputSidePackets().HasTag(kModelFDTag)) {
#if defined(ABSL_HAVE_MMAP) && !TFLITE_WITH_STABLE_ABI
model_packet = cc->InputSidePackets().Tag("MODEL_FD");
model_packet = cc->InputSidePackets().Tag(kModelFDTag);
const auto& model_fd =
model_packet.Get<std::tuple<int, size_t, size_t>>();
auto model_allocation = std::make_unique<tflite::MMAPAllocation>(
Expand All @@ -97,7 +118,7 @@ class TfLiteModelCalculator : public CalculatorBase {
#endif
}

RET_CHECK(model) << "Failed to load TfLite model from blob.";
RET_CHECK(model) << "Failed to load TfLite model.";

cc->OutputSidePackets().Tag("MODEL").Set(
MakePacket<TfLiteModelPtr>(TfLiteModelPtr(
Expand Down
53 changes: 52 additions & 1 deletion mediapipe/calculators/tflite/tflite_model_calculator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/model_builder.h"

namespace mediapipe {

Expand Down Expand Up @@ -85,4 +86,54 @@ TEST(TfLiteModelCalculatorTest, SmokeTest) {
}
}

TEST(TfLiteModelCalculatorTest, LoadFromModelView) {
std::string model_content;
MP_ASSERT_OK(mediapipe::file::GetContents(
"mediapipe/calculators/tflite/testdata/add.bin", &model_content));

// Prepare single calculator graph to and wait for packets.
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
input_side_packet: "model_view"
node {
calculator: "TfLiteModelCalculator"
input_side_packet: "MODEL_VIEW:model_view"
output_side_packet: "MODEL:model"
}
)pb");
CalculatorGraph graph(graph_config);
MP_ASSERT_OK(graph.StartRun(
{{"model_view",
mediapipe::MakePacket<absl::string_view>(model_content)}}));
MP_ASSERT_OK(graph.WaitUntilIdle());
auto status_or_packet = graph.GetOutputSidePacket("model");
MP_ASSERT_OK(status_or_packet);
auto model_packet = status_or_packet.value();
const auto& model = model_packet.Get<
std::unique_ptr<tflite::FlatBufferModel,
std::function<void(tflite::FlatBufferModel*)>>>();

auto expected_model = tflite::FlatBufferModel::BuildFromFile(
"mediapipe/calculators/tflite/testdata/add.bin");

EXPECT_EQ(model->GetModel()->version(),
expected_model->GetModel()->version());
EXPECT_EQ(model->GetModel()->buffers()->size(),
expected_model->GetModel()->buffers()->size());
const int num_subgraphs = expected_model->GetModel()->subgraphs()->size();
EXPECT_EQ(model->GetModel()->subgraphs()->size(), num_subgraphs);
for (int i = 0; i < num_subgraphs; ++i) {
const auto* expected_subgraph =
expected_model->GetModel()->subgraphs()->Get(i);
const auto* subgraph = model->GetModel()->subgraphs()->Get(i);
const int num_tensors = expected_subgraph->tensors()->size();
EXPECT_EQ(subgraph->tensors()->size(), num_tensors);
for (int j = 0; j < num_tensors; ++j) {
EXPECT_EQ(subgraph->tensors()->Get(j)->name()->str(),
expected_subgraph->tensors()->Get(j)->name()->str());
}
}
}

} // namespace mediapipe

0 comments on commit d2bc9e5

Please sign in to comment.