forked from DeepLabCut/DeepLabCut
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added functionality to plot cropped analysis points on cropped video;…
… expanded filters
- Loading branch information
Showing
7 changed files
with
213 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
""" | ||
DeepLabCut2.0 Toolbox | ||
https://github.com/AlexEMG/DeepLabCut | ||
A Mathis, [email protected] | ||
T Nath, [email protected] | ||
M Mathis, [email protected] | ||
""" | ||
|
||
from deeplabcut.post_processing.filtering import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
""" | ||
DeepLabCut2.0 Toolbox | ||
https://github.com/AlexEMG/DeepLabCut | ||
A Mathis, [email protected] | ||
T Nath, [email protected] | ||
M Mathis, [email protected] | ||
""" | ||
import numpy as np | ||
import os | ||
from pathlib import Path | ||
import pandas as pd | ||
|
||
from deeplabcut.utils import auxiliaryfunctions, visualization | ||
from deeplabcut.utils import frameselectiontools | ||
from deeplabcut.refine_training_dataset.outlier_frames import FitSARIMAXModel | ||
|
||
import argparse | ||
from tqdm import tqdm | ||
import matplotlib.pyplot as plt | ||
from skimage.util import img_as_ubyte | ||
from scipy import signal | ||
|
||
|
||
def filterpredictions(config,video,videotype='avi',shuffle=1,trainingsetindex=0,filterype='median',windowlength=5,p_bound=.001,ARdegree=3,MAdegree=1,alpha=.01,save_as_csv=True,destfolder=None): | ||
""" | ||
Fits frame-by-frame pose predictions with ARIMA model (filtertype='arima') or median filter (default). | ||
Parameter | ||
---------- | ||
config : string | ||
Full path of the config.yaml file as a string. | ||
video : string | ||
Full path of the video to extract the frame from. Make sure that this video is already analyzed. | ||
shuffle : int, optional | ||
The shufle index of training dataset. The extracted frames will be stored in the labeled-dataset for | ||
the corresponding shuffle of training dataset. Default is set to 1 | ||
trainingsetindex: int, optional | ||
Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml). | ||
filterype: string | ||
Select which filter, 'arima' or 'median' filter. | ||
windowlength: int | ||
For filtertype='median' filters the input array using a local window-size given by windowlength. The array will automatically be zero-padded. | ||
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.medfilt.html The windowlenght should be an odd number. | ||
p_bound: float between 0 and 1, optional | ||
For filtertype 'arima' this parameter defines the likelihood below, | ||
below which a body part will be consided as missing data for filtering purposes. | ||
ARdegree: int, optional | ||
For filtertype 'arima' Autoregressive degree of Sarimax model degree. | ||
see https://www.statsmodels.org/dev/generated/statsmodels.tsa.statespace.sarimax.SARIMAX.html | ||
MAdegree: int | ||
For filtertype 'arima' Moving Avarage degree of Sarimax model degree. | ||
See https://www.statsmodels.org/dev/generated/statsmodels.tsa.statespace.sarimax.SARIMAX.html | ||
alpha: float | ||
Significance level for detecting outliers based on confidence interval of fitted SARIMAX model. | ||
save_as_csv: bool, optional | ||
Saves the predictions in a .csv file. The default is ``False``; if provided it must be either ``True`` or ``False`` | ||
destfolder: string, optional | ||
Specifies the destination folder for analysis data (default is the path of the video). Note that for subsequent analysis this | ||
folder also needs to be passed. | ||
Example | ||
-------- | ||
Arima model: | ||
deeplabcut.filterpredictions('C:\\myproject\\reaching-task\\config.yaml',['C:\\myproject\\trailtracking-task\\test.mp4'],shuffle=3,filterype='arima',ARdegree=5,MAdegree=2) | ||
Use median filter over 10bins: | ||
deeplabcut.filterpredictions('C:\\myproject\\reaching-task\\config.yaml',['C:\\myproject\\trailtracking-task\\test.mp4'],shuffle=3,windowlength=10) | ||
One can then use the filtered rather than the frame-by-frame predictions by calling: | ||
deeplabcut.plot_trajectories('C:\\myproject\\reaching-task\\config.yaml',['C:\\myproject\\trailtracking-task\\test.mp4'],shuffle=3,filtered=True) | ||
deeplabcut.create_labeled_video('C:\\myproject\\reaching-task\\config.yaml',['C:\\myproject\\trailtracking-task\\test.mp4'],shuffle=3,filtered=True) | ||
-------- | ||
Returns filtered pandas array with the same structure as normal output of network. | ||
""" | ||
cfg = auxiliaryfunctions.read_config(config) | ||
scorer=auxiliaryfunctions.GetScorerName(cfg,shuffle,trainFraction = cfg['TrainingFraction'][trainingsetindex]) | ||
Videos=auxiliaryfunctions.Getlistofvideos(video,videotype) | ||
|
||
if len(Videos)>0: | ||
for video in Videos: | ||
|
||
if destfolder is None: | ||
destfolder = str(Path(video).parents[0]) | ||
|
||
print("Filtering with ARIMA model %s",video) | ||
videofolder = str(Path(video).parents[0]) | ||
dataname = str(Path(video).stem)+scorer | ||
filteredname=dataname.split('.h5')[0]+'filtered.h5' | ||
try: | ||
Dataframe = pd.read_hdf(os.path.join(videofolder,filteredname)) | ||
print("Video already filtered...") | ||
except FileNotFoundError: | ||
try: | ||
Dataframe = pd.read_hdf(os.path.join(videofolder,dataname+'.h5')) | ||
for bpindex,bp in tqdm(enumerate(cfg['bodyparts'])): | ||
pdindex = pd.MultiIndex.from_product([[scorer], [bp], ['x', 'y','likelihood']],names=['scorer', 'bodyparts', 'coords']) | ||
x,y,p=Dataframe[scorer][bp]['x'].values,Dataframe[scorer][bp]['y'].values,Dataframe[scorer][bp]['likelihood'].values | ||
|
||
if filterype=='arima': | ||
meanx,CIx=FitSARIMAXModel(x,p,p_bound,alpha,ARdegree,MAdegree,False) | ||
meany,CIy=FitSARIMAXModel(y,p,p_bound,alpha,ARdegree,MAdegree,False) | ||
|
||
meanx[0]=x[0] | ||
meany[0]=y[0] | ||
else: | ||
meanx=signal.medfilt(x,kernel_size=windowlength) | ||
meany=signal.medfilt(y,kernel_size=windowlength) | ||
|
||
if bpindex==0: | ||
data = pd.DataFrame(np.hstack([np.expand_dims(meanx,axis=1),np.expand_dims(meany,axis=1),np.expand_dims(p,axis=1)]), columns=pdindex) | ||
else: | ||
item=pd.DataFrame(np.hstack([np.expand_dims(meanx,axis=1),np.expand_dims(meany,axis=1),np.expand_dims(p,axis=1)]), columns=pdindex) | ||
data=pd.concat([data.T, item.T]).T | ||
|
||
data.to_hdf(os.path.join(videofolder,filteredname), 'df_with_missing', format='table', mode='w') | ||
if save_as_csv: | ||
print("Saving filtered csv poses!") | ||
data.to_csv(os.path.join(videofolder,filteredname.split('.h5')[0]+'.csv')) | ||
except FileNotFoundError: | ||
print("Video not analyzed -- Run analyze_videos first.") | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('config') | ||
parser.add_argument('videos') | ||
cli_args = parser.parse_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.