Skip to content

Commit

Permalink
use pandas to read spotpy-param-cal-process-file
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Sep 17, 2024
1 parent fc46262 commit a9c7067
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 29 deletions.
55 changes: 44 additions & 11 deletions hydromodel/trainers/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
"""
Author: Wenyu Ouyang
Date: 2022-10-25 21:16:22
LastEditTime: 2024-09-14 20:31:58
LastEditTime: 2024-09-17 14:23:29
LastEditors: Wenyu Ouyang
Description: Plots for calibration and testing results
FilePath: \hydromodel\hydromodel\trainers\evaluate.py
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
"""

import pathlib
import pandas as pd
import os
import yaml
import numpy as np
import pandas as pd
import xarray as xr
import spotpy
import yaml

from hydroutils import hydro_file, hydro_stat
from hydroutils import hydro_stat
from hydrodatasource.utils.utils import streamflow_unit_conv

from hydromodel.datasets import *
Expand Down Expand Up @@ -270,6 +268,38 @@ def load_results(self):
return xr.open_dataset(file_path)


def _load_csv_results_pandas(sceua_calibrated_file_name):
return pd.read_csv(sceua_calibrated_file_name + ".csv")


def _get_minlikeindex_pandas(results_df, like_index=1, verbose=True):
"""
Get the minimum objectivefunction of your result DataFrame
:results_df: Expects a pandas DataFrame with a "like" column for objective functions
:type: DataFrame
:return: Index of the position in the DataFrame with the minimum objective function
value and the value of the minimum objective function
:rtype: int and float
"""
# Extract the 'like' column based on the like_index
like_column = f"like{str(like_index)}"
likes = results_df[like_column].values

# Find the minimum value in the 'like' column
minimum = np.nanmin(likes)
value = str(round(minimum, 4))
index = np.where(likes == minimum)
text2 = " has the lowest objective function with: "
textv = f"Run number {str(index[0][0])}{text2}{value}"

if verbose:
print(textv)

return index[0][0], minimum


def _read_save_sceua_calibrated_params(basin_id, save_dir, sceua_calibrated_file_name):
"""
read the parameters' file generated by spotpy SCE-UA when finishing calibration
Expand All @@ -289,16 +319,19 @@ def _read_save_sceua_calibrated_params(basin_id, save_dir, sceua_calibrated_file
-------
"""
results = spotpy.analyser.load_csv_results(sceua_calibrated_file_name)
bestindex, bestobjf = spotpy.analyser.get_minlikeindex(
results
) # 结果数组中具有最小目标函数的位置的索引
results = _load_csv_results_pandas(sceua_calibrated_file_name)
# Index of the position in the results array with the minimum objective function
bestindex, bestobjf = _get_minlikeindex_pandas(results)
# the following code is from spotpy but its performance is not good so we use pandas to replace it
# results = spotpy.analyser.load_csv_results(sceua_calibrated_file_name)
# bestindex, bestobjf = spotpy.analyser.get_minlikeindex(results)
best_model_run = results[bestindex]
fields = [word for word in best_model_run.dtype.names if word.startswith("par")]
best_calibrate_params = pd.DataFrame(list(best_model_run[fields]))
save_file = os.path.join(save_dir, basin_id + "_calibrate_params.txt")
best_calibrate_params.to_csv(save_file, sep=",", index=False, header=True)
return np.array(best_calibrate_params).reshape(1, -1) # 返回一列最佳的结果
# Return the best result as a single row
return np.array(best_calibrate_params).reshape(1, -1)


def _read_all_basin_params(basins, param_dir):
Expand Down
37 changes: 37 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
"""
Author: Wenyu Ouyang
Date: 2024-08-14 16:34:32
LastEditTime: 2024-09-17 14:32:47
LastEditors: Wenyu Ouyang
Description: Some common fixtures for testing
FilePath: \hydromodel\test\conftest.py
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
"""

import os
import numpy as np
import pytest
import spotpy
from spotpy.examples.spot_setup_hymod_python import spot_setup

from hydrodataset import Camels
from hydromodel import SETTING
Expand Down Expand Up @@ -52,3 +64,28 @@ def qobs(basin_area, camels, basins):
r = qobs["streamflow"] / basin_area["area_gages2"]
r_mmd = r.pint.to(target_unit)
return np.expand_dims(r_mmd.to_numpy().transpose(1, 0), axis=2)


