Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorFlow 2.0 above support added #1098

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ Python3, tensorflow 1.0, numpy, opencv 3.

### Getting started

You can choose _one_ of the following three ways to get started with darkflow.
- Install tf_slim module since contrib is deprecated from tf2+
```pip install git+https://github.com/ShanuDey/tf-slim.git```

1. Just build the Cython extensions in place. NOTE: If installing this way you will have to use `./flow` in the cloned darkflow directory instead of `flow` as darkflow is not installed globally.
- You can choose _one_ of the following three ways to get started with darkflow.

1. Just build the Cython extensions in place. NOTE: If installing this way you will have to use `./flow` in the cloned darkflow directory instead of `flow` as darkflow is not installed globally.
```
python3 setup.py build_ext --inplace
```

2. Let pip install darkflow globally in dev mode (still globally accessible, but changes to the code immediately take effect)
2. Let pip install darkflow globally in dev mode (still globally accessible, but changes to the code immediately take effect)
```
pip install -e .
```

3. Install with pip globally
3. Install with pip globally
```
pip install .
```
Expand Down
42 changes: 21 additions & 21 deletions darkflow/net/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
class TFNet(object):

_TRAINER = dict({
'rmsprop': tf.train.RMSPropOptimizer,
'adadelta': tf.train.AdadeltaOptimizer,
'adagrad': tf.train.AdagradOptimizer,
'adagradDA': tf.train.AdagradDAOptimizer,
'momentum': tf.train.MomentumOptimizer,
'adam': tf.train.AdamOptimizer,
'ftrl': tf.train.FtrlOptimizer,
'sgd': tf.train.GradientDescentOptimizer
'rmsprop': tf.compat.v1.train.RMSPropOptimizer,
'adadelta': tf.compat.v1.train.AdadeltaOptimizer,
'adagrad': tf.compat.v1.train.AdagradOptimizer,
'adagradDA': tf.compat.v1.train.AdagradDAOptimizer,
'momentum': tf.compat.v1.train.MomentumOptimizer,
'adam': tf.compat.v1.train.AdamOptimizer,
'ftrl': tf.compat.v1.train.FtrlOptimizer,
'sgd': tf.compat.v1.train.GradientDescentOptimizer
})

# imported methods
Expand Down Expand Up @@ -78,8 +78,8 @@ def __init__(self, FLAGS, darknet = None):
time.time() - start))

def build_from_pb(self):
with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
graph_def = tf.GraphDef()
with tf.compat.v1.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())

