Skip to content

Commit

Permalink
remove unnecessary outputs from numpy ensemble function
Browse files Browse the repository at this point in the history
  • Loading branch information
themattinthehatt committed Dec 11, 2024
1 parent 01bd293 commit f9d9bdf
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 78 deletions.
24 changes: 4 additions & 20 deletions eks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# ------------------------------------------------------------------------------------------


# TODO: don't return arrays AND dicts
def ensemble(
markers_list: list,
keys: list,
Expand All @@ -43,21 +42,13 @@ def ensemble(
shape (samples, n_keypoints)
ensemble_stacks: np.ndarray
shape (n_models, samples, n_keypoints)
keypoints_avg_dict: dict
keys: marker keypoints, values: shape (samples)
keypoints_var_dict: dict
keys: marker keypoints, values: shape (samples)
keypoints_stack_dict: dict(dict)
keys: model_ids, keys: marker keypoints, values: shape (samples)
"""

ensemble_preds = []
ensemble_vars = []
ensemble_likes = []
ensemble_stacks = []
keypoints_avg_dict = {}
keypoints_var_dict = {}
keypoints_stack_dict = defaultdict(dict)

if avg_mode == 'median':
average_func = np.nanmedian
Expand All @@ -72,12 +63,9 @@ def ensemble(
stack = np.zeros((markers_list[0].shape[0], len(markers_list)))
for k in range(len(markers_list)):
stack[:, k] = markers_list[k][key]
ensemble_stacks.append(stack)
avg = average_func(stack, axis=1)
ensemble_preds.append(avg)
ensemble_stacks.append(stack)
keypoints_avg_dict[key] = avg
for i, keypoints in enumerate(stack.T):
keypoints_stack_dict[i][key] = stack.T[i]

# collect likelihoods
likelihood_stack = np.ones((markers_list[0].shape[0], len(markers_list)))
Expand All @@ -94,18 +82,14 @@ def ensemble(
var = var / mean_conf_per_keypoint # low-confidence --> inflated obs variances
elif var_mode != 'var':
raise ValueError(f"var_mode={var_mode} not supported")

ensemble_vars.append(var)
keypoints_var_dict[key] = var

ensemble_preds = np.asarray(ensemble_preds).T
ensemble_vars = np.asarray(ensemble_vars).T
ensemble_likes = np.asarray(ensemble_likes).T
ensemble_stacks = np.asarray(ensemble_stacks).T
return (
ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks,
keypoints_avg_dict, keypoints_var_dict, keypoints_stack_dict,
)

return ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks


def forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars):
Expand Down
14 changes: 6 additions & 8 deletions eks/ibl_paw_multiview_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,14 @@ def ensemble_kalman_smoother_ibl_paw(
markers_list_right_cam.append(markers_right_cam)

# compute ensemble median left camera
left_cam_ensemble_preds, left_cam_ensemble_vars, _, left_cam_ensemble_stacks, \
left_cam_keypoints_mean_dict, left_cam_keypoints_var_dict, \
left_cam_keypoints_stack_dict = \
ensemble(markers_list_left_cam, keys, avg_mode=ensembling_mode, var_mode='var')
left_cam_ensemble_preds, left_cam_ensemble_vars, _, left_cam_ensemble_stacks = ensemble(
markers_list_left_cam, keys, avg_mode=ensembling_mode, var_mode='var',
)

# compute ensemble median right camera
right_cam_ensemble_preds, right_cam_ensemble_vars, _, right_cam_ensemble_stacks, \
right_cam_keypoints_mean_dict, right_cam_keypoints_var_dict, \
right_cam_keypoints_stack_dict = \
ensemble(markers_list_right_cam, keys, avg_mode=ensembling_mode, var_mode='var')
right_cam_ensemble_preds, right_cam_ensemble_vars, _, right_cam_ensemble_stacks = ensemble(
markers_list_right_cam, keys, avg_mode=ensembling_mode, var_mode='var',
)

# keep percentage of the points for multi-view PCA based lowest ensemble variance
hstacked_vars = np.hstack((left_cam_ensemble_vars, right_cam_ensemble_vars))
Expand Down
30 changes: 11 additions & 19 deletions eks/ibl_pupil_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def get_pupil_diameter(dlc):
"""
diameters = []
# Get the x,ys coordinates of the four pupil points
top, bottom, left, right = [np.vstack((dlc[f'pupil_{point}_r_x'], dlc[f'pupil_{point}_r_y']))
for point in ['top', 'bottom', 'left', 'right']]
top, bottom, left, right = [
np.vstack((dlc[f'pupil_{point}_r_x'], dlc[f'pupil_{point}_r_y']))
for point in ['top', 'bottom', 'left', 'right']
]
# First compute direct diameters
diameters.append(np.linalg.norm(top - bottom, axis=0))
diameters.append(np.linalg.norm(left - right, axis=0))
Expand Down Expand Up @@ -168,29 +170,19 @@ def ensemble_kalman_smoother_ibl_pupil(
'pupil_top_r_x', 'pupil_top_r_y', 'pupil_bottom_r_x', 'pupil_bottom_r_y',
'pupil_right_r_x', 'pupil_right_r_y', 'pupil_left_r_x', 'pupil_left_r_y',
]
(
ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks,
keypoints_mean_dict, keypoints_var_dict, keypoints_stack_dict,
) = ensemble(markers_list, keys, avg_mode=avg_mode, var_mode=var_mode)

# compute center of mass
pupil_locations = get_pupil_location(keypoints_mean_dict)
pupil_diameters = get_pupil_diameter(keypoints_mean_dict)
diameters = []
for i in range(len(markers_list)):
keypoints_dict = keypoints_stack_dict[i]
diameter = get_pupil_diameter(keypoints_dict)
diameters.append(diameter)
ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks = ensemble(
markers_list, keys, avg_mode=avg_mode, var_mode=var_mode,
)

# compute center of mass + diameter
pupil_diameters = get_pupil_diameter({key: ensemble_preds[:, i] for i, key in enumerate(keys)})
pupil_locations = get_pupil_location({key: ensemble_preds[:, i] for i, key in enumerate(keys)})
mean_x_obs = np.mean(pupil_locations[:, 0])
mean_y_obs = np.mean(pupil_locations[:, 1])

# make the mean zero
x_t_obs, y_t_obs = pupil_locations[:, 0] - mean_x_obs, pupil_locations[:, 1] - mean_y_obs

# latent variables (observed)
# latent variables - diameter, com_x, com_y
# z_t_obs = np.vstack((pupil_diameters, x_t_obs, y_t_obs))

# --------------------------------------
# Set values for kalman filter
# --------------------------------------
Expand Down
13 changes: 3 additions & 10 deletions eks/multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,13 @@ def ensemble_kalman_smoother_multicam(
cam_ensemble_preds = []
cam_ensemble_vars = []
cam_ensemble_stacks = []
cam_keypoints_mean_dict = []
cam_keypoints_var_dict = []
cam_keypoints_stack_dict = []
for camera in range(num_cameras):
cam_ensemble_preds_curr, cam_ensemble_vars_curr, _, cam_ensemble_stacks_curr, \
cam_keypoints_mean_dict_curr, cam_keypoints_var_dict_curr, \
cam_keypoints_stack_dict_curr = \
ensemble(markers_list_cams[camera], keys, avg_mode=ensembling_mode)
cam_ensemble_preds_curr, cam_ensemble_vars_curr, _, cam_ensemble_stacks_curr = ensemble(
markers_list_cams[camera], keys, avg_mode=ensembling_mode,
)
cam_ensemble_preds.append(cam_ensemble_preds_curr)
cam_ensemble_vars.append(cam_ensemble_vars_curr)
cam_ensemble_stacks.append(cam_ensemble_stacks_curr)
cam_keypoints_mean_dict.append(cam_keypoints_mean_dict_curr)
cam_keypoints_var_dict.append(cam_keypoints_var_dict_curr)
cam_keypoints_stack_dict.append(cam_keypoints_stack_dict_curr)

# filter by low ensemble variances
hstacked_vars = np.hstack(cam_ensemble_vars)
Expand Down
45 changes: 24 additions & 21 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ def test_ensemble():
markers_list.append(pd.DataFrame(data))

# Run the ensemble function with 'median' mode
ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks, keypoints_avg_dict, \
keypoints_var_dict, keypoints_stack_dict = ensemble(
markers_list, keys, avg_mode='median', var_mode='var',
)
ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks = ensemble(
markers_list, keys, avg_mode='median', var_mode='var',
)

# Verify shapes of output arrays
assert ensemble_preds.shape == (num_samples, num_keypoints), \
Expand All @@ -42,29 +41,33 @@ def test_ensemble():
f"Likes expected shape {(num_samples, num_keypoints)}, got {ensemble_likes.shape}"
assert ensemble_stacks.shape == (3, num_samples, num_keypoints), \
f"Stacks expected shape {(3, num_samples, num_keypoints)}, got {ensemble_stacks.shape}"

# Verify contents of dictionaries
assert set(keypoints_avg_dict.keys()) == set(keys), \
f"Expected keys {keys}, got {keypoints_avg_dict.keys()}"
assert set(keypoints_var_dict.keys()) == set(keys), \
f"Expected keys {keys}, got {keypoints_var_dict.keys()}"
assert len(keypoints_stack_dict) == 3, \
f"Expected 3 models, got {len(keypoints_stack_dict)}"
# Check values for a keypoint (manually compute median and variance)
for i, key in enumerate(keys):
stack = np.array([df[key].values for df in markers_list]).T
expected_mean = np.nanmedian(stack, axis=1)
expected_variance = np.nanvar(stack, axis=1)
assert np.allclose(ensemble_preds[:, i], expected_mean), \
f"Medians not computed correctly in numpy ensemble function"
assert np.allclose(ensemble_vars[:, i], expected_variance), \
f"Vars not computed correctly in numpy ensemble function"
assert np.all(ensemble_likes[:, i] == 0.5), \
f"Likelihoods not computed correctly in numpy ensemble function"

# Run the ensemble function with avg_mode='mean' and var_mode='conf_weighted_var'
ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks, keypoints_avg_dict, \
keypoints_var_dict, keypoints_stack_dict = ensemble(
markers_list, keys, avg_mode='mean', var_mode='conf_weighted_var',
)
ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks = ensemble(
markers_list, keys, avg_mode='mean', var_mode='conf_weighted_var',
)
# Check values for a keypoint (manually compute mean and variance)
for key in keys:
for i, key in enumerate(keys):
stack = np.array([df[key].values for df in markers_list]).T
expected_mean = np.nanmean(stack, axis=1)
expected_variance = 2.0 * np.nanvar(stack, axis=1) # 2x since likelihoods all 0.5
assert np.allclose(keypoints_avg_dict[key], expected_mean), \
f"Means expected {expected_mean} for {key}, got {keypoints_avg_dict[key]}"
assert np.allclose(keypoints_var_dict[key], expected_variance), \
f"Vars expected {expected_variance} for {key}, got {keypoints_var_dict[key]}"
assert np.allclose(ensemble_preds[:, i], expected_mean), \
f"Means not computed correctly in numpy ensemble function"
assert np.allclose(ensemble_vars[:, i], expected_variance), \
f"Conf weighted vars not computed correctly in numpy ensemble function"
assert np.all(ensemble_likes[:, i] == 0.5), \
f"Likelihoods not computed correctly in numpy ensemble function"


def test_kalman_dot_basic():
Expand Down

0 comments on commit f9d9bdf

Please sign in to comment.