diff --git a/hydromodel/trainers/evaluate.py b/hydromodel/trainers/evaluate.py index 70421f0..49e5cb7 100644 --- a/hydromodel/trainers/evaluate.py +++ b/hydromodel/trainers/evaluate.py @@ -1,7 +1,7 @@ """ Author: Wenyu Ouyang Date: 2022-10-25 21:16:22 -LastEditTime: 2024-09-17 14:23:29 +LastEditTime: 2024-09-17 15:13:57 LastEditors: Wenyu Ouyang Description: Plots for calibration and testing results FilePath: \hydromodel\hydromodel\trainers\evaluate.py @@ -325,22 +325,25 @@ def _read_save_sceua_calibrated_params(basin_id, save_dir, sceua_calibrated_file # the following code is from spotpy but its performance is not good so we use pandas to replace it # results = spotpy.analyser.load_csv_results(sceua_calibrated_file_name) # bestindex, bestobjf = spotpy.analyser.get_minlikeindex(results) - best_model_run = results[bestindex] - fields = [word for word in best_model_run.dtype.names if word.startswith("par")] - best_calibrate_params = pd.DataFrame(list(best_model_run[fields])) + best_model_run = results.iloc[bestindex] + fields = [word for word in best_model_run.index if word.startswith("par")] + best_calibrate_params = pd.DataFrame( + [best_model_run[fields].values], columns=fields + ) save_file = os.path.join(save_dir, basin_id + "_calibrate_params.txt") - best_calibrate_params.to_csv(save_file, sep=",", index=False, header=True) + # to keep consistent with the original code, we save the best parameters to a txt file + best_calibrate_params.T.to_csv(save_file, index=False, columns=None) # Return the best result as a single row - return np.array(best_calibrate_params).reshape(1, -1) + return best_calibrate_params.to_numpy().reshape(1, -1) def _read_all_basin_params(basins, param_dir): params_list = [] for basin_id in basins: db_name = os.path.join(param_dir, basin_id) - # 读取每个流域的参数 + # Read parameters for each basin basin_params = _read_save_sceua_calibrated_params(basin_id, param_dir, db_name) - # 确保basin_params是一维的 + # Ensure basin_params is one-dimensional basin_params = basin_params.flatten() params_list.append(basin_params) return np.vstack(params_list)