tf.import_graph_def(
Expand All @@ -91,9 +91,9 @@ def build_from_pb(self):
self.framework = create_framework(self.meta, self.FLAGS)

# Placeholders
self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
self.inp = tf.compat.v1.get_default_graph().get_tensor_by_name('input:0')
self.feed = dict() # other placeholders
self.out = tf.get_default_graph().get_tensor_by_name('output:0')
self.out = tf.compat.v1.get_default_graph().get_tensor_by_name('output:0')

self.setup_meta_ops()

Expand All @@ -102,7 +102,7 @@ def build_forward(self):

# Placeholders
inp_size = [None] + self.meta['inp_size']
self.inp = tf.placeholder(tf.float32, inp_size, 'input')
self.inp = tf.compat.v1.placeholder(tf.float32, inp_size, 'input')
self.feed = dict() # other placeholders

# Build the forward pass
Expand All @@ -129,7 +129,7 @@ def setup_meta_ops(self):
utility = min(self.FLAGS.gpu, 1.)
if utility > 0.0:
self.say('GPU mode with {} usage'.format(utility))
cfg['gpu_options'] = tf.GPUOptions(
cfg['gpu_options'] = tf.compat.v1.GPUOptions(
per_process_gpu_memory_fraction = utility)
cfg['allow_soft_placement'] = True
else:
Expand All @@ -139,14 +139,14 @@ def setup_meta_ops(self):
if self.FLAGS.train: self.build_train_op()

if self.FLAGS.summary:
self.summary_op = tf.summary.merge_all()
self.writer = tf.summary.FileWriter(self.FLAGS.summary + 'train')
self.summary_op = tf.compat.v1.summary.merge_all()
self.writer = tf.compat.v1.summary.FileWriter(self.FLAGS.summary + 'train')

self.sess = tf.Session(config = tf.ConfigProto(**cfg))
self.sess.run(tf.global_variables_initializer())
self.sess = tf.compat.v1.Session(config = tf.compat.v1.ConfigProto(**cfg))
self.sess.run(tf.compat.v1.global_variables_initializer())

if not self.ntrain: return
self.saver = tf.train.Saver(tf.global_variables(),
self.saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(),
max_to_keep = self.FLAGS.keep)
if self.FLAGS.load != 0: self.load_from_ckpt()

Expand All @@ -165,7 +165,7 @@ def savepb(self):
flags_pb.train = False
# rebuild another tfnet. all const.
tfnet_pb = TFNet(flags_pb, darknet_pb)
tfnet_pb.sess = tf.Session(graph = tfnet_pb.graph)
tfnet_pb.sess = tf.compat.v1.Session(graph = tfnet_pb.graph)
# tfnet_pb.predict() # uncomment for unit testing
name = 'built_graph/{}.pb'.format(self.meta['name'])
os.makedirs(os.path.dirname(name), exist_ok=True)
Expand All @@ -174,4 +174,4 @@ def savepb(self):
json.dump(self.meta, fp)
self.say('Saving const graph def to {}'.format(name))
graph_def = tfnet_pb.sess.graph_def
tf.train.write_graph(graph_def,'./', name, False)
tf.io.write_graph(graph_def,'./', name, False)
8 changes: 4 additions & 4 deletions darkflow/net/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ def load_old_graph(self, ckpt):
ckpt_loader = create_loader(ckpt)
self.say(old_graph_msg.format(ckpt))

for var in tf.global_variables():
for var in tf.compat.v1.global_variables():
name = var.name.split(':')[0]
args = [name, var.get_shape()]
val = ckpt_loader(args)
assert val is not None, \
'Cannot find and load {}'.format(var.name)
shp = val.shape
plh = tf.placeholder(tf.float32, shp)
op = tf.assign(var, plh)
plh = tf.compat.v1.placeholder(tf.float32, shp)
op = tf.compat.v1.assign(var, plh)
self.sess.run(op, {plh: val})

def _get_fps(self, frame):
Expand Down Expand Up @@ -156,7 +156,7 @@ def to_darknet(self):
darknet_ckpt = self.darknet

with self.graph.as_default() as g:
for var in tf.global_variables():
for var in tf.compat.v1.global_variables():
name = var.name.split(':')[0]
var_name = name.split('-')
l_idx = int(var_name[0])
Expand Down
8 changes: 4 additions & 4 deletions darkflow/net/ops/baseop.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def wrap_variable(self, var):
if not self.var: return

val = self.lay.w[var]
self.lay.w[var] = tf.constant_initializer(val)
self.lay.w[var] = tf.compat.v1.constant_initializer(val)
if var in self._SLIM: return
with tf.variable_scope(self.scope):
self.lay.w[var] = tf.get_variable(var,
with tf.compat.v1.variable_scope(self.scope):
self.lay.w[var] = tf.compat.v1.get_variable(var,
shape = self.lay.wshape[var],
dtype = tf.float32,
initializer = self.lay.w[var])
Expand All @@ -81,7 +81,7 @@ def wrap_pholder(self, ph, feed):
sig = '{}/{}'.format(self.scope, ph)
val = self.lay.h[ph]

self.lay.h[ph] = tf.placeholder_with_default(
self.lay.h[ph] = tf.compat.v1.placeholder_with_default(
val['dfault'], val['shape'], name = sig)
feed[self.lay.h[ph]] = val['feed']

Expand Down
14 changes: 7 additions & 7 deletions darkflow/net/ops/convolution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow.contrib.slim as slim
import tf_slim as slim
from .baseop import BaseOp
import tensorflow as tf
import numpy as np
Expand All @@ -24,7 +24,7 @@ def _forward(self):
def forward(self):
inp = self.inp.out
s = self.lay.stride
self.out = tf.extract_image_patches(
self.out = tf.image.extract_patches(
inp, [1,s,s,1], [1,s,s,1], [1,1,1,1], 'VALID')

def speak(self):
Expand All @@ -36,7 +36,7 @@ def speak(self):
class local(BaseOp):
def forward(self):
pad = [[self.lay.pad, self.lay.pad]] * 2;
temp = tf.pad(self.inp.out, [[0, 0]] + pad + [[0, 0]])
temp = tf.pad(tensor=self.inp.out, paddings=[[0, 0]] + pad + [[0, 0]])

k = self.lay.w['kernels']
ksz = self.lay.ksize
Expand All @@ -49,7 +49,7 @@ def forward(self):
i_, j_ = i + 1 - half, j + 1 - half
tij = temp[:, i_ : i_ + ksz, j_ : j_ + ksz,:]
row_i.append(
tf.nn.conv2d(tij, kij,
tf.nn.conv2d(input=tij, filters=kij,
padding = 'VALID',
strides = [1] * 4))
out += [tf.concat(row_i, 2)]
Expand All @@ -66,8 +66,8 @@ def speak(self):
class convolutional(BaseOp):
def forward(self):
pad = [[self.lay.pad, self.lay.pad]] * 2;
temp = tf.pad(self.inp.out, [[0, 0]] + pad + [[0, 0]])
temp = tf.nn.conv2d(temp, self.lay.w['kernel'], padding = 'VALID',
temp = tf.pad(tensor=self.inp.out, paddings=[[0, 0]] + pad + [[0, 0]])
temp = tf.nn.conv2d(input=temp, filters=self.lay.w['kernel'], padding = 'VALID',
name = self.scope, strides = [1] + [self.lay.stride] * 2 + [1])
if self.lay.batch_norm:
temp = self.batchnorm(self.lay, temp)
Expand Down Expand Up @@ -113,4 +113,4 @@ def speak(self):
args += [l.batch_norm * '+bnorm']
args += [l.activation]
msg = 'extr {}x{}p{}_{} {} {}'.format(*args)
return msg
return msg
14 changes: 7 additions & 7 deletions darkflow/net/ops/simple.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow.contrib.slim as slim
import tf_slim as slim
from .baseop import BaseOp
import tensorflow as tf
from distutils.version import StrictVersion
Expand All @@ -22,7 +22,7 @@ def speak(self):

class connected(BaseOp):
def forward(self):
self.out = tf.nn.xw_plus_b(
self.out = tf.compat.v1.nn.xw_plus_b(
self.inp.out,
self.lay.w['weights'],
self.lay.w['biases'],
Expand Down Expand Up @@ -56,7 +56,7 @@ def speak(self):
class flatten(BaseOp):
def forward(self):
temp = tf.transpose(
self.inp.out, [0,3,1,2])
a=self.inp.out, perm=[0,3,1,2])
self.out = slim.flatten(
temp, scope = self.scope)

Expand All @@ -73,7 +73,7 @@ def speak(self): return 'softmax()'
class avgpool(BaseOp):
def forward(self):
self.out = tf.reduce_mean(
self.inp.out, [1, 2],
input_tensor=self.inp.out, axis=[1, 2],
name = self.scope
)

Expand All @@ -86,7 +86,7 @@ def forward(self):
self.lay.h['pdrop'] = 1.0
self.out = tf.nn.dropout(
self.inp.out,
self.lay.h['pdrop'],
1 - (self.lay.h['pdrop']),
name = self.scope
)

Expand All @@ -103,8 +103,8 @@ def speak(self):

class maxpool(BaseOp):
def forward(self):
self.out = tf.nn.max_pool(
self.inp.out, padding = 'SAME',
self.out = tf.nn.max_pool2d(
input=self.inp.out, padding = 'SAME',
ksize = [1] + [self.lay.ksize]*2 + [1],
strides = [1] + [self.lay.stride]*2 + [1],
name = self.scope
Expand Down
4 changes: 2 additions & 2 deletions darkflow/net/vanilla/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def loss(self, net_out):
loss = l1_loss(diff)

elif loss_type == 'softmax':
loss = tf.nn.softmax_cross_entropy_with_logits(logits, y)
loss = tf.reduce_mean(loss)
loss = tf.nn.softmax_cross_entropy_with_logits(labels=tf.stop_gradient(y))
loss = tf.reduce_mean(input_tensor=loss)

elif loss_type == 'svm':
assert 'train_size' in m, \
Expand Down
26 changes: 13 additions & 13 deletions darkflow/net/yolo/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow.contrib.slim as slim
import tf_slim as slim
import pickle
import tensorflow as tf
from .misc import show
Expand Down Expand Up @@ -30,15 +30,15 @@ def loss(self, net_out):
size2 = [None, SS, B]

# return the below placeholders
_probs = tf.placeholder(tf.float32, size1)
_confs = tf.placeholder(tf.float32, size2)
_coord = tf.placeholder(tf.float32, size2 + [4])
_probs = tf.compat.v1.placeholder(tf.float32, size1)
_confs = tf.compat.v1.placeholder(tf.float32, size2)
_coord = tf.compat.v1.placeholder(tf.float32, size2 + [4])
# weights term for L2 loss
_proid = tf.placeholder(tf.float32, size1)
_proid = tf.compat.v1.placeholder(tf.float32, size1)
# material calculating IOU
_areas = tf.placeholder(tf.float32, size2)
_upleft = tf.placeholder(tf.float32, size2 + [2])
_botright = tf.placeholder(tf.float32, size2 + [2])
_areas = tf.compat.v1.placeholder(tf.float32, size2)
_upleft = tf.compat.v1.placeholder(tf.float32, size2 + [2])
_botright = tf.compat.v1.placeholder(tf.float32, size2 + [2])

self.placeholders = {
'probs':_probs, 'confs':_confs, 'coord':_coord, 'proid':_proid,
Expand All @@ -63,8 +63,8 @@ def loss(self, net_out):

# calculate the best IOU, set 0.0 confidence for worse boxes
iou = tf.truediv(intersect, _areas + area_pred - intersect)
best_box = tf.equal(iou, tf.reduce_max(iou, [2], True))
best_box = tf.to_float(best_box)
best_box = tf.equal(iou, tf.reduce_max(input_tensor=iou, axis=[2], keepdims=True))
best_box = tf.cast(best_box, dtype=tf.float32)
confs = tf.multiply(best_box, _confs)

# take care of the weight terms
Expand All @@ -87,6 +87,6 @@ def loss(self, net_out):
print('Building {} loss'.format(m['model']))
loss = tf.pow(net_out - true, 2)
loss = tf.multiply(loss, wght)
loss = tf.reduce_sum(loss, 1)
self.loss = .5 * tf.reduce_mean(loss)
tf.summary.scalar('{} loss'.format(m['model']), self.loss)
loss = tf.reduce_sum(input_tensor=loss, axis=1)
self.loss = .5 * tf.reduce_mean(input_tensor=loss)
tf.compat.v1.summary.scalar('{} loss'.format(m['model']), self.loss)
26 changes: 13 additions & 13 deletions darkflow/net/yolov2/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow.contrib.slim as slim
import tf_slim as slim
import pickle
import tensorflow as tf
from ..yolo.misc import show
Expand Down Expand Up @@ -37,15 +37,15 @@ def loss(self, net_out):
size2 = [None, HW, B]

# return the below placeholders
_probs = tf.placeholder(tf.float32, size1)
_confs = tf.placeholder(tf.float32, size2)
_coord = tf.placeholder(tf.float32, size2 + [4])
_probs = tf.compat.v1.placeholder(tf.float32, size1)
_confs = tf.compat.v1.placeholder(tf.float32, size2)
_coord = tf.compat.v1.placeholder(tf.float32, size2 + [4])
# weights term for L2 loss
_proid = tf.placeholder(tf.float32, size1)
_proid = tf.compat.v1.placeholder(tf.float32, size1)
# material calculating IOU
_areas = tf.placeholder(tf.float32, size2)
_upleft = tf.placeholder(tf.float32, size2 + [2])
_botright = tf.placeholder(tf.float32, size2 + [2])
_areas = tf.compat.v1.placeholder(tf.float32, size2)
_upleft = tf.compat.v1.placeholder(tf.float32, size2 + [2])
_botright = tf.compat.v1.placeholder(tf.float32, size2 + [2])

self.placeholders = {
'probs':_probs, 'confs':_confs, 'coord':_coord, 'proid':_proid,
Expand Down Expand Up @@ -83,8 +83,8 @@ def loss(self, net_out):

# calculate the best IOU, set 0.0 confidence for worse boxes
iou = tf.truediv(intersect, _areas + area_pred - intersect)
best_box = tf.equal(iou, tf.reduce_max(iou, [2], True))
best_box = tf.to_float(best_box)
best_box = tf.equal(iou, tf.reduce_max(input_tensor=iou, axis=[2], keepdims=True))
best_box = tf.cast(best_box, dtype=tf.float32)
confs = tf.multiply(best_box, _confs)

# take care of the weight terms
Expand All @@ -102,6 +102,6 @@ def loss(self, net_out):
loss = tf.pow(adjusted_net_out - true, 2)
loss = tf.multiply(loss, wght)
loss = tf.reshape(loss, [-1, H*W*B*(4 + 1 + C)])
loss = tf.reduce_sum(loss, 1)
self.loss = .5 * tf.reduce_mean(loss)
tf.summary.scalar('{} loss'.format(m['model']), self.loss)
loss = tf.reduce_sum(input_tensor=loss, axis=1)
self.loss = .5 * tf.reduce_mean(input_tensor=loss)
tf.compat.v1.summary.scalar('{} loss'.format(m['model']), self.loss)
Loading