Skip to content

Commit

Permalink
allow conf_weighted_var computation for pupil
Browse files Browse the repository at this point in the history
  • Loading branch information
themattinthehatt committed Dec 11, 2024
1 parent 83fc4c6 commit 37e48fd
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 142 deletions.
4 changes: 2 additions & 2 deletions eks/command_line_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def handle_parse_args(script_type):
add_camera_names(parser)
add_quantile_keep_pca(parser)
add_s(parser)
elif script_type == 'pupil':
elif script_type == 'ibl_pupil':
add_diameter_s(parser)
add_com_s(parser)
elif script_type == 'paw':
elif script_type == 'ibl_paw':
add_s(parser)
add_quantile_keep_pca(parser)
else:
Expand Down
121 changes: 61 additions & 60 deletions eks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,81 +16,82 @@
# ------------------------------------------------------------------------------------------


def ensemble(markers_list, keys, mode='median'):
"""Computes ensemble median (or mean) and variance of list of DLC marker dataframes
def ensemble(
markers_list: list,
keys: list,
avg_mode: str = 'median',
var_mode: str = 'confidence_weighted_var',
) -> tuple:
"""Compute ensemble mean/median and variance of marker dataframes.
Args:
markers_list: list
List of DLC marker dataframes`
keys: list
List of keys in each marker dataframe
mode: string
Averaging mode which includes 'median', 'mean', or 'confidence_weighted_mean'.
markers_list: List of DLC marker dataframes
keys: List of keys in each marker dataframe
avg_mode
'median' | 'mean'
var_mode
'confidence_weighted_var' | 'var'
Returns:
ensemble_preds: np.ndarray
shape (samples, n_keypoints)
ensemble_vars: np.ndarray
shape (samples, n_keypoints)
ensemble_stacks: np.ndarray
shape (n_models, samples, n_keypoints)
keypoints_avg_dict: dict
keys: marker keypoints, values: shape (samples)
keypoints_var_dict: dict
keys: marker keypoints, values: shape (samples)
keypoints_stack_dict: dict(dict)
keys: model_ids, keys: marker keypoints, values: shape (samples)
tuple:
ensemble_preds: np.ndarray
shape (samples, n_keypoints)
ensemble_vars: np.ndarray
shape (samples, n_keypoints)
ensemble_stacks: np.ndarray
shape (n_models, samples, n_keypoints)
keypoints_avg_dict: dict
keys: marker keypoints, values: shape (samples)
keypoints_var_dict: dict
keys: marker keypoints, values: shape (samples)
keypoints_stack_dict: dict(dict)
keys: model_ids, keys: marker keypoints, values: shape (samples)
"""
ensemble_stacks = []
ensemble_vars = []
ensemble_preds = []
keypoints_avg_dict = {}
keypoints_var_dict = {}
keypoints_stack_dict = defaultdict(dict)
if mode != 'confidence_weighted_mean':
if mode == 'median':
average_func = np.nanmedian
elif mode == 'mean':
average_func = np.nanmean
else:
raise ValueError(f"{mode} averaging not supported")

if avg_mode == 'median':
average_func = np.nanmedian
elif avg_mode == 'mean':
average_func = np.nanmean
else:
raise ValueError(f"avg_mode={avg_mode} not supported")

for key in keys:
if mode != 'confidence_weighted_mean':
stack = np.zeros((len(markers_list), markers_list[0].shape[0]))
for k in range(len(markers_list)):
stack[k] = markers_list[k][key]
stack = stack.T
avg = average_func(stack, 1)
var = np.nanvar(stack, 1)
ensemble_preds.append(avg)
ensemble_vars.append(var)
ensemble_stacks.append(stack)
keypoints_avg_dict[key] = avg
keypoints_var_dict[key] = var
for i, keypoints in enumerate(stack.T):
keypoints_stack_dict[i][key] = stack.T[i]
else:

# compute mean/median
stack = np.zeros((markers_list[0].shape[0], len(markers_list)))
for k in range(len(markers_list)):
stack[:, k] = markers_list[k][key]
avg = average_func(stack, axis=1)
ensemble_preds.append(avg)
ensemble_stacks.append(stack)
keypoints_avg_dict[key] = avg
for i, keypoints in enumerate(stack.T):
keypoints_stack_dict[i][key] = stack.T[i]

# compute variance
var = np.nanvar(stack, axis=1)
if var_mode in ['conf_weighted_var', 'confidence_weighted_var']:
likelihood_key = key[:-1] + 'likelihood'
if likelihood_key not in markers_list[0]:
raise ValueError(f"{likelihood_key} needs to be in your marker_df to use {mode}")
stack = np.zeros((len(markers_list), markers_list[0].shape[0]))
likelihood_stack = np.zeros((len(markers_list), markers_list[0].shape[0]))
raise ValueError(
f"{likelihood_key} needs to be in your marker_df to use {var_mode}")
likelihood_stack = np.zeros((markers_list[0].shape[0], len(markers_list)))
for k in range(len(markers_list)):
stack[k] = markers_list[k][key]
likelihood_stack[k] = markers_list[k][likelihood_key]
stack = stack.T
likelihood_stack = likelihood_stack.T
conf_per_keypoint = np.sum(likelihood_stack, 1)
mean_conf_per_keypoint = np.sum(likelihood_stack, 1) / likelihood_stack.shape[1]
avg = np.sum(stack * likelihood_stack, 1) / conf_per_keypoint
var = np.nanvar(stack, 1)
likelihood_stack[:, k] = markers_list[k][likelihood_key]
mean_conf_per_keypoint = np.mean(likelihood_stack, axis=1)
var = var / mean_conf_per_keypoint # low-confidence --> inflated obs variances
ensemble_preds.append(avg)
ensemble_vars.append(var)
ensemble_stacks.append(stack)
keypoints_avg_dict[key] = avg
keypoints_var_dict[key] = var
for i, keypoints in enumerate(stack.T):
keypoints_stack_dict[i][key] = stack.T[i]
elif var_mode != 'var':
raise ValueError(f"var_mode={var_mode} not supported")

ensemble_vars.append(var)
keypoints_var_dict[key] = var

ensemble_preds = np.asarray(ensemble_preds).T
ensemble_vars = np.asarray(ensemble_vars).T
Expand Down
23 changes: 15 additions & 8 deletions eks/ibl_paw_multiview_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from eks.utils import make_dlc_pandas_index


# TODO:
# - allow conf_weighted_mean for ensemble variance computation


def remove_camera_means(ensemble_stacks, camera_means):
scaled_ensemble_stacks = ensemble_stacks.copy()
for k in range(len(ensemble_stacks)):
Expand All @@ -31,10 +35,15 @@ def pca(S, n_comps):


def ensemble_kalman_smoother_ibl_paw(
markers_list_left_cam, markers_list_right_cam, timestamps_left_cam,
timestamps_right_cam, keypoint_names, smooth_param, quantile_keep_pca,
ensembling_mode='median',
zscore_threshold=2, img_width=128):
markers_list_left_cam, markers_list_right_cam,
timestamps_left_cam, timestamps_right_cam,
keypoint_names,
smooth_param,
quantile_keep_pca,
ensembling_mode='median',
zscore_threshold=2,
img_width=128,
):
"""
--(IBL-specific)-
-Use multi-view constraints to fit a 3d latent subspace for each body part with 2
Expand Down Expand Up @@ -63,8 +72,6 @@ def ensemble_kalman_smoother_ibl_paw(
(default 2).
img_width
The width of the image being smoothed (128 default, IBL-specific).
Returns
-------
Returns
-------
Expand Down Expand Up @@ -128,13 +135,13 @@ def ensemble_kalman_smoother_ibl_paw(
left_cam_ensemble_preds, left_cam_ensemble_vars, left_cam_ensemble_stacks, \
left_cam_keypoints_mean_dict, left_cam_keypoints_var_dict, \
left_cam_keypoints_stack_dict = \
ensemble(markers_list_left_cam, keys, mode=ensembling_mode)
ensemble(markers_list_left_cam, keys, avg_mode=ensembling_mode, var_mode='var')

# compute ensemble median right camera
right_cam_ensemble_preds, right_cam_ensemble_vars, right_cam_ensemble_stacks, \
right_cam_keypoints_mean_dict, right_cam_keypoints_var_dict, \
right_cam_keypoints_stack_dict = \
ensemble(markers_list_right_cam, keys, mode=ensembling_mode)
ensemble(markers_list_right_cam, keys, avg_mode=ensembling_mode, var_mode='var')

# keep percentage of the points for multi-view PCA based lowest ensemble variance
hstacked_vars = np.hstack((left_cam_ensemble_vars, right_cam_ensemble_vars))
Expand Down
26 changes: 21 additions & 5 deletions eks/ibl_pupil_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def fit_eks_pupil(
save_file: str,
smooth_params: list,
s_frames: Optional[list] = None,
avg_mode: str = 'median',
var_mode: str = 'confidence_weighted_var',
) -> tuple:
"""Function to fit the Ensemble Kalman Smoother for the ibl-pupil dataset.
Expand All @@ -91,6 +93,10 @@ def fit_eks_pupil(
save_file: File to save outputs.
smooth_params: List containing diameter_s and com_s.
s_frames: Frames for automatic optimization if needed.
avg_mode
'median' | 'mean'
var_mode
'confidence_weighted_var' | 'var'
Returns:
tuple:
Expand All @@ -112,7 +118,9 @@ def fit_eks_pupil(
markers_list=input_dfs_list,
keypoint_names=keypoint_names,
smooth_params=smooth_params,
s_frames=s_frames
s_frames=s_frames,
avg_mode=avg_mode,
var_mode=var_mode,
)

# Save the output DataFrame to CSV
Expand All @@ -128,16 +136,22 @@ def ensemble_kalman_smoother_ibl_pupil(
keypoint_names: list,
smooth_params: list,
s_frames: Optional[list] = None,
avg_mode: str = 'median',
var_mode: str = 'confidence_weighted_var',
zscore_threshold: float = 2,
) -> tuple:
"""Perform Ensemble Kalman Smoothing on pupil data.
Args:
markers_list: pd.DataFrames
each list element is a dataframe of predictions from one ensemble member
each list element is a dataframe of predictions from one ensemble member
keypoint_names
smooth_params: contains smoothing parameters for diameter and center of mass
s_frames: frames for automatic optimization if s is not provided
avg_mode
'median' | 'mean'
var_mode
'confidence_weighted_var' | 'var'
zscore_threshold: Minimum std threshold to reduce the effect of low ensemble std on a
zscore metric (default 2).
Expand All @@ -151,10 +165,12 @@ def ensemble_kalman_smoother_ibl_pupil(
"""

# compute ensemble median
keys = ['pupil_top_r_x', 'pupil_top_r_y', 'pupil_bottom_r_x', 'pupil_bottom_r_y',
'pupil_right_r_x', 'pupil_right_r_y', 'pupil_left_r_x', 'pupil_left_r_y']
keys = [
'pupil_top_r_x', 'pupil_top_r_y', 'pupil_bottom_r_x', 'pupil_bottom_r_y',
'pupil_right_r_x', 'pupil_right_r_y', 'pupil_left_r_x', 'pupil_left_r_y',
]
ensemble_preds, ensemble_vars, ensemble_stacks, keypoints_mean_dict, keypoints_var_dict, \
keypoints_stack_dict = ensemble(markers_list, keys)
keypoints_stack_dict = ensemble(markers_list, keys, avg_mode=avg_mode, var_mode=var_mode)

# compute center of mass
pupil_locations = get_pupil_location(keypoints_mean_dict)
Expand Down
2 changes: 1 addition & 1 deletion eks/multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def ensemble_kalman_smoother_multicam(
cam_ensemble_preds_curr, cam_ensemble_vars_curr, cam_ensemble_stacks_curr, \
cam_keypoints_mean_dict_curr, cam_keypoints_var_dict_curr, \
cam_keypoints_stack_dict_curr = \
ensemble(markers_list_cams[camera], keys, mode=ensembling_mode)
ensemble(markers_list_cams[camera], keys, avg_mode=ensembling_mode)
cam_ensemble_preds.append(cam_ensemble_preds_curr)
cam_ensemble_vars.append(cam_ensemble_vars_curr)
cam_ensemble_stacks.append(cam_ensemble_stacks_curr)
Expand Down
4 changes: 3 additions & 1 deletion scripts/ibl_paw_multiview_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from eks.ibl_paw_multiview_smoother import ensemble_kalman_smoother_ibl_paw
from eks.utils import convert_lp_dlc


smoother_type = 'ibl_paw'

# Collect User-Provided Args
smoother_type = 'paw'
args = handle_parse_args(smoother_type)
input_dir = os.path.abspath(args.input_dir)
save_dir = handle_io(input_dir, args.save_dir) # defaults to outputs\
Expand Down
6 changes: 4 additions & 2 deletions scripts/ibl_pupil_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from eks.ibl_pupil_smoother import fit_eks_pupil
from eks.utils import format_data, plot_results


smoother_type = 'ibl_pupil'

# Collect User-Provided Arguments
smoother_type = 'pupil'
args = handle_parse_args(smoother_type)

# Determine input source (directory or list of files)
Expand All @@ -31,7 +33,7 @@
input_source=input_source,
save_file=os.path.join(save_dir, save_filename or 'eks_pupil.csv'),
smooth_params=[diameter_s, com_s],
s_frames=s_frames
s_frames=s_frames,
)

# Plot results
Expand Down
4 changes: 3 additions & 1 deletion scripts/mirrored_multicam_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from eks.multicam_smoother import ensemble_kalman_smoother_multicam
from eks.utils import format_data, plot_results, populate_output_dataframe

# Collect User-Provided Args

smoother_type = 'multicam'

# Collect User-Provided Args
args = handle_parse_args(smoother_type)
input_dir = os.path.abspath(args.input_dir)
save_dir = handle_io(input_dir, args.save_dir) # defaults to outputs
Expand Down
Loading

0 comments on commit 37e48fd

Please sign in to comment.