Skip to content

Commit

Permalink
Added functionality to plot cropped analysis points on cropped video;…
Browse files Browse the repository at this point in the history
… expanded filters
  • Loading branch information
AlexEMG committed Jun 5, 2019
1 parent 4cacee6 commit 72a1ea5
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 141 deletions.
3 changes: 2 additions & 1 deletion deeplabcut/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
if os.environ.get('Colab', default=False) == 'True':
print("Project loaded in colab-mode. Apparently Colab has trouble loading statsmodels, so the smooting & outlier frame extraction is disabled. Sorry!")
else:
from deeplabcut.refine_training_dataset import extract_outlier_frames, merge_datasets, filterpredictions
from deeplabcut.refine_training_dataset import extract_outlier_frames, merge_datasets
from deeplabcut.post_processing import filterpredictions

#Direct import for convenience
from deeplabcut.pose_estimation_tensorflow import train_network
Expand Down
9 changes: 9 additions & 0 deletions deeplabcut/post_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
DeepLabCut2.0 Toolbox
https://github.com/AlexEMG/DeepLabCut
A Mathis, [email protected]
T Nath, [email protected]
M Mathis, [email protected]
"""

from deeplabcut.post_processing.filtering import *
143 changes: 143 additions & 0 deletions deeplabcut/post_processing/filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
DeepLabCut2.0 Toolbox
https://github.com/AlexEMG/DeepLabCut
A Mathis, [email protected]
T Nath, [email protected]
M Mathis, [email protected]
"""
import numpy as np
import os
from pathlib import Path
import pandas as pd

from deeplabcut.utils import auxiliaryfunctions, visualization
from deeplabcut.utils import frameselectiontools
from deeplabcut.refine_training_dataset.outlier_frames import FitSARIMAXModel

import argparse
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.util import img_as_ubyte
from scipy import signal


def filterpredictions(config,video,videotype='avi',shuffle=1,trainingsetindex=0,filterype='median',windowlength=5,p_bound=.001,ARdegree=3,MAdegree=1,alpha=.01,save_as_csv=True,destfolder=None):
"""
Fits frame-by-frame pose predictions with ARIMA model (filtertype='arima') or median filter (default).
Parameter
----------
config : string
Full path of the config.yaml file as a string.
video : string
Full path of the video to extract the frame from. Make sure that this video is already analyzed.
shuffle : int, optional
The shufle index of training dataset. The extracted frames will be stored in the labeled-dataset for
the corresponding shuffle of training dataset. Default is set to 1
trainingsetindex: int, optional
Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml).
filterype: string
Select which filter, 'arima' or 'median' filter.
windowlength: int
For filtertype='median' filters the input array using a local window-size given by windowlength. The array will automatically be zero-padded.
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.medfilt.html The windowlenght should be an odd number.
p_bound: float between 0 and 1, optional
For filtertype 'arima' this parameter defines the likelihood below,
below which a body part will be consided as missing data for filtering purposes.
ARdegree: int, optional
For filtertype 'arima' Autoregressive degree of Sarimax model degree.
see https://www.statsmodels.org/dev/generated/statsmodels.tsa.statespace.sarimax.SARIMAX.html
MAdegree: int
For filtertype 'arima' Moving Avarage degree of Sarimax model degree.
See https://www.statsmodels.org/dev/generated/statsmodels.tsa.statespace.sarimax.SARIMAX.html
alpha: float
Significance level for detecting outliers based on confidence interval of fitted SARIMAX model.
save_as_csv: bool, optional
Saves the predictions in a .csv file. The default is ``False``; if provided it must be either ``True`` or ``False``
destfolder: string, optional
Specifies the destination folder for analysis data (default is the path of the video). Note that for subsequent analysis this
folder also needs to be passed.
Example
--------
Arima model:
deeplabcut.filterpredictions('C:\\myproject\\reaching-task\\config.yaml',['C:\\myproject\\trailtracking-task\\test.mp4'],shuffle=3,filterype='arima',ARdegree=5,MAdegree=2)
Use median filter over 10bins:
deeplabcut.filterpredictions('C:\\myproject\\reaching-task\\config.yaml',['C:\\myproject\\trailtracking-task\\test.mp4'],shuffle=3,windowlength=10)
One can then use the filtered rather than the frame-by-frame predictions by calling:
deeplabcut.plot_trajectories('C:\\myproject\\reaching-task\\config.yaml',['C:\\myproject\\trailtracking-task\\test.mp4'],shuffle=3,filtered=True)
deeplabcut.create_labeled_video('C:\\myproject\\reaching-task\\config.yaml',['C:\\myproject\\trailtracking-task\\test.mp4'],shuffle=3,filtered=True)
--------
Returns filtered pandas array with the same structure as normal output of network.
"""
cfg = auxiliaryfunctions.read_config(config)
scorer=auxiliaryfunctions.GetScorerName(cfg,shuffle,trainFraction = cfg['TrainingFraction'][trainingsetindex])
Videos=auxiliaryfunctions.Getlistofvideos(video,videotype)

