Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

30 accelerate cpi by using batch prediction and numpy array operations instead of for loop #31

Conversation

jpaillard
Copy link
Collaborator

Description

New implementation of the method CPI.predict. The idea is to replace the for loop over permutation by a single batched prediction over all permuted arrays.

N: number of samples
D: number of features
B: number of permutations

New

for p in range B:
   X_perm_j.append(sampling with jth group (conditionally) permuted)

X_perm_j  // shape P x N x D
y_pred_perm <- estimator.predict(X_perm_j )

Previous

for p in range B:
   X_perm_j <- sampling with jth group (conditionally) permuted  // shape N x D
   y_pred_perm_p <- estimator.predict(X_perm_j )   // Shape N

Results

Using pytest benchmark I obtain very important computation time improvement.

image

Reproducibility

The above benchmark can be reproduced as follow:

pip install pytest-benchmark

add the following test to the test_cpi,py file

def test_benchmark(benchmark):

    rng = np.random.RandomState(0)
    X_train = rng.randn(80, 10)
    y_train = rng.randn(80)
    X_test = rng.randn(20, 10)
    print(y_train)

    regression_model = LinearRegression()
    regression_model.fit(X_train, y_train)
    imputation_model = LinearRegression()

    cpi = CPI(
        estimator=regression_model,
        imputation_model=imputation_model,
        n_permutations=20,
        method="predict",
        random_state=0,
        n_jobs=1,
    )
    cpi.fit(
        X_train,
        y_train,
        groups=None,
    )
    benchmark(cpi.predict, X_test)
    # Save the output to check reproducibility.
    # Make sure to comment the benchmark line above before as it
    # will change the rng state in an unpredictable way.
    # np.save("./.pytest_cache/y_pred_2.npy", cpi.predict(X_test))

Run the benchmark on the previous (main branch) implementation:

git checkout main
pytest hidimstat/test/test_cpi.py::test_benchmark --benchmark-json previous_implementation

Run the benchmark on the new implementation and compare restults:

git checkout 30-accelerate-cpi-by-using-batch-prediction-and-numpy-array-operations-instead-of-for-loop
pytest hidimstat/test/test_cpi.py::test_benchmark --benchmark-compare previous_implementation

Consistency with previous implementation

The random seeding is done in a way that guarantees the exact consistency with the previous implementation.

Copy link

codecov bot commented Nov 14, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 77.11%. Comparing base (9b3d98e) to head (212198a).
Report is 4 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #31      +/-   ##
==========================================
+ Coverage   77.09%   77.11%   +0.02%     
==========================================
  Files          46       46              
  Lines        2462     2465       +3     
==========================================
+ Hits         1898     1901       +3     
  Misses        564      564              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@jpaillard jpaillard requested a review from bthirion November 14, 2024 17:58
Copy link
Contributor

@bthirion bthirion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thx.

@jpaillard jpaillard merged commit 7571832 into main Nov 15, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Accelerate CPI by using batch prediction and Numpy array operations instead of for loop
2 participants