Skip to content

Commit

Permalink
working on #36
Browse files Browse the repository at this point in the history
add a method to resample if needed
  • Loading branch information
nag92 committed Sep 17, 2021
1 parent a00ffef commit 3b4eeed
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
11 changes: 11 additions & 0 deletions GaitAnaylsisToolkit/LearningTools/Trainer/TPGMMQuaternions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

import TPGMMTrainer


class TPGMMQuaternions(TPGMMTrainer.TPGMMTrainer):

def __init__(self, demo, file_name, n_rf, dt=0.01, reg=[1e-5], poly_degree=[15], A=[], b=[]):
super().__init__(demo, file_name, n_rf, dt, reg, poly_degree, A, b)

def train(self, save=True):
return super().train(save)
11 changes: 8 additions & 3 deletions GaitAnaylsisToolkit/LearningTools/Trainer/TPGMMTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class TPGMMTrainer(TrainerBase.TrainerBase):

def __init__(self, demo, file_name, n_rf, dt=0.01, reg=[1e-5], poly_degree=[15],A=[],b=[]):
def __init__(self, demo, file_name, n_rf, dt=0.01, reg=[1e-5], poly_degree=[15], resample=[False], A=[],b=[]):
"""
:param file_names: file to save training too
:param n_rfs: number of DMPs
Expand Down Expand Up @@ -38,8 +38,13 @@ def __init__(self, demo, file_name, n_rf, dt=0.01, reg=[1e-5], poly_degree=[15],
else:
my_reg = reg*(1+ len(demo))

for d, polyD in zip(demo, poly_degree):
demo_, dtw_data_ = self.resample(d, polyD)
if len(resample) == len(demo):
my_resample = [1e-8] + reg
else:
my_resample = resample*(1+ len(demo))

for d, polyD, resamp in zip(demo, poly_degree, my_resample):
demo_, dtw_data_ = self.resample(d, polyD, resamp)
rescaled.append(demo_)
self.dtw_data.append(dtw_data_)

Expand Down
14 changes: 5 additions & 9 deletions GaitAnaylsisToolkit/LearningTools/Trainer/TrainerBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def reg(self):
def reg(self, value):
self._reg = value

def resample(self, trajs, poly_degree):
def resample(self, trajs, poly_degree, resample):
"""
:param trajs: list of demos
Expand Down Expand Up @@ -62,26 +62,22 @@ def resample(self, trajs, poly_degree):
for ii, y in enumerate(trajs):
dtw_data = {}
d, cost_matrix, acc_cost_matrix, path = dtw(trajs[idx], y, dist=manhattan_distance)
# d, cost_matrix, acc_cost_matrix, path = dtw(x_fit, y, dist=manhattan_distance)
dtw_data["cost"] = d
dtw_data["cost_matrix"] = cost_matrix
dtw_data["acc_cost_matrix"] = acc_cost_matrix
dtw_data["path"] = path

data.append(dtw_data)
#data_warp = [y[path[1]][:x_fit.shape[0]]]
data_warp = [y[path[1]]]
data_warp_rsp = signal.resample(data_warp[0], x_fit.shape[0]) # resample dtw output 188 points to 118 points
#data_warp = [y[:][:x_fit.shape[0]]]
#coefs = poly.polyfit(t, data_warp[0], 20)
data_warp_rsp = y[path[1]][:x_fit.shape[0]]
if resample:
data_warp_rsp = signal.resample(data_warp[0], x_fit.shape[0]) # resample dtw output 188 points to 118 points
coefs = poly.polyfit(t, data_warp_rsp, poly_degree)
ffit = poly.Polynomial(coefs) # instead of np.poly1d
y_fit = ffit(t)
# y_fit = data_warp[0]
y_fit = data_warp_rsp
temp = [[np.array(ele)] for ele in y_fit.tolist()]
temp = np.array(temp)
demos.append(temp)
dtw_data["unsmooth_path"] =y[path[1]][:x_fit.shape[0]]
dtw_data["smooth_path"] = temp
return demos, data

Expand Down

0 comments on commit 3b4eeed

Please sign in to comment.