Skip to content

Commit

Permalink
Merge branch 'jchutrue-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexEMG committed Jun 6, 2019
2 parents 37d8a7b + d6bbe48 commit ef8d67f
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 52 deletions.
70 changes: 19 additions & 51 deletions deeplabcut/generate_training_dataset/trainingsetmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,11 @@
else:
mpl.use('TkAgg')
import matplotlib.pyplot as plt



#if os.environ.get('DLClight', default=False) == 'True':
# mpl.use('AGG') #anti-grain geometry engine #https://matplotlib.org/faq/usage_faq.html
# pass
#else:
# mpl.use('TkAgg')
#import matplotlib.pyplot as plt
from skimage import io

import yaml
from deeplabcut import DEBUG
from deeplabcut.utils import auxiliaryfunctions, conversioncode
from deeplabcut.utils import auxiliaryfunctions, conversioncode, auxfun_models

#matplotlib.use('Agg')

Expand Down Expand Up @@ -434,15 +425,14 @@ def mergeandsplit(config,trainindex=0,uniform=True,windows2linux=False):
>>> deeplabcut.create_training_dataset(config,Shuffles=[3],trainIndexes=trainIndexes,testIndexes=testIndexes)
To freeze a (uniform) split:
>>> trainIndices, testIndices=deeplabcut.mergeandsplit(config,trainindex=0,uniform=True)
>>> trainIndexes, testIndexes=deeplabcut.mergeandsplit(config,trainindex=0,uniform=True)
You can then create two model instances that have the identical trainingset. Thereby you can assess the role of various parameters on the performance of DLC.
>>> deeplabcut.create_training_dataset(config,Shuffles=[0],trainIndices=trainIndices,testIndices=testIndices)
>>> deeplabcut.create_training_dataset(config,Shuffles=[1],trainIndices=trainIndices,testIndices=testIndices)
>>> deeplabcut.create_training_dataset(config,Shuffles=[0],trainIndexes=trainIndexes,testIndexes=testIndexes)
>>> deeplabcut.create_training_dataset(config,Shuffles=[1],trainIndexes=trainIndexes,testIndexes=testIndexes)
--------
"""

# Loading metadata from config file:
cfg = auxiliaryfunctions.read_config(config)
scorer = cfg['scorer']
Expand Down Expand Up @@ -479,7 +469,7 @@ def mergeandsplit(config,trainindex=0,uniform=True,windows2linux=False):
return trainIndexes, testIndexes


def create_training_dataset(config,num_shuffles=1,Shuffles=None,windows2linux=False,trainIndices=None,testIndices=None):
def create_training_dataset(config,num_shuffles=1,Shuffles=None,windows2linux=False,trainIndexes=None,testIndexes=None):
"""
Creates a training dataset. Labels from all the extracted frames are merged into a single .h5 file.\n
Only the videos included in the config file are used to create this dataset.\n
Expand All @@ -501,9 +491,6 @@ def create_training_dataset(config,num_shuffles=1,Shuffles=None,windows2linux=Fa
The annotation files contain path formated according to your operating system. If you label on windows
but train & evaluate on a unix system (e.g. ubunt, colab, Mac) set this variable to True to convert the paths.
trainIndices and testIndices: list of indices for traininng and testing. Use mergeandsplit(config,trainindex=0,uniform=True,windows2linux=False) to create them
See help for deeplabcut.mergeandsplit?
Example
--------
>>> deeplabcut.create_training_dataset('/analysis/project/reaching-task/config.yaml',num_shuffles=1)
Expand All @@ -513,9 +500,7 @@ def create_training_dataset(config,num_shuffles=1,Shuffles=None,windows2linux=Fa
"""
from skimage import io
import scipy.io as sio
import deeplabcut
import subprocess


# Loading metadata from config file:
cfg = auxiliaryfunctions.read_config(config)
scorer = cfg['scorer']
Expand All @@ -526,28 +511,16 @@ def create_training_dataset(config,num_shuffles=1,Shuffles=None,windows2linux=Fa

Data = merge_annotateddatasets(cfg,project_path,Path(os.path.join(project_path,trainingsetfolder)),windows2linux)
Data = Data[scorer] #extract labeled data

#set model type. we will allow more in the future.
if cfg['resnet']==50:
net_type ='resnet_'+str(cfg['resnet'])
resnet_path = str(Path(deeplabcut.__file__).parents[0] / 'pose_estimation_tensorflow/models/pretrained/resnet_v1_50.ckpt')
elif cfg['resnet']==101:
net_type ='resnet_'+str(cfg['resnet'])
resnet_path = str(Path(deeplabcut.__file__).parents[0] / 'pose_estimation_tensorflow/models/pretrained/resnet_v1_101.ckpt')
else:
print("Currently only ResNet 50 or 101 supported, please change 'resnet' entry in config.yaml!")
num_shuffles=-1 #thus the loop below is empty...

if not Path(resnet_path).is_file():
"""
Downloads the ImageNet pretrained weights for ResNet.
"""
start = os.getcwd()
os.chdir(str(Path(resnet_path).parents[0]))
print("Downloading the pretrained model (ResNets)....")
subprocess.call("download.sh", shell=True)
os.chdir(start)



#loading & linking pretrained models
net_type ='resnet_'+str(cfg['resnet'])
import deeplabcut
parent_path = Path(os.path.dirname(deeplabcut.__file__))
defaultconfigfile = str(parent_path / 'pose_cfg.yaml')

model_path,num_shuffles=auxfun_models.Check4weights(net_type,parent_path,num_shuffles)

if Shuffles==None:
Shuffles=range(1,num_shuffles+1,1)
else:
Expand All @@ -558,11 +531,9 @@ def create_training_dataset(config,num_shuffles=1,Shuffles=None,windows2linux=Fa
for shuffle in Shuffles: # Creating shuffles starting from 1
for trainFraction in TrainingFraction:
#trainIndexes, testIndexes = SplitTrials(range(len(Data.index)), trainFraction)
if trainIndices is None and testIndices is None:
if trainIndexes is None and testIndexes is None:
trainIndexes, testIndexes = SplitTrials(range(len(Data.index)), trainFraction)
else: # set to passed values...
trainIndexes=trainIndices
testIndexes=testIndices
else:
print("You passed a split with the following fraction:", len(trainIndexes)*1./(len(testIndexes)+len(trainIndexes))*100)

####################################################
Expand Down Expand Up @@ -649,13 +620,10 @@ def create_training_dataset(config,num_shuffles=1,Shuffles=None,windows2linux=Fa
"num_joints": len(bodyparts),
"all_joints": [[i] for i in range(len(bodyparts))],
"all_joints_names": [str(bpt) for bpt in bodyparts],
"init_weights": resnet_path,
"init_weights": model_path,
"project_path": str(cfg['project_path']),
"net_type": net_type
}

defaultconfigfile = str(Path(deeplabcut.__file__).parents[0] / 'pose_cfg.yaml')

trainingdata = MakeTrain_pose_yaml(items2change,path_train_config,defaultconfigfile)
keys2save = [
"dataset", "num_joints", "all_joints", "all_joints_names",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# legacy.
#!/bin/sh

curl http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz | tar xvz
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
resnet_50: http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
resnet_101: http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz
2 changes: 2 additions & 0 deletions deeplabcut/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from deeplabcut.utils.make_labeled_video import *
from deeplabcut.utils.auxiliaryfunctions import *
from deeplabcut.utils.auxfun_models import *

from deeplabcut.utils.video_processor import *
from deeplabcut.utils.plotting import *

Expand Down
46 changes: 46 additions & 0 deletions deeplabcut/utils/auxfun_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
DeepLabCut Toolbox
https://github.com/AlexEMG/DeepLabCut
A Mathis, [email protected]
M Mathis, [email protected]
"""

from deeplabcut.utils import auxiliaryfunctions

def Check4weights(modeltype,parent_path,num_shuffles):
''' gets local path to network weights and checks if they are present. If not, downloads them from tensorflow.org '''
if 'resnet_50' == modeltype:
model_path = parent_path / 'pose_estimation_tensorflow/models/pretrained/resnet_v1_50.ckpt'
elif 'resnet_101' == modeltype:
model_path = parent_path / 'pose_estimation_tensorflow/models/pretrained/resnet_v1_101.ckpt'
else:
print("Currently only ResNet 50 or 101 supported, please change 'resnet' entry in config.yaml!")
num_shuffles=-1 #thus the loop below is empty...
model_path=parent_path

if num_shuffles>0:
if not model_path.is_file():
Downloadweights(modeltype,model_path)

return str(model_path),num_shuffles

def Downloadweights(modeltype,model_path):
"""
Downloads the ImageNet pretrained weights for ResNet.
"""

import urllib
import tarfile
from io import BytesIO

target_dir = model_path.parents[0]
neturls=auxiliaryfunctions.read_plainconfig(target_dir / 'pretrained_model_urls.yaml')
try:
url = neturls[modeltype]
print("Downloading a ImageNet-pretrained model from {}....".format(url))
response = urllib.request.urlopen(url)
with tarfile.open(fileobj=BytesIO(response.read()), mode='r:gz') as tar:
tar.extractall(path=target_dir)
except KeyError:
print("Model does not exist", modeltype)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
'tqdm>4','wheel==0.31.1'],
scripts=['deeplabcut/pose_estimation_tensorflow/models/pretrained/download.sh'],
packages=setuptools.find_packages(),
data_files=[('deeplabcut',['deeplabcut/pose_cfg.yaml'])],
data_files=[('deeplabcut',['deeplabcut/pose_cfg.yaml','deeplabcut/pose_estimation_tensorflow/models/pretrained/pretrained_model_urls.yaml'])],
include_package_data=True,
classifiers=(
"Programming Language :: Python :: 3",
Expand Down

0 comments on commit ef8d67f

Please sign in to comment.