if len(Videos)>0:
for video in Videos:

if destfolder is None:
destfolder = str(Path(video).parents[0])

print("Filtering with ARIMA model %s",video)
videofolder = str(Path(video).parents[0])
dataname = str(Path(video).stem)+scorer
filteredname=dataname.split('.h5')[0]+'filtered.h5'
try:
Dataframe = pd.read_hdf(os.path.join(videofolder,filteredname))
print("Video already filtered...")
except FileNotFoundError:
try:
Dataframe = pd.read_hdf(os.path.join(videofolder,dataname+'.h5'))
for bpindex,bp in tqdm(enumerate(cfg['bodyparts'])):
pdindex = pd.MultiIndex.from_product([[scorer], [bp], ['x', 'y','likelihood']],names=['scorer', 'bodyparts', 'coords'])
x,y,p=Dataframe[scorer][bp]['x'].values,Dataframe[scorer][bp]['y'].values,Dataframe[scorer][bp]['likelihood'].values

if filterype=='arima':
meanx,CIx=FitSARIMAXModel(x,p,p_bound,alpha,ARdegree,MAdegree,False)
meany,CIy=FitSARIMAXModel(y,p,p_bound,alpha,ARdegree,MAdegree,False)

meanx[0]=x[0]
meany[0]=y[0]
else:
meanx=signal.medfilt(x,kernel_size=windowlength)
meany=signal.medfilt(y,kernel_size=windowlength)

if bpindex==0:
data = pd.DataFrame(np.hstack([np.expand_dims(meanx,axis=1),np.expand_dims(meany,axis=1),np.expand_dims(p,axis=1)]), columns=pdindex)
else:
item=pd.DataFrame(np.hstack([np.expand_dims(meanx,axis=1),np.expand_dims(meany,axis=1),np.expand_dims(p,axis=1)]), columns=pdindex)
data=pd.concat([data.T, item.T]).T