@pytest.fixture(scope="session")
def hymod_setup():
"""
A pytest fixture that runs the hymod calibration and returns the results.
This will run before any test that requires it.
"""
setup = spot_setup(spotpy.objectivefunctions.rmse)
if not os.path.exists("test/SCEUA_hymod.csv"):
# Set up the hymod model and sampler

# Create SCE-UA sampler
sampler = spotpy.algorithms.sceua(
setup, dbname="test/SCEUA_hymod", dbformat="csv"
)

# Calibration parameters
repetitions = 5000 # Maximum iterations

# Run the sampler
sampler.sample(repetitions, ngs=7, kstop=3, peps=0.1, pcento=0.1)

# Return the results for further use
return setup
20 changes: 4 additions & 16 deletions test/test_data_visualize.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
Author: Wenyu Ouyang
Date: 2022-10-25 21:16:22
LastEditTime: 2024-09-11 21:25:42
LastEditTime: 2024-09-17 14:30:58
LastEditors: Wenyu Ouyang
Description: Test for results visualization
FilePath: \hydromodel\test\test_data_visualize.py
Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
"""

import spotpy
import matplotlib.pyplot as plt
import spotpy
from spotpy.examples.spot_setup_hymod_python import spot_setup as hymod_setup

from hydroutils import hydro_time

Expand All @@ -20,19 +20,7 @@
from hydromodel.trainers.evaluate import _read_save_sceua_calibrated_params


def test_run_hymod_calibration():
# a case from spotpy example
setup = hymod_setup(spotpy.objectivefunctions.rmse)

# 创建SCE-UA算法的sampler
sampler = spotpy.algorithms.sceua(setup, dbname="test/SCEUA_hymod", dbformat="csv")

# 设置校准参数
repetitions = 5000 # 最大迭代次数

# 运行sampler
sampler.sample(repetitions, ngs=7, kstop=3, peps=0.1, pcento=0.1)

def test_run_hymod_calibration(hymod_setup):
# 从CSV文件加载结果
results = spotpy.analyser.load_csv_results("test/SCEUA_hymod")

Expand All @@ -58,7 +46,7 @@ def test_run_hymod_calibration():
linestyle="solid",
label=f"Best objf.={str(bestobjf)}",
)
plt.plot(setup.evaluation(), "r.", markersize=3, label="Observation data")
plt.plot(hymod_setup.evaluation(), "r.", markersize=3, label="Observation data")
plt.xlabel("Number of Observation Points")
plt.ylabel("Discharge [l s-1]")
plt.legend(loc="upper right")
Expand Down
38 changes: 36 additions & 2 deletions test/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-09-14 17:00:29
LastEditTime: 2024-09-14 17:25:16
LastEditTime: 2024-09-17 14:32:30
LastEditors: Wenyu Ouyang
Description: Test the evaluate module
FilePath: \hydromodel\test\test_evaluate.py
Expand All @@ -12,7 +12,10 @@
import pytest
import xarray as xr
import numpy as np
from hydromodel.trainers.evaluate import Evaluator

from spotpy.analyser import get_minlikeindex

from hydromodel.trainers.evaluate import Evaluator, _get_minlikeindex_pandas


@pytest.fixture
Expand Down Expand Up @@ -95,3 +98,34 @@ def test_predict(sample_dataset, evaluator, mocker):
assert "prcp" in qsim
assert "prcp" in qobs
assert "prcp" in etsim


def test_get_minlikeindex(hymod_setup):
filename = "test/SCEUA_hymod"
# Read data using np.genfromtxt (original)
results_np = np.genfromtxt(
f"{filename}.csv", delimiter=",", names=True, invalid_raise=False
)

# Read data using pandas (new method)
results_pd = pd.read_csv(f"{filename}.csv")

# Run the original get_minlikeindex function
original_index, original_minimum = get_minlikeindex(
results_np, like_index=1, verbose=False
)

# Run the new get_minlikeindex_pandas function
new_index, new_minimum = _get_minlikeindex_pandas(
results_pd, like_index=1, verbose=False
)

# Compare the results from both methods
assert (
original_index == new_index
), f"Index mismatch: {original_index} != {new_index}"
assert np.isclose(
original_minimum, new_minimum
), f"Minimum value mismatch: {original_minimum} != {new_minimum}"

print("Test passed! Both methods return identical results.")

0 comments on commit a9c7067

Please sign in to comment.