Skip to content

Commit

Permalink
rename variables
Browse files Browse the repository at this point in the history
  • Loading branch information
jpaillard committed Nov 14, 2024
1 parent dc63c98 commit ce37632
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions hidimstat/cpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _joblib_predict_one_group(imputation_model, X, j):
Compute the prediction of the model with the permuted data for a
single group of covariates.
"""
list_y_pred_perm = []

if isinstance(X, pd.DataFrame):
X_j = X[self.groups[j]].copy().values
X_minus_j = X.drop(columns=self.groups[j]).values
Expand All @@ -169,15 +169,17 @@ def _joblib_predict_one_group(imputation_model, X, j):
X_j_hat = imputation_model.predict(X_minus_j).reshape(X_j.shape)
residual_j = X_j - X_j_hat

X_perm_all = np.empty((self.n_permutations, X.shape[0], X.shape[1]))
X_perm_all[:, :, non_group_ids] = X_minus_j
# n_permutations x n_samples x n_features_j
# Create an array X_perm_j of shape (n_permutations, n_samples, n_features)
# where the j-th group of covariates is (conditionally) permuted
X_perm_j = np.empty((self.n_permutations, X.shape[0], X.shape[1]))
X_perm_j[:, :, non_group_ids] = X_minus_j
# Create the permuted data for the j-th group of covariates
residual_j_perm = np.array(
[self.rng.permutation(residual_j) for _ in range(self.n_permutations)]
)
X_perm_all[:, :, group_ids] = X_j_hat[np.newaxis, :, :] + residual_j_perm

X_perm_batch = X_perm_all.reshape(-1, X.shape[1])
X_perm_j[:, :, group_ids] = X_j_hat[np.newaxis, :, :] + residual_j_perm
# Reshape X_perm_j to allow for batch prediction
X_perm_batch = X_perm_j.reshape(-1, X.shape[1])
if isinstance(X, pd.DataFrame):
X_perm_batch = pd.DataFrame(
X_perm_batch.reshape(-1, X.shape[1]), columns=X.columns
Expand All @@ -193,19 +195,6 @@ def _joblib_predict_one_group(imputation_model, X, j):
)
return y_pred_perm

# for _ in range(self.n_permutations):
# X_j_perm = X_j_hat + self.rng.permutation(residual_j)
# X_perm = np.empty_like(X)
# X_perm[:, non_group_ids] = X_minus_j
# X_perm[:, group_ids] = X_j_perm
# if isinstance(X, pd.DataFrame):
# X_perm = pd.DataFrame(X_perm, columns=X.columns)

# y_pred_perm = getattr(self.estimator, self.method)(X_perm)
# list_y_pred_perm.append(y_pred_perm)

# return np.array(list_y_pred_perm)

# Parallelize the computation of the importance scores for each group
out_list = Parallel(n_jobs=self.n_jobs)(
delayed(_joblib_predict_one_group)(imputation_model, X, j)
Expand Down

0 comments on commit ce37632

Please sign in to comment.