forked from dneprDroid/tfsecured
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTFPredictor.mm
102 lines (80 loc) · 2.78 KB
/
TFPredictor.mm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
//
// TFPredictor.m
// TFSecured
//
// Created by user on 5/3/18.
// Copyright © 2018 user. All rights reserved.
//
#import "TFPredictor.h"
#import "NSError+Util.h"
#import "Utils.hpp"
#include <tensorflow/core/framework/op.h>
#include <tensorflow/core/framework/op_kernel.h>
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/framework/shape_inference.h>
#include <iostream>
#include <fstream>
#include "../../../../TFSecured/GraphDefDecryptor.hpp"
using namespace tensorflow;
@interface TFPredictor () {
GraphDef *graph;
std::string inNode;
std::string outNode;
}
@property(copy, nonatomic) NSString *modelPath;
@end
@implementation TFPredictor
+ (instancetype)initWith:(NSString*)modelPath
inputNodeName:(NSString*)inNode
outputNodeName:(NSString*)outNode {
TFPredictor *pred = [self new];
pred.modelPath = modelPath;
pred->inNode = std::string([inNode cStringUsingEncoding:NSUTF8StringEncoding]);
pred->outNode = std::string([outNode cStringUsingEncoding:NSUTF8StringEncoding]);
pred->graph = new GraphDef;
return pred;
}
- (void)loadModelWithKey:(NSString*)key error:(nullable TFErrorCallback) callback {
std::string keyUnhashed([key cStringUsingEncoding:NSUTF8StringEncoding]);
const char * path = [self.modelPath cStringUsingEncoding: NSUTF8StringEncoding];
std::cout << "Loading pb model from path: " << path << std::endl;
auto status = tfsecured::GraphDefDecryptAES(path, graph, keyUnhashed);
if (!status.ok()) {
printf("Error reading graph: %s\n", status.error_message().c_str());
if (callback)
callback([NSError withMsg: @"Error reading graph"
code: status.code()
localized: NSStringFromCString(status.error_message().c_str())]);
return;
}
}
- (void)predictTensor:(const Tensor&)input output: (Tensor*)output {
SessionOptions options;
Status status;
std::unique_ptr<Session> session(NewSession(options));
status = session->Create(*graph);
if (!status.ok()) {
printf("Error creating session: %s\n", status.error_message().c_str());
return;
}
std::cout << "Tensor input shape: " << input.shape().DebugString() << "\n";
std::vector<tensorflow::Tensor> outputs;
status = session->Run({{inNode, input}},
{outNode},
{},
&outputs);
if (!status.ok()) {
std::cout << "Session running is failed!" << "\n";
return;
}
if (outputs.size() == 0) {
std::cout << "Outputs are empty!" << "\n";
return;
}
*output = outputs[0];
}
- (void)dealloc {
printf("...... TFPredictor deallocation ......\n");
delete graph;
}
@end