Skip to content

Commit

Permalink
speed up cpi.predict
Browse files Browse the repository at this point in the history
  • Loading branch information
jpaillard committed Nov 14, 2024
1 parent 3ca0a79 commit dc63c98
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions hidimstat/cpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,42 @@ def _joblib_predict_one_group(imputation_model, X, j):
X_j_hat = imputation_model.predict(X_minus_j).reshape(X_j.shape)
residual_j = X_j - X_j_hat

for _ in range(self.n_permutations):
X_j_perm = X_j_hat + self.rng.permutation(residual_j)
X_perm = np.empty_like(X)
X_perm[:, non_group_ids] = X_minus_j
X_perm[:, group_ids] = X_j_perm
if isinstance(X, pd.DataFrame):
X_perm = pd.DataFrame(X_perm, columns=X.columns)

y_pred_perm = getattr(self.estimator, self.method)(X_perm)
list_y_pred_perm.append(y_pred_perm)

return np.array(list_y_pred_perm)
X_perm_all = np.empty((self.n_permutations, X.shape[0], X.shape[1]))
X_perm_all[:, :, non_group_ids] = X_minus_j
# n_permutations x n_samples x n_features_j
residual_j_perm = np.array(
[self.rng.permutation(residual_j) for _ in range(self.n_permutations)]
)
X_perm_all[:, :, group_ids] = X_j_hat[np.newaxis, :, :] + residual_j_perm

X_perm_batch = X_perm_all.reshape(-1, X.shape[1])
if isinstance(X, pd.DataFrame):
X_perm_batch = pd.DataFrame(
X_perm_batch.reshape(-1, X.shape[1]), columns=X.columns
)
y_pred_perm = getattr(self.estimator, self.method)(X_perm_batch)

# In case of classification, the output is a 2D array. Reshape accordingly
if y_pred_perm.ndim == 1:
y_pred_perm = y_pred_perm.reshape(self.n_permutations, X.shape[0])
else:
y_pred_perm = y_pred_perm.reshape(
self.n_permutations, X.shape[0], y_pred_perm.shape[1]
)
return y_pred_perm

# for _ in range(self.n_permutations):
# X_j_perm = X_j_hat + self.rng.permutation(residual_j)
# X_perm = np.empty_like(X)
# X_perm[:, non_group_ids] = X_minus_j
# X_perm[:, group_ids] = X_j_perm
# if isinstance(X, pd.DataFrame):
# X_perm = pd.DataFrame(X_perm, columns=X.columns)

# y_pred_perm = getattr(self.estimator, self.method)(X_perm)
# list_y_pred_perm.append(y_pred_perm)

# return np.array(list_y_pred_perm)

# Parallelize the computation of the importance scores for each group
out_list = Parallel(n_jobs=self.n_jobs)(
Expand Down

0 comments on commit dc63c98

Please sign in to comment.