Skip to content

Commit

Permalink
Making the net_spec python3 compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
philkr committed Jul 8, 2015
1 parent 77d66df commit 7093b0b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
5 changes: 3 additions & 2 deletions examples/pycaffe/caffenet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from caffe import layers as L, params as P, to_proto
from caffe.proto import caffe_pb2
from __future__ import print_function

# helper function for common structures

Expand Down Expand Up @@ -45,10 +46,10 @@ def caffenet(lmdb, batch_size=256, include_acc=False):

def make_net():
with open('train.prototxt', 'w') as f:
print >>f, caffenet('/path/to/caffe-train-lmdb')
print(caffenet('/path/to/caffe-train-lmdb'), file=f)

with open('test.prototxt', 'w') as f:
print >>f, caffenet('/path/to/caffe-val-lmdb', batch_size=50, include_acc=True)
print(caffenet('/path/to/caffe-val-lmdb', batch_size=50, include_acc=True), file=f)

if __name__ == '__main__':
make_net()
11 changes: 6 additions & 5 deletions python/caffe/net_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class -- assign to its attributes directly to name layers, and call

from .proto import caffe_pb2
from google import protobuf
import six


def param_name_dict():
Expand Down Expand Up @@ -63,12 +64,12 @@ def assign_proto(proto, name, val):
if isinstance(val[0], dict):
for item in val:
proto_item = getattr(proto, name).add()
for k, v in item.iteritems():
for k, v in six.iteritems(item):
assign_proto(proto_item, k, v)
else:
getattr(proto, name).extend(val)
elif isinstance(val, dict):
for k, v in val.iteritems():
for k, v in six.iteritems(val):
assign_proto(getattr(proto, name), k, v)
else:
setattr(proto, name, val)
Expand Down Expand Up @@ -131,7 +132,7 @@ def _to_proto(self, layers, names, autonames):
layer.top.append(self._get_name(top, names, autonames))
layer.name = self._get_name(self.tops[0], names, autonames)

for k, v in self.params.iteritems():
for k, v in six.iteritems(self.params):
# special case to handle generic *params
if k.endswith('param'):
assign_proto(layer, k, v)
Expand Down Expand Up @@ -161,10 +162,10 @@ def __getattr__(self, name):
return self.tops[name]

def to_proto(self):
names = {v: k for k, v in self.tops.iteritems()}
names = {v: k for k, v in six.iteritems(self.tops)}
autonames = {}
layers = OrderedDict()
for name, top in self.tops.iteritems():
for name, top in six.iteritems(self.tops):
top.fn._to_proto(layers, names, autonames)
net = caffe_pb2.NetParameter()
net.layer.extend(layers.values())
Expand Down
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ protobuf>=2.5.0
python-gflags>=2.0
pyyaml>=3.10
Pillow>=2.3.0
six>=1.1.0

0 comments on commit 7093b0b

Please sign in to comment.