Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

30 accelerate cpi by using batch prediction and numpy array operations instead of for loop #31

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions hidimstat/cpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _joblib_predict_one_group(imputation_model, X, j):
Compute the prediction of the model with the permuted data for a
single group of covariates.
"""
list_y_pred_perm = []

if isinstance(X, pd.DataFrame):
X_j = X[self.groups[j]].copy().values
X_minus_j = X.drop(columns=self.groups[j]).values
Expand All @@ -169,18 +169,31 @@ def _joblib_predict_one_group(imputation_model, X, j):
X_j_hat = imputation_model.predict(X_minus_j).reshape(X_j.shape)
residual_j = X_j - X_j_hat

for _ in range(self.n_permutations):
X_j_perm = X_j_hat + self.rng.permutation(residual_j)
X_perm = np.empty_like(X)
X_perm[:, non_group_ids] = X_minus_j
X_perm[:, group_ids] = X_j_perm
if isinstance(X, pd.DataFrame):
X_perm = pd.DataFrame(X_perm, columns=X.columns)

y_pred_perm = getattr(self.estimator, self.method)(X_perm)
list_y_pred_perm.append(y_pred_perm)

return np.array(list_y_pred_perm)
# Create an array X_perm_j of shape (n_permutations, n_samples, n_features)
# where the j-th group of covariates is (conditionally) permuted
X_perm_j = np.empty((self.n_permutations, X.shape[0], X.shape[1]))
X_perm_j[:, :, non_group_ids] = X_minus_j
# Create the permuted data for the j-th group of covariates
residual_j_perm = np.array(
[self.rng.permutation(residual_j) for _ in range(self.n_permutations)]
)
X_perm_j[:, :, group_ids] = X_j_hat[np.newaxis, :, :] + residual_j_perm
# Reshape X_perm_j to allow for batch prediction
X_perm_batch = X_perm_j.reshape(-1, X.shape[1])
if isinstance(X, pd.DataFrame):
X_perm_batch = pd.DataFrame(
X_perm_batch.reshape(-1, X.shape[1]), columns=X.columns
)
y_pred_perm = getattr(self.estimator, self.method)(X_perm_batch)

# In case of classification, the output is a 2D array. Reshape accordingly
if y_pred_perm.ndim == 1:
y_pred_perm = y_pred_perm.reshape(self.n_permutations, X.shape[0])
else:
y_pred_perm = y_pred_perm.reshape(
self.n_permutations, X.shape[0], y_pred_perm.shape[1]
)
return y_pred_perm

# Parallelize the computation of the importance scores for each group
out_list = Parallel(n_jobs=self.n_jobs)(
Expand Down
38 changes: 26 additions & 12 deletions hidimstat/permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def fit(self, X, y=None, groups=None):
self.groups = groups
return self

def predict(self, X, y):
def predict(self, X, y=None):
"""
Compute the prediction of the model with permuted data for each group.

Expand Down Expand Up @@ -114,17 +114,31 @@ def _joblib_predict_one_group(X, j):
group_ids = self.groups[j]
non_group_ids = np.delete(np.arange(X.shape[1]), group_ids)

for _ in range(self.n_permutations):
X_j_perm = self.rng.permutation(X_j)
X_perm = np.empty_like(X)
X_perm[:, non_group_ids] = X_minus_j
X_perm[:, group_ids] = X_j_perm
if isinstance(X, pd.DataFrame):
X_perm = pd.DataFrame(X_perm, columns=X.columns)

y_pred_perm = getattr(self.estimator, self.method)(X_perm)
list_y_pred_perm.append(y_pred_perm)
return np.array(list_y_pred_perm)
# Create an array X_perm_j of shape (n_permutations, n_samples, n_features)
# where the j-th group of covariates is permuted
X_perm_j = np.empty((self.n_permutations, X.shape[0], X.shape[1]))
X_perm_j[:, :, non_group_ids] = X_minus_j
# Create the permuted data for the j-th group of covariates
group_j_permuted = np.array(
[self.rng.permutation(X_j) for _ in range(self.n_permutations)]
)
X_perm_j[:, :, group_ids] = group_j_permuted
# Reshape X_perm_j to allow for batch prediction
X_perm_batch = X_perm_j.reshape(-1, X.shape[1])
if isinstance(X, pd.DataFrame):
X_perm_batch = pd.DataFrame(
X_perm_batch.reshape(-1, X.shape[1]), columns=X.columns
)
y_pred_perm = getattr(self.estimator, self.method)(X_perm_batch)

# In case of classification, the output is a 2D array. Reshape accordingly
if y_pred_perm.ndim == 1:
y_pred_perm = y_pred_perm.reshape(self.n_permutations, X.shape[0])
else:
y_pred_perm = y_pred_perm.reshape(
self.n_permutations, X.shape[0], y_pred_perm.shape[1]
)
return y_pred_perm

# Parallelize the computation of the importance scores for each group
out_list = Parallel(n_jobs=self.n_jobs)(
Expand Down