data.to_hdf(os.path.join(videofolder,filteredname), 'df_with_missing', format='table', mode='w')
if save_as_csv:
print("Saving filtered csv poses!")
data.to_csv(os.path.join(videofolder,filteredname.split('.h5')[0]+'.csv'))
except FileNotFoundError:
print("Video not analyzed -- Run analyze_videos first.")

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('config')
parser.add_argument('videos')
cli_args = parser.parse_args()
100 changes: 0 additions & 100 deletions deeplabcut/refine_training_dataset/outlier_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,106 +194,6 @@ def extract_outlier_frames(config,videos,videotype='avi',shuffle=1,trainingsetin
print("The video has not been analyzed yet!. You can only refine the labels, after the pose has been estimate. Please run 'analyze_video' first.")


def filterpredictions(config,video,videotype='avi',shuffle=1,trainingsetindex=0,p_bound=.001,ARdegree=3,MAdegree=1,alpha=.01,save_as_csv=True,destfolder=None):
"""
Fits frame-by-frame pose predictions with SARIMAX model.
Parameter
----------
config : string
Full path of the config.yaml file as a string.
video : string
Full path of the video to extract the frame from. Make sure that this video is already analyzed.
shuffle : int, optional
The shufle index of training dataset. The extracted frames will be stored in the labeled-dataset for
the corresponding shuffle of training dataset. Default is set to 1
trainingsetindex: int, optional
Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml).
comparisonbodyparts: list of strings, optional
This select the body parts for which SARIMAX models are fit. Either ``all``, then all body parts
from config.yaml are used orr a list of strings that are a subset of the full list.
E.g. ['hand','Joystick'] for the demo Reaching-Mackenzie-2018-08-30/config.yaml to select only these two body parts.
p_bound: float between 0 and 1, optional
For outlieralgorithm 'uncertain' this parameter defines the likelihood below,
below which a body part will be consided as missing data for filtering purposes.
ARdegree: int, optional
For outlieralgorithm 'fitting': Autoregressive degree of Sarimax model degree.
see https://www.statsmodels.org/dev/generated/statsmodels.tsa.statespace.sarimax.SARIMAX.html
MAdegree: int
For outlieralgorithm 'fitting': Moving Avarage degree of Sarimax model degree.
See https://www.statsmodels.org/dev/generated/statsmodels.tsa.statespace.sarimax.SARIMAX.html
alpha: float
Significance level for detecting outliers based on confidence interval of fitted SARIMAX model.
save_as_csv: bool, optional
Saves the predictions in a .csv file. The default is ``False``; if provided it must be either ``True`` or ``False``
destfolder: string, optional
Specifies the destination folder for analysis data (default is the path of the video). Note that for subsequent analysis this
folder also needs to be passed.
Example
--------
deeplabcut.filterpredictions('C:\\myproject\\reaching-task\\config.yaml',['C:\\myproject\\trailtracking-task\\test.mp4'],shuffle=3,ARdegree=5,MAdegree=2)
One can then use the filtered rather than the frame-by-frame predictions by calling:
deeplabcut.plot_trajectories('C:\\myproject\\reaching-task\\config.yaml',['C:\\myproject\\trailtracking-task\\test.mp4'],shuffle=3,filtered=True)
deeplabcut.create_labeled_video('C:\\myproject\\reaching-task\\config.yaml',['C:\\myproject\\trailtracking-task\\test.mp4'],shuffle=3,filtered=True)
--------
Returns filtered pandas array with the same structure as normal output of network.
"""
cfg = auxiliaryfunctions.read_config(config)
scorer=auxiliaryfunctions.GetScorerName(cfg,shuffle,trainFraction = cfg['TrainingFraction'][trainingsetindex])
Videos=auxiliaryfunctions.Getlistofvideos(video,videotype)
if len(Videos)>0:
for video in Videos:

if destfolder is None:
destfolder = str(Path(video).parents[0])

print("Filtering with ARIMA model %s",video)
videofolder = str(Path(video).parents[0])
dataname = str(Path(video).stem)+scorer
filteredname=dataname.split('.h5')[0]+'filtered.h5'
try:
Dataframe = pd.read_hdf(os.path.join(videofolder,filteredname))
print("Video already filtered...")
except FileNotFoundError:
try:
Dataframe = pd.read_hdf(os.path.join(videofolder,dataname+'.h5'))
for bpindex,bp in tqdm(enumerate(cfg['bodyparts'])):
pdindex = pd.MultiIndex.from_product([[scorer], [bp], ['x', 'y','likelihood']],names=['scorer', 'bodyparts', 'coords'])
x,y,p=Dataframe[scorer][bp]['x'].values,Dataframe[scorer][bp]['y'].values,Dataframe[scorer][bp]['likelihood'].values
meanx,CIx=FitSARIMAXModel(x,p,p_bound,alpha,ARdegree,MAdegree,False)
meany,CIy=FitSARIMAXModel(y,p,p_bound,alpha,ARdegree,MAdegree,False)

meanx[0]=x[0]
meany[0]=y[0]

if bpindex==0:
data = pd.DataFrame(np.hstack([np.expand_dims(meanx,axis=1),np.expand_dims(meany,axis=1),np.expand_dims(p,axis=1)]), columns=pdindex)
else:
item=pd.DataFrame(np.hstack([np.expand_dims(meanx,axis=1),np.expand_dims(meany,axis=1),np.expand_dims(p,axis=1)]), columns=pdindex)
data=pd.concat([data.T, item.T]).T

data.to_hdf(os.path.join(videofolder,filteredname), 'df_with_missing', format='table', mode='w')
if save_as_csv:
print("Saving filtered csv poses!")
data.to_csv(os.path.join(videofolder,filteredname.split('.h5')[0]+'.csv'))
except FileNotFoundError:
print("Video not analyzed -- Run analyze_videos first.")

def convertparms2start(pn):
''' Creating a start value for sarimax in case of an value error
See: https://groups.google.com/forum/#!topic/pystatsmodels/S_Fo53F25Rk '''
Expand Down
Loading

0 comments on commit 72a1ea5

Please sign in to comment.