forked from dneprDroid/tfsecured
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencrypt_model.py
109 lines (79 loc) · 2.84 KB
/
encrypt_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import base64
import hashlib
import os
import sys
import string
import random
try:
from Crypto import Random
from Crypto.Cipher import AES
except:
raise Exception('Install Crypto! \n pip install pycrypto')
try:
import tensorflow as tf
except:
raise Exception('Install Tensorflow!')
class AESCipher(object):
def __init__(self, _key):
self.bs = 32
self.key = hashlib.sha256(_key.encode()).digest()
def encrypt(self, raw):
raw = self._pad(raw)
iv = Random.new().read(AES.block_size)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return iv + cipher.encrypt(raw)
def decrypt(self, enc):
iv = enc[:AES.block_size]
print('Iv: %s' % iv)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size:])) # .decode('utf-8')
def _pad(self, s):
return s + str.encode((self.bs - len(s) % self.bs) * chr(self.bs - len(s) % self.bs))
@staticmethod
def _unpad(s):
return s[:-ord(s[len(s) - 1:])]
############### Util Methods ###############
def load_graph(path):
if not tf.gfile.Exists(path):
raise Exception('File doesn\'t exist at path: %s' % path)
with tf.gfile.GFile(path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
f.close()
tf.import_graph_def(graph_def, name=None)
return graph_def
def generate_output_path(input_path, suffix):
filename, file_extension = os.path.splitext(input_path)
return filename + suffix + file_extension
def random_string(size=30):
return ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(size))
def read_arg(index, default=None, err_msg=None):
def print_error():
if err_msg is not None:
raise Exception(err_msg)
else:
raise Exception('Not found arg with index %s' % index)
if len(sys.argv) <= index:
if default is not None:
return default
print_error()
return sys.argv[index]
#############################################
def main():
usage = 'python encrypt_model.py <INPUT_PB_MODEL> <OUTPUT_PB_MODEL> <KEY>'
print('\nUSAGE: %s\n' % usage)
# Args:
input_path = read_arg(1, default='demo/models/saved_model.pb')
default_out = generate_output_path(input_path, '-encrypted')
output_path = read_arg(2, default=default_out)
KEY = read_arg(3, default=random_string())
graph_def = load_graph(input_path)
cipher = AESCipher(KEY)
nodes_binary_str = graph_def.SerializeToString()
nodes_binary_str = cipher.encrypt(nodes_binary_str)
with tf.gfile.GFile(output_path, 'wb') as f:
f.write(nodes_binary_str)
f.close()
print('Saved with key="%s" to %s' % (KEY, output_path))
if __name__ == "__main__":
main()