diff --git a/hydromodel/trainers/evaluate.py b/hydromodel/trainers/evaluate.py index 9ec0f7e..70421f0 100644 --- a/hydromodel/trainers/evaluate.py +++ b/hydromodel/trainers/evaluate.py @@ -1,22 +1,20 @@ """ Author: Wenyu Ouyang Date: 2022-10-25 21:16:22 -LastEditTime: 2024-09-14 20:31:58 +LastEditTime: 2024-09-17 14:23:29 LastEditors: Wenyu Ouyang Description: Plots for calibration and testing results FilePath: \hydromodel\hydromodel\trainers\evaluate.py Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved. """ -import pathlib -import pandas as pd import os +import yaml import numpy as np +import pandas as pd import xarray as xr -import spotpy -import yaml -from hydroutils import hydro_file, hydro_stat +from hydroutils import hydro_stat from hydrodatasource.utils.utils import streamflow_unit_conv from hydromodel.datasets import * @@ -270,6 +268,38 @@ def load_results(self): return xr.open_dataset(file_path) +def _load_csv_results_pandas(sceua_calibrated_file_name): + return pd.read_csv(sceua_calibrated_file_name + ".csv") + + +def _get_minlikeindex_pandas(results_df, like_index=1, verbose=True): + """ + Get the minimum objectivefunction of your result DataFrame + + :results_df: Expects a pandas DataFrame with a "like" column for objective functions + :type: DataFrame + + :return: Index of the position in the DataFrame with the minimum objective function + value and the value of the minimum objective function + :rtype: int and float + """ + # Extract the 'like' column based on the like_index + like_column = f"like{str(like_index)}" + likes = results_df[like_column].values + + # Find the minimum value in the 'like' column + minimum = np.nanmin(likes) + value = str(round(minimum, 4)) + index = np.where(likes == minimum) + text2 = " has the lowest objective function with: " + textv = f"Run number {str(index[0][0])}{text2}{value}" + + if verbose: + print(textv) + + return index[0][0], minimum + + def _read_save_sceua_calibrated_params(basin_id, save_dir, sceua_calibrated_file_name): """ read the parameters' file generated by spotpy SCE-UA when finishing calibration @@ -289,16 +319,19 @@ def _read_save_sceua_calibrated_params(basin_id, save_dir, sceua_calibrated_file ------- """ - results = spotpy.analyser.load_csv_results(sceua_calibrated_file_name) - bestindex, bestobjf = spotpy.analyser.get_minlikeindex( - results - ) # 结果数组中具有最小目标函数的位置的索引 + results = _load_csv_results_pandas(sceua_calibrated_file_name) + # Index of the position in the results array with the minimum objective function + bestindex, bestobjf = _get_minlikeindex_pandas(results) + # the following code is from spotpy but its performance is not good so we use pandas to replace it + # results = spotpy.analyser.load_csv_results(sceua_calibrated_file_name) + # bestindex, bestobjf = spotpy.analyser.get_minlikeindex(results) best_model_run = results[bestindex] fields = [word for word in best_model_run.dtype.names if word.startswith("par")] best_calibrate_params = pd.DataFrame(list(best_model_run[fields])) save_file = os.path.join(save_dir, basin_id + "_calibrate_params.txt") best_calibrate_params.to_csv(save_file, sep=",", index=False, header=True) - return np.array(best_calibrate_params).reshape(1, -1) # 返回一列最佳的结果 + # Return the best result as a single row + return np.array(best_calibrate_params).reshape(1, -1) def _read_all_basin_params(basins, param_dir): diff --git a/test/conftest.py b/test/conftest.py index 297a580..0178b47 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,6 +1,18 @@ +""" +Author: Wenyu Ouyang +Date: 2024-08-14 16:34:32 +LastEditTime: 2024-09-17 14:32:47 +LastEditors: Wenyu Ouyang +Description: Some common fixtures for testing +FilePath: \hydromodel\test\conftest.py +Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved. +""" + import os import numpy as np import pytest +import spotpy +from spotpy.examples.spot_setup_hymod_python import spot_setup from hydrodataset import Camels from hydromodel import SETTING @@ -52,3 +64,28 @@ def qobs(basin_area, camels, basins): r = qobs["streamflow"] / basin_area["area_gages2"] r_mmd = r.pint.to(target_unit) return np.expand_dims(r_mmd.to_numpy().transpose(1, 0), axis=2) + + +@pytest.fixture(scope="session") +def hymod_setup(): + """ + A pytest fixture that runs the hymod calibration and returns the results. + This will run before any test that requires it. + """ + setup = spot_setup(spotpy.objectivefunctions.rmse) + if not os.path.exists("test/SCEUA_hymod.csv"): + # Set up the hymod model and sampler + + # Create SCE-UA sampler + sampler = spotpy.algorithms.sceua( + setup, dbname="test/SCEUA_hymod", dbformat="csv" + ) + + # Calibration parameters + repetitions = 5000 # Maximum iterations + + # Run the sampler + sampler.sample(repetitions, ngs=7, kstop=3, peps=0.1, pcento=0.1) + + # Return the results for further use + return setup diff --git a/test/test_data_visualize.py b/test/test_data_visualize.py index ace135e..614da7f 100644 --- a/test/test_data_visualize.py +++ b/test/test_data_visualize.py @@ -1,16 +1,16 @@ """ Author: Wenyu Ouyang Date: 2022-10-25 21:16:22 -LastEditTime: 2024-09-11 21:25:42 +LastEditTime: 2024-09-17 14:30:58 LastEditors: Wenyu Ouyang Description: Test for results visualization FilePath: \hydromodel\test\test_data_visualize.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved. """ +import spotpy import matplotlib.pyplot as plt import spotpy -from spotpy.examples.spot_setup_hymod_python import spot_setup as hymod_setup from hydroutils import hydro_time @@ -20,19 +20,7 @@ from hydromodel.trainers.evaluate import _read_save_sceua_calibrated_params -def test_run_hymod_calibration(): - # a case from spotpy example - setup = hymod_setup(spotpy.objectivefunctions.rmse) - - # 创建SCE-UA算法的sampler - sampler = spotpy.algorithms.sceua(setup, dbname="test/SCEUA_hymod", dbformat="csv") - - # 设置校准参数 - repetitions = 5000 # 最大迭代次数 - - # 运行sampler - sampler.sample(repetitions, ngs=7, kstop=3, peps=0.1, pcento=0.1) - +def test_run_hymod_calibration(hymod_setup): # 从CSV文件加载结果 results = spotpy.analyser.load_csv_results("test/SCEUA_hymod") @@ -58,7 +46,7 @@ def test_run_hymod_calibration(): linestyle="solid", label=f"Best objf.={str(bestobjf)}", ) - plt.plot(setup.evaluation(), "r.", markersize=3, label="Observation data") + plt.plot(hymod_setup.evaluation(), "r.", markersize=3, label="Observation data") plt.xlabel("Number of Observation Points") plt.ylabel("Discharge [l s-1]") plt.legend(loc="upper right") diff --git a/test/test_evaluate.py b/test/test_evaluate.py index 1a1ea30..80be201 100644 --- a/test/test_evaluate.py +++ b/test/test_evaluate.py @@ -1,7 +1,7 @@ """ Author: Wenyu Ouyang Date: 2024-09-14 17:00:29 -LastEditTime: 2024-09-14 17:25:16 +LastEditTime: 2024-09-17 14:32:30 LastEditors: Wenyu Ouyang Description: Test the evaluate module FilePath: \hydromodel\test\test_evaluate.py @@ -12,7 +12,10 @@ import pytest import xarray as xr import numpy as np -from hydromodel.trainers.evaluate import Evaluator + +from spotpy.analyser import get_minlikeindex + +from hydromodel.trainers.evaluate import Evaluator, _get_minlikeindex_pandas @pytest.fixture @@ -95,3 +98,34 @@ def test_predict(sample_dataset, evaluator, mocker): assert "prcp" in qsim assert "prcp" in qobs assert "prcp" in etsim + + +def test_get_minlikeindex(hymod_setup): + filename = "test/SCEUA_hymod" + # Read data using np.genfromtxt (original) + results_np = np.genfromtxt( + f"{filename}.csv", delimiter=",", names=True, invalid_raise=False + ) + + # Read data using pandas (new method) + results_pd = pd.read_csv(f"{filename}.csv") + + # Run the original get_minlikeindex function + original_index, original_minimum = get_minlikeindex( + results_np, like_index=1, verbose=False + ) + + # Run the new get_minlikeindex_pandas function + new_index, new_minimum = _get_minlikeindex_pandas( + results_pd, like_index=1, verbose=False + ) + + # Compare the results from both methods + assert ( + original_index == new_index + ), f"Index mismatch: {original_index} != {new_index}" + assert np.isclose( + original_minimum, new_minimum + ), f"Minimum value mismatch: {original_minimum} != {new_minimum}" + + print("Test passed! Both methods return identical results.")