Skip to content

Commit

Permalink
Weighted average for leave-one-out cross validation
Browse files Browse the repository at this point in the history
  • Loading branch information
hwpang committed Sep 11, 2023
1 parent 2212786 commit 9bd4ddc
Showing 1 changed file with 32 additions and 14 deletions.
46 changes: 32 additions & 14 deletions arkane/encorr/bac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,20 +1047,38 @@ def fit(self,
logging.info(f'RMSE/MAE before fitting: {stats_before.rmse:.2f}/{stats_before.mae:.2f} kcal/mol')
logging.info(f'RMSE/MAE after fitting: {stats_after.rmse:.2f}/{stats_after.mae:.2f} kcal/mol')

rmse_before = [test_data.calculate_stats().rmse for test_data in test_data_results]
mae_before = [test_data.calculate_stats().mae for test_data in test_data_results]
rmse_after = [test_data.calculate_stats(for_bac_data=True).rmse for test_data in test_data_results]
mae_after = [test_data.calculate_stats(for_bac_data=True).mae for test_data in test_data_results]

logging.info('\nCross-validation results:')
logging.info(f'Testing RMSE before fitting (mean +- 1 std): '
f'{np.average(rmse_before):.2f} +- {np.std(rmse_before):.2f} kcal/mol')
logging.info(f'Testing MAE before fitting (mean +- 1 std): '
f'{np.average(mae_before):.2f} +- {np.std(mae_before):.2f} kcal/mol')
logging.info(f'Testing RMSE after fitting (mean +- 1 std): '
f'{np.average(rmse_after):.2f} +- {np.std(rmse_after):.2f} kcal/mol')
logging.info(f'Testing MAE after fitting (mean +- 1 std): '
f'{np.average(mae_after):.2f} +- {np.std(mae_after):.2f} kcal/mol')
if self.n_folds == -1:
num_test_data = sum(len(test_data) for test_data in test_data_results)
rmse_before = np.sqrt(np.sum([test_data.calculate_stats().rmse**2 * len(test_data) for test_data in test_data_results]) / num_test_data)
mae_before = np.sum([test_data.calculate_stats().mae * len(test_data) for test_data in test_data_results]) / num_test_data
rmse_after = np.sqrt(np.sum([test_data.calculate_stats(for_bac_data=True).rmse**2 * len(test_data) for test_data in test_data_results]) / num_test_data)
mae_after = np.sum([test_data.calculate_stats(for_bac_data=True).mae * len(test_data) for test_data in test_data_results]) / num_test_data

logging.info('\nCross-validation results:')
logging.info(f'Testing RMSE before fitting: '
f'{rmse_before:.2f} kcal/mol')
logging.info(f'Testing MAE before fitting: '
f'{mae_before:.2f} kcal/mol')
logging.info(f'Testing RMSE after fitting: '
f'{rmse_after:.2f} kcal/mol')
logging.info(f'Testing MAE after fitting: '
f'{mae_after:.2f} kcal/mol')

else:
rmse_before = [test_data.calculate_stats().rmse for test_data in test_data_results]
mae_before = [test_data.calculate_stats().mae for test_data in test_data_results]
rmse_after = [test_data.calculate_stats(for_bac_data=True).rmse for test_data in test_data_results]
mae_after = [test_data.calculate_stats(for_bac_data=True).mae for test_data in test_data_results]

logging.info('\nCross-validation results:')
logging.info(f'Testing RMSE before fitting (mean +- 1 std): '
f'{np.average(rmse_before):.2f} +- {np.std(rmse_before):.2f} kcal/mol')
logging.info(f'Testing MAE before fitting (mean +- 1 std): '
f'{np.average(mae_before):.2f} +- {np.std(mae_before):.2f} kcal/mol')
logging.info(f'Testing RMSE after fitting (mean +- 1 std): '
f'{np.average(rmse_after):.2f} +- {np.std(rmse_after):.2f} kcal/mol')
logging.info(f'Testing MAE after fitting (mean +- 1 std): '
f'{np.average(mae_after):.2f} +- {np.std(mae_after):.2f} kcal/mol')


def get_confidence_intervals(x: np.ndarray,
Expand Down

0 comments on commit 9bd4ddc

Please sign in to comment.