Skip to content

Commit

Permalink
optimize the implementation of PermutationImportance.predict
Browse files Browse the repository at this point in the history
  • Loading branch information
jpaillard committed Nov 14, 2024
1 parent ce37632 commit 212198a
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions hidimstat/permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def fit(self, X, y=None, groups=None):
self.groups = groups
return self

def predict(self, X, y):
def predict(self, X, y=None):
"""
Compute the prediction of the model with permuted data for each group.
Expand Down Expand Up @@ -114,17 +114,31 @@ def _joblib_predict_one_group(X, j):
group_ids = self.groups[j]
non_group_ids = np.delete(np.arange(X.shape[1]), group_ids)

for _ in range(self.n_permutations):
X_j_perm = self.rng.permutation(X_j)
X_perm = np.empty_like(X)
X_perm[:, non_group_ids] = X_minus_j
X_perm[:, group_ids] = X_j_perm
if isinstance(X, pd.DataFrame):
X_perm = pd.DataFrame(X_perm, columns=X.columns)

y_pred_perm = getattr(self.estimator, self.method)(X_perm)
list_y_pred_perm.append(y_pred_perm)
return np.array(list_y_pred_perm)
# Create an array X_perm_j of shape (n_permutations, n_samples, n_features)
# where the j-th group of covariates is permuted
X_perm_j = np.empty((self.n_permutations, X.shape[0], X.shape[1]))
X_perm_j[:, :, non_group_ids] = X_minus_j
# Create the permuted data for the j-th group of covariates
group_j_permuted = np.array(
[self.rng.permutation(X_j) for _ in range(self.n_permutations)]
)
X_perm_j[:, :, group_ids] = group_j_permuted
# Reshape X_perm_j to allow for batch prediction
X_perm_batch = X_perm_j.reshape(-1, X.shape[1])
if isinstance(X, pd.DataFrame):
X_perm_batch = pd.DataFrame(
X_perm_batch.reshape(-1, X.shape[1]), columns=X.columns
)
y_pred_perm = getattr(self.estimator, self.method)(X_perm_batch)

# In case of classification, the output is a 2D array. Reshape accordingly
if y_pred_perm.ndim == 1:
y_pred_perm = y_pred_perm.reshape(self.n_permutations, X.shape[0])
else:
y_pred_perm = y_pred_perm.reshape(
self.n_permutations, X.shape[0], y_pred_perm.shape[1]
)
return y_pred_perm

# Parallelize the computation of the importance scores for each group
out_list = Parallel(n_jobs=self.n_jobs)(
Expand Down

0 comments on commit 212198a

Please sign in to comment.