From f7dc2bcbfa8f2c070b503a4b6295d3d0e3d165db Mon Sep 17 00:00:00 2001
From: ouyangwenyu <wenyuouyang@outlook.com>
Date: Fri, 22 Mar 2024 09:51:52 +0800
Subject: [PATCH] start testing for new version

---
 definitions.py                              |  18 -
 hydromodel/__init__.py                      |  81 ++++-
 hydromodel/datasets/data_postprocess.py     |  20 +-
 hydromodel/models/xaj_bmi.py                | 188 ++++++-----
 hydromodel/trainers/calibrate_ga_xaj_bmi.py |   6 +-
 hydromodel/trainers/plots.py                | 346 ++++++++++++--------
 test/picture.py                             | 116 -------
 test/test-xaj-bmi.py                        |  43 ---
 test/test_data.py                           | 171 +---------
 test/test_gr4j.py                           |  15 +-
 test/test_hydromodel.py                     |  24 --
 test/test_hymod.py                          |  11 +-
 test/test_rr_event_iden.py                  |  17 +-
 test/test_xaj.py                            |   9 +-
 test/test_xaj_bmi.py                        |   3 -
 15 files changed, 425 insertions(+), 643 deletions(-)
 delete mode 100644 definitions.py
 delete mode 100644 test/picture.py
 delete mode 100644 test/test-xaj-bmi.py
 delete mode 100644 test/test_hydromodel.py

diff --git a/definitions.py b/definitions.py
deleted file mode 100644
index dcadc13..0000000
--- a/definitions.py
+++ /dev/null
@@ -1,18 +0,0 @@
-"""
-Author: Wenyu Ouyang
-Date: 2021-07-26 08:51:23
-LastEditTime: 2022-11-16 18:47:10
-LastEditors: Wenyu Ouyang
-Description: some configs for hydro-model-xaj
-FilePath: \hydro-model-xaj\definitions.py
-Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
-"""
-import os
-from pathlib import Path
-
-ROOT_DIR = os.path.dirname(os.path.abspath('/home/ldaning/code/biye/hydro-model-xaj/definitions.py'))  # This is your Project Root
-path = Path(ROOT_DIR)
-DATASET_DIR = os.path.join(path.parent.parent.absolute(), "data")
-print("Please Check your directory:")
-print("ROOT_DIR of the repo: ", ROOT_DIR)
-print("DATASET_DIR of the repo: ", DATASET_DIR)
diff --git a/hydromodel/__init__.py b/hydromodel/__init__.py
index b62cfb8..fc8f51b 100644
--- a/hydromodel/__init__.py
+++ b/hydromodel/__init__.py
@@ -1,5 +1,80 @@
-"""Top-level package for hydromodel."""
+"""
+Author: Wenyu Ouyang
+Date: 2024-02-09 15:56:48
+LastEditTime: 2024-03-22 09:12:40
+LastEditors: Wenyu Ouyang
+Description: Top-level package for hydromodel
+FilePath: \hydro-model-xaj\hydromodel\__init__.py
+Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
+"""
+
+import os
+from pathlib import Path
+from hydroutils import hydro_file
+import yaml
 
 __author__ = """Wenyu Ouyang"""
-__email__ = 'wenyuouyang@outlook.com'
-__version__ = '0.0.1'
+__email__ = "wenyuouyang@outlook.com"
+__version__ = "0.0.1"
+
+
+CACHE_DIR = hydro_file.get_cache_dir()
+SETTING_FILE = os.path.join(Path.home(), "hydro_setting.yml")
+
+
+def read_setting(setting_path):
+    if not os.path.exists(setting_path):
+        raise FileNotFoundError(f"Configuration file not found: {setting_path}")
+
+    with open(setting_path, "r") as file:
+        setting = yaml.safe_load(file)
+
+    example_setting = (
+        "local_data_path:\n"
+        "  root: 'D:\\data\\waterism' # Update with your root data directory\n"
+        "  datasets-origin: 'D:\\data\\waterism\\datasets-origin' # datasets-origin is the directory you put downloaded datasets\n"
+        "  datasets-interim: 'D:\\data\\waterism\\datasets-interim' # the other choice for the directory you put downloaded datasets\n"
+        "  basins-origin: 'D:\\data\\waterism\\basins-origin' # the directory put your own data\n"
+        "  basins-interim: 'D:\\data\\waterism\\basins-interim' # the other choice for your own data"
+    )
+
+    if setting is None:
+        raise ValueError(
+            f"Configuration file is empty or has invalid format.\n\nExample configuration:\n{example_setting}"
+        )
+
+    # Define the expected structure
+    expected_structure = {
+        "local_data_path": [
+            "root",
+            "datasets-origin",
+            "datasets-interim",
+            "basins-origin",
+            "basins-interim",
+        ],
+    }
+
+    # Validate the structure
+    try:
+        for key, subkeys in expected_structure.items():
+            if key not in setting:
+                raise KeyError(f"Missing required key in config: {key}")
+
+            if isinstance(subkeys, list):
+                for subkey in subkeys:
+                    if subkey not in setting[key]:
+                        raise KeyError(f"Missing required subkey '{subkey}' in '{key}'")
+    except KeyError as e:
+        raise ValueError(
+            f"Incorrect configuration format: {e}\n\nExample configuration:\n{example_setting}"
+        ) from e
+
+    return setting
+
+
+try:
+    SETTING = read_setting(SETTING_FILE)
+except ValueError as e:
+    print(e)
+except Exception as e:
+    print(f"Unexpected error: {e}")
diff --git a/hydromodel/datasets/data_postprocess.py b/hydromodel/datasets/data_postprocess.py
index 9a4d80a..75d71fb 100644
--- a/hydromodel/datasets/data_postprocess.py
+++ b/hydromodel/datasets/data_postprocess.py
@@ -3,13 +3,9 @@
 import pandas as pd
 import pathlib
 import spotpy
-from pathlib import Path
-import sys
 
 from hydroutils import hydro_file
 
-sys.path.append(os.path.dirname(Path(os.path.abspath(__file__)).parent.parent))
-import definitions
 from hydromodel.models.model_config import MODEL_PARAM_DICT
 from hydromodel.models.xaj import xaj
 
@@ -34,23 +30,15 @@ def read_save_sceua_calibrated_params(basin_id, save_dir, sceua_calibrated_file_
 
     """
     results = spotpy.analyser.load_csv_results(sceua_calibrated_file_name)
-<<<<<<< HEAD
     bestindex, bestobjf = spotpy.analyser.get_minlikeindex(
         results
     )  # 结果数组中具有最小目标函数的位置的索引
-=======
-    bestindex, bestobjf = spotpy.analyser.get_minlikeindex(results) #结果数组中具有最小目标函数的位置的索引
->>>>>>> wangjingyi1999-event
     best_model_run = results[bestindex]
     fields = [word for word in best_model_run.dtype.names if word.startswith("par")] 
     best_calibrate_params = pd.DataFrame(list(best_model_run[fields]))
     save_file = os.path.join(save_dir, basin_id + "_calibrate_params.txt")
     best_calibrate_params.to_csv(save_file, sep=",", index=False, header=True)
-<<<<<<< HEAD
     return np.array(best_calibrate_params).reshape(1, -1)  # 返回一列最佳的结果
-=======
-    return np.array(best_calibrate_params).reshape(1, -1)    #返回一列最佳的结果
->>>>>>> wangjingyi1999-event
 
 
 def summarize_parameters(result_dir, model_info: dict):
@@ -223,12 +211,8 @@ def read_and_save_et_ouputs(result_dir, fold: int):
     train_period = data_info_train["time"]
     test_period = data_info_test["time"]
     # TODO: basins_lump_p_pe_q_fold NAME need to be unified
-    train_np_file = os.path.join(
-        exp_dir, "data_info_fold" + str(fold) + "_train.npy"
-    )
-    test_np_file = os.path.join(
-        exp_dir, "data_info_fold" + str(fold) + "_test.npy"
-    )
+    train_np_file = os.path.join(exp_dir, f"data_info_fold{fold}_train.npy")
+    test_np_file = os.path.join(exp_dir, f"data_info_fold{fold}_test.npy")
     # train_np_file = os.path.join(exp_dir, f"basins_lump_p_pe_q_fold{fold}_train.npy")
     # test_np_file = os.path.join(exp_dir, f"basins_lump_p_pe_q_fold{fold}_test.npy")
     train_data = np.load(train_np_file)
diff --git a/hydromodel/models/xaj_bmi.py b/hydromodel/models/xaj_bmi.py
index 2ee3fa9..44c4409 100644
--- a/hydromodel/models/xaj_bmi.py
+++ b/hydromodel/models/xaj_bmi.py
@@ -2,25 +2,27 @@
 from bmipy import Bmi
 import numpy as np
 
-from xaj.xajmodel import xaj_route, xaj_runoff
-from xaj.constant_unit import convert_unit,unit
-
-from grpc4bmi.constants import GRPC_MAX_MESSAGE_LENGTH
-import datetime 
+import datetime
 import pandas as pd
 import logging
-from xaj.configuration import configuration
 
 logger = logging.getLogger(__name__)
 
 PRECISION = 1e-5
+
+
 class xajBmi(Bmi):
     """Empty model wrapped in a BMI interface."""
-    name =  "hydro-model-xaj"
-    input_var_names = ("precipitation","ETp")
-    output_var_names = ("ET","discharge")
-    var_units = {"precipitation": "mm/day", "ETp": "mm/day", "discharge": "mm/day", "ET": "mm/day"}
 
+    name = "hydro-model-xaj"
+    input_var_names = ("precipitation", "ETp")
+    output_var_names = ("ET", "discharge")
+    var_units = {
+        "precipitation": "mm/day",
+        "ETp": "mm/day",
+        "discharge": "mm/day",
+        "ET": "mm/day",
+    }
 
     def __init__(self):
         """Create a BmiHeat model that is ready for initialization."""
@@ -30,53 +32,71 @@ def initialize(self, config_file):
         try:
             logger.info("xaj: initialize_model")
             config = configuration.read_config(config_file)
-            forcing_data = pd.read_csv(config['forcing_file'])
+            forcing_data = pd.read_csv(config["forcing_file"])
             p_and_e_df, p_and_e = configuration.extract_forcing(forcing_data)
-            p_and_e_warmup = p_and_e[0:config['warmup_length'],:,:]
-            params=np.tile([0.5], (1, 15))
-            self.q_sim_state,self.es_state,self.w0, self.w1,self.w2,self.s0, self.fr0, self.qi0, self.qg0 = configuration.warmup(p_and_e_warmup,params,config['warmup_length'])
-        
-            self._start_time_str,self._end_time_str, self._time_units = configuration.get_time_config(config) 
-            
+            p_and_e_warmup = p_and_e[0 : config["warmup_length"], :, :]
+            params = np.tile([0.5], (1, 15))
+            (
+                self.q_sim_state,
+                self.es_state,
+                self.w0,
+                self.w1,
+                self.w2,
+                self.s0,
+                self.fr0,
+                self.qi0,
+                self.qg0,
+            ) = configuration.warmup(p_and_e_warmup, params, config["warmup_length"])
+
+            self._start_time_str, self._end_time_str, self._time_units = (
+                configuration.get_time_config(config)
+            )
+
             self.params = params
-            self.warmup_length = config['warmup_length']
+            self.warmup_length = config["warmup_length"]
             self.p_and_e_df = p_and_e_df
             self.p_and_e = p_and_e
             self.config = config
-            self.basin_area = config['basin_area']
+            self.basin_area = config["basin_area"]
 
         except:
             import traceback
+
             traceback.print_exc()
             raise
 
-
     def update(self):
         """Update model for a single time step."""
-     
+
         self.time_step += 1
         # p_and_e_sim = self.p_and_e[self.warmup_length+1:self.time_step+self.warmup_length+1]
-        p_and_e_sim = self.p_and_e[1:self.time_step+1]
-        self.runoff_im, self.rss_,self.ris_, self.rgs_, self.es_runoff, self.rss=  xaj_runoff(p_and_e_sim,
-                w0=self.w0,  s0=self.s0, fr0=self.fr0,
+        p_and_e_sim = self.p_and_e[1 : self.time_step + 1]
+        self.runoff_im, self.rss_, self.ris_, self.rgs_, self.es_runoff, self.rss = (
+            xaj_runoff(
+                p_and_e_sim,
+                w0=self.w0,
+                s0=self.s0,
+                fr0=self.fr0,
                 params_runoff=self.params,
                 return_state=False,
-                )
-        if self.time_step+self.warmup_length+1 >= self.p_and_e.shape[0]:
-            q_sim,es = xaj_route(p_and_e_sim,
-                            params_route=self.params,
-                            model_name = "xaj",
-                            runoff_im=self.runoff_im, 
-                            rss_=self.rss_,
-                            ris_=self.ris_, 
-                            rgs_=self.rgs_,
-                            rss=self.rss,
-                            qi0=self.qi0,
-                            qg0=self.qg0, 
-                            es=self.es_runoff,
-                            )
-            self.p_sim = p_and_e_sim[:,:,0]
-            self.e_sim = p_and_e_sim[:,:,1]
+            )
+        )
+        if self.time_step + self.warmup_length + 1 >= self.p_and_e.shape[0]:
+            q_sim, es = xaj_route(
+                p_and_e_sim,
+                params_route=self.params,
+                model_name="xaj",
+                runoff_im=self.runoff_im,
+                rss_=self.rss_,
+                ris_=self.ris_,
+                rgs_=self.rgs_,
+                rss=self.rss,
+                qi0=self.qi0,
+                qg0=self.qg0,
+                es=self.es_runoff,
+            )
+            self.p_sim = p_and_e_sim[:, :, 0]
+            self.e_sim = p_and_e_sim[:, :, 1]
             q_sim = convert_unit(
                 q_sim,
                 unit_now="mm/day",
@@ -84,7 +104,7 @@ def update(self):
                 basin_area=float(self.basin_area),
             )
             self.q_sim = q_sim
-            self.es =  es
+            self.es = es
 
     def update_until(self, time):
         while self.get_current_time() + 0.001 < time:
@@ -105,7 +125,7 @@ def get_output_item_count(self) -> int:
 
     def get_input_var_names(self) -> Tuple[str]:
         return self.input_var_names
-    
+
     def get_output_var_names(self) -> Tuple[str]:
         return self.output_var_names
 
@@ -113,7 +133,7 @@ def get_var_grid(self, name: str) -> int:
         raise NotImplementedError()
 
     def get_var_type(self, name: str) -> str:
-        return 'float64'
+        return "float64"
 
     def get_var_units(self, name: str) -> str:
         return self.var_units[name]
@@ -126,20 +146,20 @@ def get_var_nbytes(self, name: str) -> int:
 
     def get_var_location(self, name: str) -> str:
         raise NotImplementedError()
-    
+
     def get_start_time(self):
         return self.start_Time(self._start_time_str)
 
     def get_current_time(self):
         # return self.start_Time(self._start_time_str) + datetime.timedelta(self.time_step+self.warmup_length)
-        if self._time_units == 'hours':
+        if self._time_units == "hours":
             time_step = datetime.timedelta(hours=self.time_step)
-        elif self._time_units == 'days':    
-            time_step = datetime.timedelta(days=self.time_step)    
-        return self.start_Time(self._start_time_str)+ time_step
+        elif self._time_units == "days":
+            time_step = datetime.timedelta(days=self.time_step)
+        return self.start_Time(self._start_time_str) + time_step
 
     def get_end_time(self):
-        return self.end_Time(self._end_time_str) 
+        return self.end_Time(self._end_time_str)
 
     def get_time_units(self) -> str:
         return self._time_units
@@ -147,35 +167,31 @@ def get_time_units(self) -> str:
     def get_time_step(self) -> float:
         return 1
 
-    def get_value(self, name: str) -> None: 
+    def get_value(self, name: str) -> None:
         logger.info("getting value for var %s", name)
         return self.get_value_ptr(name).flatten()
-        
+
     def get_value_ptr(self, name: str) -> np.ndarray:
-        if name == 'discharge':
+        if name == "discharge":
             return self.q_sim
-        elif name == 'ET':
+        elif name == "ET":
             return self.es
 
-    def get_value_at_indices(
-        self, name: str, inds: np.ndarray
-    ) -> np.ndarray:
-        
-        return self.get_value_ptr(name).take(inds)
+    def get_value_at_indices(self, name: str, inds: np.ndarray) -> np.ndarray:
 
+        return self.get_value_ptr(name).take(inds)
 
     def set_value(self, name: str, src: np.ndarray):
-        
+
         val = self.get_value_ptr(name)
         val[:] = src.reshape(val.shape)
-  
-        
+
     def set_value_at_indices(
         self, name: str, inds: np.ndarray, src: np.ndarray
     ) -> None:
         val = self.get_value_ptr(name)
         val.flat[inds] = src
-        
+
     # Grid information
     def get_grid_rank(self, grid: int) -> int:
         raise NotImplementedError()
@@ -228,47 +244,47 @@ def get_grid_nodes_per_face(
         self, grid: int, nodes_per_face: np.ndarray
     ) -> np.ndarray:
         raise NotImplementedError()
-    
-    
+
     def start_Time(self, _start_time_str):
 
         try:
-            if " " in _start_time_str:    
-                date, time = _start_time_str.split(" ")    
+            if " " in _start_time_str:
+                date, time = _start_time_str.split(" ")
             else:
-                date = _start_time_str   
-                time = None            
+                date = _start_time_str
+                time = None
             year, month, day = date.split("-")
             self._startTime = datetime.date(int(year), int(month), int(day))
-        
+
             if time:
-                hour, minute, second = time.split(":")      
-                # self._startTime = self._startTime.replace(hour=int(hour),   
-                #                                       minute=int(minute),  
+                hour, minute, second = time.split(":")
+                # self._startTime = self._startTime.replace(hour=int(hour),
+                #                                       minute=int(minute),
                 #                                       second=int(second))
-                self._startTime = datetime.datetime(int(year), int(month), int(day),int(hour),int(minute), int(second))
-        except ValueError:    
+                self._startTime = datetime.datetime(
+                    int(year), int(month), int(day), int(hour), int(minute), int(second)
+                )
+        except ValueError:
             raise ValueError("Invalid start date format!")
 
         return self._startTime
-    
+
     def end_Time(self, _end_time_str):
 
         try:
-            if " " in _end_time_str:    
-                date, time = _end_time_str.split(" ")    
+            if " " in _end_time_str:
+                date, time = _end_time_str.split(" ")
             else:
-                date = _end_time_str   
-                time = None            
+                date = _end_time_str
+                time = None
             year, month, day = date.split("-")
             self._endTime = datetime.date(int(year), int(month), int(day))
-        
+
             if time:
-                hour, minute, second = time.split(":")      
-                self._endTime = datetime.datetime(int(year), int(month), int(day),int(hour),int(minute), int(second))
-        except ValueError:    
+                hour, minute, second = time.split(":")
+                self._endTime = datetime.datetime(
+                    int(year), int(month), int(day), int(hour), int(minute), int(second)
+                )
+        except ValueError:
             raise ValueError("Invalid start date format!")
         return self._endTime
-    
-        
- 
diff --git a/hydromodel/trainers/calibrate_ga_xaj_bmi.py b/hydromodel/trainers/calibrate_ga_xaj_bmi.py
index bef49d6..fd1de7f 100644
--- a/hydromodel/trainers/calibrate_ga_xaj_bmi.py
+++ b/hydromodel/trainers/calibrate_ga_xaj_bmi.py
@@ -1,9 +1,8 @@
 """Calibrate XAJ model using DEAP"""
+
 import os
 import pickle
 import random
-import sys
-from pathlib import Path
 
 import numpy as np
 import pandas as pd
@@ -14,12 +13,9 @@
 import HydroErr as he
 from hydroutils import hydro_stat, hydro_file
 
-sys.path.append(os.path.dirname(Path(os.path.abspath(__file__)).parent.parent))
-import definitions
 from hydromodel.models.model_config import MODEL_PARAM_DICT
 from hydromodel.trainers.plots import plot_sim_and_obs, plot_train_iteration
 from hydromodel.models.xaj_bmi import xajBmi
-from hydromodel.utils import units
 
 
 def evaluate(individual, x_input, y_true, warmup_length, model):
diff --git a/hydromodel/trainers/plots.py b/hydromodel/trainers/plots.py
index 1e942ab..84562d0 100644
--- a/hydromodel/trainers/plots.py
+++ b/hydromodel/trainers/plots.py
@@ -1,10 +1,10 @@
 """
 Author: Wenyu Ouyang
 Date: 2022-10-25 21:16:22
-LastEditTime: 2024-03-21 20:07:52
+LastEditTime: 2024-03-22 09:51:19
 LastEditors: Wenyu Ouyang
 Description: Plots for calibration and testing results
-FilePath: \hydro-model-xaj\hydromodel\utils\plots.py
+FilePath: \hydro-model-xaj\hydromodel\trainers\plots.py
 Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
 """
 
@@ -16,8 +16,6 @@
 
 from hydroutils import hydro_file, hydro_stat
 
-from hydromodel.utils import units
-
 
 def plot_sim_and_obs(
     date,
@@ -25,7 +23,7 @@ def plot_sim_and_obs(
     obs,
     save_fig,
     xlabel="Date",
-    ylabel="Streamflow(" + units.unit["streamflow"] + ")",
+    ylabel=None,
 ):
     # matplotlib.use("Agg")
     fig = plt.figure(figsize=(9, 6))
@@ -156,79 +154,115 @@ def show_calibrate_result(
         stat_error, os.path.join(save_dir, "train_metrics.json")
     )
 
-    #循还画图
-    time = pd.read_excel('D:/研究生/毕业论文/new毕业论文/预答辩/碧流河水库/站点信息/洪水率定时间.xlsx')
+    # 循还画图
+    time = pd.read_excel(
+        "D:/研究生/毕业论文/new毕业论文/预答辩/碧流河水库/站点信息/洪水率定时间.xlsx"
+    )
     calibrate_starttime = pd.to_datetime("2012-06-10 0:00:00")
     calibrate_endtime = pd.to_datetime("2019-12-31 23:00:00")
     basin_area = float(basin_area)
-    best_simulation = [x * (basin_area*1000000/1000/3600) for x in best_simulation]
-    obs = [x * (basin_area*1000000/1000/3600) for x in spot_setup.evaluation()]
-    time['starttime']=pd.to_datetime(time['starttime'])
-    time['endtime']=pd.to_datetime(time['endtime'])
-    Prcp_list=[]
-    W_obs_list=[]
-    W_sim_list=[]
-    W_bias_abs_list=[]
-    W_bias_rela_list=[]
-    Q_max_obs_list=[]
-    Q_max_sim_list=[]
-    Q_bias_rela_list=[]
-    time_bias_list=[]
-    DC_list=[]
-    ID_list=[]
+    best_simulation = [
+        x * (basin_area * 1000000 / 1000 / 3600) for x in best_simulation
+    ]
+    obs = [x * (basin_area * 1000000 / 1000 / 3600) for x in spot_setup.evaluation()]
+    time["starttime"] = pd.to_datetime(time["starttime"])
+    time["endtime"] = pd.to_datetime(time["endtime"])
+    Prcp_list = []
+    W_obs_list = []
+    W_sim_list = []
+    W_bias_abs_list = []
+    W_bias_rela_list = []
+    Q_max_obs_list = []
+    Q_max_sim_list = []
+    Q_bias_rela_list = []
+    time_bias_list = []
+    DC_list = []
+    ID_list = []
     for i, row in time.iterrows():
-    # for i in range(len(time)):
-        if(row['starttime']<calibrate_endtime):
-        # if(time["starttime",0]<calibrate_endtime):
-                start_num = (row['starttime']-calibrate_starttime-pd.Timedelta(hours=warmup_length))/pd.Timedelta(hours=1)   
-                end_num = (row['endtime']-calibrate_starttime-pd.Timedelta(hours=warmup_length))/pd.Timedelta(hours=1)
-                start_period = (row['endtime']-calibrate_starttime)/pd.Timedelta(hours=1)
-                end_period = (row['endtime']-calibrate_starttime)/pd.Timedelta(hours=1)
-                start_period = int(start_period)
-                end_period = int(end_period)
-                start_num = int(start_num)
-                end_num = int(end_num)
-                t_range_train_changci = pd.date_range(row['starttime'],row['endtime'],freq='H')
-                save_fig = os.path.join(save_dir, "train_results"+str(i)+".png")
-                best_simulation_changci = best_simulation[start_num:end_num+1]
-                plot_sim_and_obs(t_range_train_changci, best_simulation[start_num:end_num+1], obs[start_num:end_num+1],prcp[start_num:end_num+1],save_fig)
-                Prcp=sum(prcp[start_num:end_num+1])
-                W_obs=sum(obs[start_num:end_num+1])*3600*1000/basin_area/1000000
-                W_sim = sum(best_simulation_changci) * 3600 * 1000 /basin_area/ 1000000
-                W_bias_abs=W_sim-W_obs
-                W_bias_rela = W_bias_abs/W_obs
-                Q_max_obs=np.max(obs[start_num:end_num+1])
-                Q_max_sim=np.max(best_simulation_changci)
-                Q_bias_rela = (Q_max_sim-Q_max_obs)/Q_max_obs
-                t1 =np.argmax(best_simulation_changci)
-                t2 =np.argmax(obs[start_num:end_num+1])
-                time_bias = t1-t2
-                DC = NSE(obs[start_num:end_num+1],best_simulation_changci)
-                ID = row['starttime'].strftime('%Y%m%d')
-                Prcp_list.append(Prcp)
-                W_obs_list.append(W_obs)
-                W_sim_list.append(W_sim)
-                W_bias_abs_list.append(W_bias_abs)
-                W_bias_rela_list.append(W_bias_rela)
-                Q_max_obs_list.append(Q_max_obs)
-                Q_max_sim_list.append(Q_max_sim)
-                Q_bias_rela_list.append(Q_bias_rela)
-                time_bias_list.append(time_bias)
-            
-                DC_list.append(DC)
-                ID_list.append(ID)
-                
-    bias =pd.DataFrame({"Prcp(mm)":Prcp_list,"W_obs(mm)":W_obs_list,
-                        "W_sim(mm)":W_sim_list,"W_bias_abs":W_bias_abs_list,
-                        "W_bias_rela":W_bias_rela_list,"Q_max_obs(m3/s)":Q_max_obs_list,
-                        "Q_max_sim(m3/s)":Q_max_sim_list,"Q_bias_rela":Q_bias_rela_list,
-                        "time_bias":time_bias_list,"DC":DC_list,"ID":ID_list})            
-    bias.to_csv(os.path.join('D:/研究生/毕业论文/new毕业论文/预答辩/碧流河水库/站点信息/train_metrics.csv'))
+        # for i in range(len(time)):
+        if row["starttime"] < calibrate_endtime:
+            # if(time["starttime",0]<calibrate_endtime):
+            start_num = (
+                row["starttime"]
+                - calibrate_starttime
+                - pd.Timedelta(hours=warmup_length)
+            ) / pd.Timedelta(hours=1)
+            end_num = (
+                row["endtime"] - calibrate_starttime - pd.Timedelta(hours=warmup_length)
+            ) / pd.Timedelta(hours=1)
+            start_period = (row["endtime"] - calibrate_starttime) / pd.Timedelta(
+                hours=1
+            )
+            end_period = (row["endtime"] - calibrate_starttime) / pd.Timedelta(hours=1)
+            start_period = int(start_period)
+            end_period = int(end_period)
+            start_num = int(start_num)
+            end_num = int(end_num)
+            t_range_train_changci = pd.date_range(
+                row["starttime"], row["endtime"], freq="H"
+            )
+            save_fig = os.path.join(save_dir, "train_results" + str(i) + ".png")
+            best_simulation_changci = best_simulation[start_num : end_num + 1]
+            plot_sim_and_obs(
+                t_range_train_changci,
+                best_simulation[start_num : end_num + 1],
+                obs[start_num : end_num + 1],
+                prcp[start_num : end_num + 1],
+                save_fig,
+            )
+            Prcp = sum(prcp[start_num : end_num + 1])
+            W_obs = (
+                sum(obs[start_num : end_num + 1]) * 3600 * 1000 / basin_area / 1000000
+            )
+            W_sim = sum(best_simulation_changci) * 3600 * 1000 / basin_area / 1000000
+            W_bias_abs = W_sim - W_obs
+            W_bias_rela = W_bias_abs / W_obs
+            Q_max_obs = np.max(obs[start_num : end_num + 1])
+            Q_max_sim = np.max(best_simulation_changci)
+            Q_bias_rela = (Q_max_sim - Q_max_obs) / Q_max_obs
+            t1 = np.argmax(best_simulation_changci)
+            t2 = np.argmax(obs[start_num : end_num + 1])
+            time_bias = t1 - t2
+            DC = NSE(obs[start_num : end_num + 1], best_simulation_changci)
+            ID = row["starttime"].strftime("%Y%m%d")
+            Prcp_list.append(Prcp)
+            W_obs_list.append(W_obs)
+            W_sim_list.append(W_sim)
+            W_bias_abs_list.append(W_bias_abs)
+            W_bias_rela_list.append(W_bias_rela)
+            Q_max_obs_list.append(Q_max_obs)
+            Q_max_sim_list.append(Q_max_sim)
+            Q_bias_rela_list.append(Q_bias_rela)
+            time_bias_list.append(time_bias)
+
+            DC_list.append(DC)
+            ID_list.append(ID)
+
+    bias = pd.DataFrame(
+        {
+            "Prcp(mm)": Prcp_list,
+            "W_obs(mm)": W_obs_list,
+            "W_sim(mm)": W_sim_list,
+            "W_bias_abs": W_bias_abs_list,
+            "W_bias_rela": W_bias_rela_list,
+            "Q_max_obs(m3/s)": Q_max_obs_list,
+            "Q_max_sim(m3/s)": Q_max_sim_list,
+            "Q_bias_rela": Q_bias_rela_list,
+            "time_bias": time_bias_list,
+            "DC": DC_list,
+            "ID": ID_list,
+        }
+    )
+    bias.to_csv(
+        os.path.join(
+            "D:/研究生/毕业论文/new毕业论文/预答辩/碧流河水库/站点信息/train_metrics.csv"
+        )
+    )
     t_range_train = pd.to_datetime(train_period[warmup_length:]).values.astype(
         "datetime64[h]"
     )
-    save_fig = os.path.join(save_dir, "train_results.png")   #生成结果图
-    plot_sim_and_obs(t_range_train, best_simulation,obs,prcp[:],save_fig)
+    save_fig = os.path.join(save_dir, "train_results.png")  # 生成结果图
+    plot_sim_and_obs(t_range_train, best_simulation, obs, prcp[:], save_fig)
 
 
 def show_test_result(basin_id, test_date, qsim, obs, save_dir):
@@ -237,12 +271,14 @@ def show_test_result(basin_id, test_date, qsim, obs, save_dir):
     hydro_file.serialize_json_np(
         stat_error, os.path.join(save_dir, "test_metrics.json")
     )
-    time = pd.read_excel('D:/研究生/毕业论文/new毕业论文/预答辩/碧流河水库/站点信息/洪水率定时间.xlsx')
+    time = pd.read_excel(
+        "D:/研究生/毕业论文/new毕业论文/预答辩/碧流河水库/站点信息/洪水率定时间.xlsx"
+    )
     test_starttime = pd.to_datetime("2020-01-01 00:00:00")
     test_endtime = pd.to_datetime("2022-08-31 23:00:00")
     # for i in range(len(time)):
     #     if(test_starttime<time.iloc[i,0]<test_endtime):
-    #             start_num = (time.iloc[i,0]-test_starttime-pd.Timedelta(hours=warmup_length))/pd.Timedelta(hours=1)   
+    #             start_num = (time.iloc[i,0]-test_starttime-pd.Timedelta(hours=warmup_length))/pd.Timedelta(hours=1)
     #             end_num = (time.iloc[i,1]-test_starttime-pd.Timedelta(hours=warmup_length))/pd.Timedelta(hours=1)
     #             start_period = (time.iloc[i,0]-test_starttime)/pd.Timedelta(hours=1)
     #             end_period = (time.iloc[i,1]-test_starttime)/pd.Timedelta(hours=1)
@@ -253,64 +289,93 @@ def show_test_result(basin_id, test_date, qsim, obs, save_dir):
     #             t_range_test_changci = pd.to_datetime(test_date[start_period:end_period]).values.astype("datetime64[h]")
     #             save_fig = os.path.join(save_dir, "test_results"+str(i)+".png")
     #             plot_sim_and_obs(t_range_test_changci, qsim.flatten()[start_num:end_num],obs.flatten()[start_num:end_num], prcp[start_num:end_num],save_fig)
-    Prcp_list=[]
-    W_obs_list=[]
-    W_sim_list=[]
-    W_bias_abs_list=[]
-    W_bias_rela_list=[]
-    Q_max_obs_list=[]
-    Q_max_sim_list=[]
-    Q_bias_rela_list=[]
-    time_bias_list=[]
-    DC_list=[]
-    ID_list=[]
+    Prcp_list = []
+    W_obs_list = []
+    W_sim_list = []
+    W_bias_abs_list = []
+    W_bias_rela_list = []
+    Q_max_obs_list = []
+    Q_max_sim_list = []
+    Q_bias_rela_list = []
+    time_bias_list = []
+    DC_list = []
+    ID_list = []
     for i, row in time.iterrows():
-        if(test_starttime<row['starttime']<test_endtime):
-                start_num = (row['starttime']-test_starttime-pd.Timedelta(hours=warmup_length))/pd.Timedelta(hours=1)   
-                end_num = (row['endtime']-test_starttime-pd.Timedelta(hours=warmup_length))/pd.Timedelta(hours=1)
-                start_period = (row['endtime']-test_starttime)/pd.Timedelta(hours=1)
-                end_period = (row['endtime']-test_starttime)/pd.Timedelta(hours=1)
-                start_period = int(start_period)
-                end_period = int(end_period)
-                start_num = int(start_num)
-                end_num = int(end_num)
-                t_range_train_changci = pd.date_range(row['starttime'],row['endtime'],freq='H')
-                save_fig = os.path.join(save_dir, "test_results"+str(i)+".png")
-                plot_sim_and_obs(t_range_train_changci, qsim.flatten()[start_num:end_num+1], obs.flatten()[start_num:end_num+1],prcp[start_num:end_num+1],save_fig)
-                Prcp=sum(prcp[start_num:end_num+1])
-                W_obs=sum(obs.flatten()[start_num:end_num+1])
-                W_sim =sum(qsim.flatten()[start_num:end_num+1])
-                W_bias_abs=W_sim-W_obs
-                W_bias_rela = W_bias_abs/W_obs
-                Q_max_obs=np.max(obs[start_num:end_num+1])
-                Q_max_sim=np.max(qsim.flatten()[start_num:end_num+1])
-                Q_bias_rela = (Q_max_sim-Q_max_obs)/Q_max_obs
-                t1 =np.argmax(qsim.flatten()[start_num:end_num+1])
-                t2 =np.argmax(obs[start_num:end_num+1])
-                time_bias = t1-t2
-                DC = NSE(obs.flatten()[start_num:end_num+1],qsim.flatten()[start_num:end_num+1])
-                ID = row['starttime'].strftime('%Y%m%d')
-                Prcp_list.append(Prcp)
-                W_obs_list.append(W_obs)
-                W_sim_list.append(W_sim)
-                W_bias_abs_list .append(W_bias_abs)
-                W_bias_rela_list.append(W_bias_rela)
-                Q_max_obs_list.append(Q_max_obs)
-                Q_max_sim_list.append(Q_max_sim)
-                Q_bias_rela_list.append(Q_bias_rela)
-                time_bias_list.append(time_bias)
-                DC_list.append(DC)
-                ID_list.append(ID)
-                
-    bias =pd.DataFrame({"Prcp(mm)":Prcp_list,"W_obs(mm)":W_obs_list,
-                        "W_sim(mm)":W_sim_list,"W_bias_abs":W_bias_abs_list,
-                        "W_bias_rela":W_bias_rela_list,"Q_max_obs(m3/s)":Q_max_obs_list,
-                        "Q_max_sim(m3/s)":Q_max_sim_list,"Q_bias_rela":Q_bias_rela_list,
-                        "time_bias":time_bias_list,"DC":DC_list,"ID":ID_list})            
-    bias.to_csv(os.path.join('D:/研究生/毕业论文/new毕业论文/预答辩/碧流河水库/站点信息/test_metrics.csv'))   
-        
+        if test_starttime < row["starttime"] < test_endtime:
+            start_num = (
+                row["starttime"] - test_starttime - pd.Timedelta(hours=warmup_length)
+            ) / pd.Timedelta(hours=1)
+            end_num = (
+                row["endtime"] - test_starttime - pd.Timedelta(hours=warmup_length)
+            ) / pd.Timedelta(hours=1)
+            start_period = (row["endtime"] - test_starttime) / pd.Timedelta(hours=1)
+            end_period = (row["endtime"] - test_starttime) / pd.Timedelta(hours=1)
+            start_period = int(start_period)
+            end_period = int(end_period)
+            start_num = int(start_num)
+            end_num = int(end_num)
+            t_range_train_changci = pd.date_range(
+                row["starttime"], row["endtime"], freq="H"
+            )
+            save_fig = os.path.join(save_dir, "test_results" + str(i) + ".png")
+            plot_sim_and_obs(
+                t_range_train_changci,
+                qsim.flatten()[start_num : end_num + 1],
+                obs.flatten()[start_num : end_num + 1],
+                prcp[start_num : end_num + 1],
+                save_fig,
+            )
+            Prcp = sum(prcp[start_num : end_num + 1])
+            W_obs = sum(obs.flatten()[start_num : end_num + 1])
+            W_sim = sum(qsim.flatten()[start_num : end_num + 1])
+            W_bias_abs = W_sim - W_obs
+            W_bias_rela = W_bias_abs / W_obs
+            Q_max_obs = np.max(obs[start_num : end_num + 1])
+            Q_max_sim = np.max(qsim.flatten()[start_num : end_num + 1])
+            Q_bias_rela = (Q_max_sim - Q_max_obs) / Q_max_obs
+            t1 = np.argmax(qsim.flatten()[start_num : end_num + 1])
+            t2 = np.argmax(obs[start_num : end_num + 1])
+            time_bias = t1 - t2
+            DC = NSE(
+                obs.flatten()[start_num : end_num + 1],
+                qsim.flatten()[start_num : end_num + 1],
+            )
+            ID = row["starttime"].strftime("%Y%m%d")
+            Prcp_list.append(Prcp)
+            W_obs_list.append(W_obs)
+            W_sim_list.append(W_sim)
+            W_bias_abs_list.append(W_bias_abs)
+            W_bias_rela_list.append(W_bias_rela)
+            Q_max_obs_list.append(Q_max_obs)
+            Q_max_sim_list.append(Q_max_sim)
+            Q_bias_rela_list.append(Q_bias_rela)
+            time_bias_list.append(time_bias)
+            DC_list.append(DC)
+            ID_list.append(ID)
+
+    bias = pd.DataFrame(
+        {
+            "Prcp(mm)": Prcp_list,
+            "W_obs(mm)": W_obs_list,
+            "W_sim(mm)": W_sim_list,
+            "W_bias_abs": W_bias_abs_list,
+            "W_bias_rela": W_bias_rela_list,
+            "Q_max_obs(m3/s)": Q_max_obs_list,
+            "Q_max_sim(m3/s)": Q_max_sim_list,
+            "Q_bias_rela": Q_bias_rela_list,
+            "time_bias": time_bias_list,
+            "DC": DC_list,
+            "ID": ID_list,
+        }
+    )
+    bias.to_csv(
+        os.path.join(
+            "D:/研究生/毕业论文/new毕业论文/预答辩/碧流河水库/站点信息/test_metrics.csv"
+        )
+    )
+
     save_fig = os.path.join(save_dir, "test_results.png")
-   
+
     plot_sim_and_obs(
         test_date[365:],
         qsim.flatten(),
@@ -318,21 +383,20 @@ def show_test_result(basin_id, test_date, qsim, obs, save_dir):
         prcp[:],
         save_fig,
     )
-    
-    
-    
-def NSE(obs,mol):
+
+
+def NSE(obs, mol):
     numerator = 0
     denominator = 0
     meangauge = 0
     count = 0
     for i in range(len(obs)):
-        if (obs[i]>=0):
-            numerator+=pow(abs(mol[i])-obs[i],2)
-            meangauge+=obs[i]
-            count+=1
-    meangauge=meangauge/count
+        if obs[i] >= 0:
+            numerator += pow(abs(mol[i]) - obs[i], 2)
+            meangauge += obs[i]
+            count += 1
+    meangauge = meangauge / count
     for i in range(len(obs)):
-        if (obs[i]>=0):
-            denominator+=pow(obs[i]-meangauge,2)
-    return 1-numerator/denominator
\ No newline at end of file
+        if obs[i] >= 0:
+            denominator += pow(obs[i] - meangauge, 2)
+    return 1 - numerator / denominator
diff --git a/test/picture.py b/test/picture.py
deleted file mode 100644
index 5fa0e1a..0000000
--- a/test/picture.py
+++ /dev/null
@@ -1,116 +0,0 @@
-from matplotlib import pyplot as plt
-import pandas as pd
-import os
-import numpy as np
-from numpy import *
-import matplotlib.dates as mdates
-import sys
-from pathlib import Path
-sys.path.append(os.path.dirname(Path(os.path.abspath(__file__)).parent.parent))
-# from hydromodel.utils import hydro_constant
-
-time = pd.read_excel('D:/研究生/毕业论文/new毕业论文/预答辩/碧流河水库/站点信息/DMCA.xlsx')
-time['starttime'] = pd.to_datetime(time['starttime'], format='%d/%m/%Y %H:%M')
-time['endtime'] = pd.to_datetime(time['endtime'], format='%d/%m/%Y %H:%M')
-sim = pd.read_excel('D:/研究生/毕业论文/new毕业论文/预答辩/碧流河水库/站点信息/picture.xlsx')
-sim['date'] = pd.to_datetime(sim['date'], format='%d/%m/%Y %H:%M')
-for i  in range(len(time)):
-    start_time = time['starttime'][i]
-    end_time = time['endtime'][i]
-    start_num = np.where(sim['date'] == start_time)[0]
-    end_num = np.where(sim['date'] == end_time)[0]
-    # date = pd.date_range(start_time, end_time, freq='H')
-    start_num = int(start_num)
-    end_num = int(end_num)
-    date =sim['date'][start_num:end_num]
-    sim_xaj = sim['sim_xaj'][start_num:end_num]
-    sim_dhf = sim['sim_dhf'][start_num:end_num]
-    obs = sim['streamflow(m3/s)'][start_num:end_num]
-    prcp = sim['prcp(mm/hour)'][start_num:end_num]
-    fig = plt.figure(figsize=(9,6),dpi=500)
-    ax = fig.subplots()
-    ax.plot(
-        date,
-        sim_xaj,
-        color="blue",
-        linestyle="-",
-        linewidth=1,
-        label="Simulation_xaj",
-    )
-    ax.plot(
-        date,
-        sim_dhf,
-        color="green",
-        linestyle="-",
-        linewidth=1,
-        label="Simulation_dhf",
-    )
-    ax.plot(
-        date,
-        obs,
-        # "r.",
-        color="black",
-        linestyle="-",
-        linewidth=1,
-        label="Observation",
-    )
-    ylim = np.max(np.vstack((obs, sim_xaj)))
-    print(start_time)
-    ax.set_ylim(0, ylim*1.3) 
-    ax.xaxis.set_major_formatter(mdates.DateFormatter("%y-%m-%d"))
-    xlabel="Date(∆t=1hour)"
-    ylabel="Streamflow(m^3/s)"
-    ax.set_xlabel(xlabel)
-    ax.set_ylabel(ylabel)
-    plt.legend(loc="upper right")
-    # sim_xaj = np.array(sim_xaj)
-    # obs = np.array(obs)
-    # numerator = 0
-    # denominator = 0
-    # meangauge = 0
-    # count = 0
-    # for h in range(len(obs)):
-    #     if (obs[h]>=0):
-    #         numerator+=pow(abs(sim_xaj[h])-obs[h],2)
-    #         meangauge+=obs[h]
-    #         count+=1
-    # meangauge=meangauge/count
-    # for m in range(len(obs)):
-    #     if (obs[m]>=0):
-    #         denominator+=pow(obs[m]-meangauge,2)
-    # NSE= 1-numerator/denominator
-    # plt.text(0.9, 0.6, 'NSE=%.2f' % NSE, 
-    #      horizontalalignment='center',  
-    #      verticalalignment='center',
-    #      transform = ax.transAxes,
-    #      fontsize=10)
-
-    ax2 = ax.twinx()
-    ax2.bar(date,prcp, label='Precipitation', color='royalblue',alpha=0.9,width=0.05)
-    ax2.set_ylabel('Precipitation(mm)')
-    plt.yticks(fontproperties = 'Times New Roman', size = 10)
-    prcp_max = np.max(prcp)
-    ax2.set_ylim(0, prcp_max*4)
-    ax2.invert_yaxis()  #y轴反向
-    ax2.legend(loc='upper left')
-    plt.tight_layout()  # 自动调整子图参数,使之填充整个图像区域
-    save_fig = os.path.join('D:/研究生/毕业论文/new毕业论文/预答辩/碧流河水库/站点信息/plot', "results"+str(i)+".png")
-    plt.savefig(save_fig, bbox_inches="tight")
-    plt.close()
-
-
-def NSE(obs,sim_xaj):
-    numerator = 0
-    denominator = 0
-    meangauge = 0
-    count = 0
-    for i in range(len(obs)):
-        if (obs[i]>=0):
-            numerator+=pow(abs(sim_xaj[i])-obs[i],2)
-            meangauge+=obs[i]
-            count+=1
-    meangauge=meangauge/count
-    for i in range(len(obs)):
-        if (obs[i]>=0):
-            denominator+=pow(obs[i]-meangauge,2)
-    NSE= 1-numerator/denominator
\ No newline at end of file
diff --git a/test/test-xaj-bmi.py b/test/test-xaj-bmi.py
deleted file mode 100644
index b35e994..0000000
--- a/test/test-xaj-bmi.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import logging
-logging.basicConfig(level=logging.INFO)
-from xaj_bmi import xajBmi
-import pandas as pd
-# from test.test_xaj import test_xaj
-# from configuration import configuration
-# import numpy as np
-model = xajBmi()
-print(model.get_component_name())
-
-
-model.initialize("xaj/runxaj.yaml")
-print("Start time:", model.get_start_time())
-print("End time:", model.get_end_time())
-print("Current time:", model.get_current_time())
-print("Time step:", model.get_time_step())
-print("Time units:", model.get_time_units())
-print(model.get_input_var_names())
-print(model.get_output_var_names())
-
-discharge = []
-ET = []
-time = []                                          
-while model.get_current_time() <= model.get_end_time():
-    time.append(model.get_current_time())
-    model.update()
-
-discharge=model.get_value("discharge")
-ET=model.get_value("ET")
-
-results = pd.DataFrame({
-                'discharge': discharge.flatten(),
-                'ET': ET.flatten(),  
-            })
-results.to_csv('/home/wangjingyi/code/hydro-model-xaj/scripts/xaj.csv')
-model.finalize()
-# params=np.tile([0.5], (1, 15))
-# config = configuration.read_config("scripts/runxaj.yaml")
-# forcing_data = pd.read_csv(config['forcing_file'])
-# p_and_e_df, p_and_e = configuration.extract_forcing(forcing_data)
-# test_xaj(p_and_e=p_and_e,params=params,warmup_length=360)
-
-
diff --git a/test/test_data.py b/test/test_data.py
index 2af99b3..5be94a9 100644
--- a/test/test_data.py
+++ b/test/test_data.py
@@ -1,173 +1,26 @@
 """
 Author: Wenyu Ouyang
 Date: 2022-10-25 21:16:22
-LastEditTime: 2024-03-21 18:44:13
+LastEditTime: 2024-03-22 09:26:38
 LastEditors: Wenyu Ouyang
 Description: Test for data preprocess
 FilePath: \hydro-model-xaj\test\test_data.py
 Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
 """
 
-import os
-from collections import OrderedDict
+from hydrodataset import Camels
 
-import numpy as np
-import pandas as pd
-import pytest
-import fnmatch
-import socket
-from datetime import datetime
-import pathlib
+from hydromodel import SETTING
 
-from hydroutils import hydro_file
 
-import definitions
-from hydromodel.utils import hydro_utils
-from hydromodel.datasets.data_preprocess import (
-    cross_valid_data,
-    split_train_test,
-)
-
-
-# @pytest.fixture()
-# def txt_file():
-#     return os.path.join(
-#         definitions.ROOT_DIR, "hydromodel", "example", "01013500_lump_p_pe_q.txt"
-#     )
-
-
-# @pytest.fixture()
-# def json_file():
-#     return os.path.join(definitions.ROOT_DIR, "hydromodel", "example", "data_info.json")
-
-
-# @pytest.fixture()
-# def npy_file():
-#     return os.path.join(
-#         definitions.ROOT_DIR, "hydromodel", "example", "basins_lump_p_pe_q.npy"
-#     )
-
-txt_file = pathlib.Path(
-    "/home/ldaning/code/biye/hydro-model-xaj/hydromodel/example/wuxi.csv"
-)
-forcing_data = pathlib.Path(
-    "/home/ldaning/code/biye/hydro-model-xaj/hydromodel/example/wuxi.csv"
-)
-json_file = pathlib.Path(
-    "/home/ldaning/code/biye/hydro-model-xaj/hydromodel/example/model_run_wuxi7/data_info.json"
-)
-npy_file = pathlib.Path(
-    "/home/ldaning/code/biye/hydro-model-xaj/hydromodel/example/model_run_wuxi7/data_info.npy"
-)
-
-
-# def test_save_data(txt_file, json_file, npy_file):
-data = pd.read_csv(txt_file)
-datetime_index = pd.to_datetime(data["date"], format="%Y/%m/%d %H:%M")
-# Note: The units are all mm/day! For streamflow, data is divided by basin's area
-# variables = ["prcp(mm/day)", "petfao56(mm/day)", "streamflow(mm/day)"]
-variables = ["prcp(mm/hour)", "pev(mm/hour)", "streamflow(m3/s)"]
-data_info = OrderedDict(
-    {
-        "time": data["date"].values.tolist(),
-        "basin": ["wuxi"],
-        "variable": variables,
-        "area": ["1992.62"],
-    }
-)
-hydro_utils.serialize_json(data_info, json_file)
-# 1 ft3 = 0.02831685 m3
-# ft3tom3 = 2.831685e-2
-
-# 1 km2 = 10^6 m2
-km2tom2 = 1e6
-# 1 m = 1000 mm
-mtomm = 1000
-# 1 day = 24 * 3600 s
-# daytos = 24 * 3600
-hourtos = 3600
-# trans ft3/s to mm/day
-# basin_area = 2055.56
-basin_area = 1992.62
-data[variables[-1]] = (
-    data[["streamflow(m3/s)"]].values
-    # * ft3tom3
-    / (basin_area * km2tom2)
-    * mtomm
-    * hourtos
-)
-df = data[variables]
-hydro_utils.serialize_numpy(np.expand_dims(df.values, axis=1), npy_file)
-
-
-# def test_load_data(txt_file, npy_file):
-#     data_ = pd.read_csv(txt_file)
-#     df = data_[["prcp(mm/day)", "petfao56(mm/day)"]]
-#     data = hydro_utils.unserialize_numpy(npy_file)[:, :, :2]
-#     np.testing.assert_array_equal(data, np.expand_dims(df.values, axis=1))
-
-
-# start_train = datetime(2014, 5, 1, 1)
-# end_train = datetime(2020, 1, 1, 7)
-# start_test = datetime(2020, 1, 1, 8)
-# end_test = datetime(2021, 10, 11, 23)
-# train_period = ["2014-05-01 09:00:00", "2019-01-01 08:00:00"]
-test_period = ["2019-01-01 07:00:00", "2021-10-12 09:00:00"]
-# test_period = ["2019-01-01 08:00:00", "2021-10-11 09:00:00"]
-train_period = ["2014-05-01 09:00:00", "2019-01-01 07:00:00"]
-period = ["2014-05-01 09:00:00", "2021-10-12 09:00:00"]
-cv_fold = 1
-warmup_length = 365
-
-# if not (cv_fold > 1):
-#     # no cross validation
-#     periods = np.sort(
-#         [train_period[0], train_period[1], test_period[0], test_period[1]]
-#     )
-#     print(periods)
-if cv_fold > 1:
-    cross_valid_data(json_file, npy_file, period, warmup_length, cv_fold)
-else:
-    split_train_test(json_file, npy_file, train_period, test_period)
+def test_load_dataset():
+    dataset_dir = SETTING["local_data_path"]["datasets-origin"]
+    camels = Camels(dataset_dir)
+    data = camels.read_ts_xrdataset(
+        ["01013500"], ["2014-05-01 09:00:00", "2019-01-01 07:00:00"], "streamflow"
+    )
+    print(data)
 
 
-kfold = [
-    int(f_name[len("data_info_fold") : -len("_test.json")])
-    for f_name in os.listdir(os.path.dirname(txt_file))
-    if fnmatch.fnmatch(f_name, "*_fold*_test.json")
-]
-kfold = np.sort(kfold)
-for fold in kfold:
-    print(f"Start to calibrate the {fold}-th fold")
-    train_data_info_file = os.path.join(
-        os.path.dirname(forcing_data), f"data_info_fold{str(fold)}_train.json"
-    )
-    train_data_file = os.path.join(
-        os.path.dirname(forcing_data), f"data_info_fold{str(fold)}_train.npy"
-    )
-    test_data_info_file = os.path.join(
-        os.path.dirname(forcing_data), f"data_info_fold{str(fold)}_test.json"
-    )
-    test_data_file = os.path.join(
-        os.path.dirname(forcing_data), f"data_info_fold{str(fold)}_test.npy"
-    )
-    if (
-        os.path.exists(train_data_info_file) is False
-        or os.path.exists(train_data_file) is False
-        or os.path.exists(test_data_info_file) is False
-        or os.path.exists(test_data_file) is False
-    ):
-        raise FileNotFoundError(
-            "The data files are not found, please run datapreprocess4calibrate.py first."
-        )
-    data_train = hydro_utils.unserialize_numpy(train_data_file)
-    print(data_train.shape)
-    data_test = hydro_utils.unserialize_numpy(test_data_file)
-    data_info_train = hydro_utils.unserialize_json_ordered(train_data_info_file)
-    data_info_test = hydro_utils.unserialize_json_ordered(test_data_info_file)
-    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
-    # one directory for one model + one hyperparam setting and one basin
-    save_dir = os.path.join(
-        os.path.dirname(forcing_data),
-        current_time + "_" + socket.gethostname() + "_fold" + str(fold),
-    )
+def test_read_your_own_data():
+    pass
diff --git a/test/test_gr4j.py b/test/test_gr4j.py
index 93c7175..ec9ad97 100644
--- a/test/test_gr4j.py
+++ b/test/test_gr4j.py
@@ -1,23 +1,20 @@
 """
 Author: Wenyu Ouyang
 Date: 2023-06-02 09:30:36
-LastEditTime: 2023-06-03 10:41:48
+LastEditTime: 2024-03-22 09:30:05
 LastEditors: Wenyu Ouyang
-Description: 
-FilePath: /hydro-model-xaj/test/test_gr4j.py
+Description: Test case for GR4J model
+FilePath: \hydro-model-xaj\test\test_gr4j.py
 Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
 """
+
 import os
 
 import numpy as np
 import pandas as pd
 import pytest
-import spotpy
-from matplotlib import pyplot as plt
-import definitions
-from hydromodel.trainers.calibrate_sceua import calibrate_by_sceua, SpotSetup
+from hydromodel import SETTING
 from hydromodel.models.gr4j import gr4j
-from hydromodel.trainers.plots import show_calibrate_result
 
 
 @pytest.fixture()
@@ -34,7 +31,7 @@ def warmup_length():
 
 @pytest.fixture()
 def the_data():
-    root_dir = definitions.ROOT_DIR
+    root_dir = SETTING["local_data_path"]["datasets-origin"]
     # test_data = pd.read_csv(os.path.join(root_dir, "hydromodel", "example", '01013500_lump_p_pe_q.txt'))
     return pd.read_csv(
         os.path.join(root_dir, "hydromodel", "example", "hymod_input.csv"), sep=";"
diff --git a/test/test_hydromodel.py b/test/test_hydromodel.py
deleted file mode 100644
index 509744b..0000000
--- a/test/test_hydromodel.py
+++ /dev/null
@@ -1,24 +0,0 @@
-#!/usr/bin/env python
-
-"""Tests for `hydromodel` package."""
-
-import pytest
-
-
-from hydromodel import hydromodel
-
-
-@pytest.fixture
-def response():
-    """Sample pytest fixture.
-
-    See more at: http://doc.pytest.org/en/latest/fixture.html
-    """
-    # import requests
-    # return requests.get('https://github.com/audreyr/cookiecutter-pypackage')
-
-
-def test_content(response):
-    """Sample pytest test function with the pytest fixture as an argument."""
-    # from bs4 import BeautifulSoup
-    # assert 'GitHub' in BeautifulSoup(response.content).title.string
diff --git a/test/test_hymod.py b/test/test_hymod.py
index 845f54f..0507e18 100644
--- a/test/test_hymod.py
+++ b/test/test_hymod.py
@@ -3,22 +3,19 @@
 Date: 2023-06-02 09:30:36
 LastEditTime: 2023-06-03 10:42:33
 LastEditors: Wenyu Ouyang
-Description: 
+Description: Test case for HYMOD model
 FilePath: /hydro-model-xaj/test/test_hymod.py
 Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
 """
+
 import os
 
-import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 import pytest
-import spotpy
 
-import definitions
-from hydromodel.trainers.calibrate_sceua import calibrate_by_sceua, SpotSetup
+from hydromodel import SETTING
 from hydromodel.models.hymod import hymod
-from hydromodel.trainers.plots import show_calibrate_result
 
 
 @pytest.fixture()
@@ -30,7 +27,7 @@ def basin_area():
 
 @pytest.fixture()
 def the_data():
-    root_dir = definitions.ROOT_DIR
+    root_dir = SETTING["local_data_path"]["datasets-origin"]
     return pd.read_csv(
         os.path.join(root_dir, "hydromodel", "example", "hymod_input.csv"), sep=";"
     )
diff --git a/test/test_rr_event_iden.py b/test/test_rr_event_iden.py
index 46d9ec1..d4246e9 100644
--- a/test/test_rr_event_iden.py
+++ b/test/test_rr_event_iden.py
@@ -1,30 +1,37 @@
 """
 Author: Wenyu Ouyang
 Date: 2023-10-28 09:23:22
-LastEditTime: 2024-02-12 16:17:26
+LastEditTime: 2024-03-22 09:32:32
 LastEditors: Wenyu Ouyang
 Description: Test for rainfall-runoff event identification
-FilePath: \hydromodel\test\test_rr_event_iden.py
+FilePath: \hydro-model-xaj\test\test_rr_event_iden.py
 Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
 """
 
 import os
 import pandas as pd
-import definitions
+
+from hydromodel import SETTING
 from hydromodel.datasets.dmca_esr import rainfall_runoff_event_identify
 
 
 def test_rainfall_runoff_event_identify():
     rain = pd.read_csv(
         os.path.join(
-            definitions.ROOT_DIR, "hydromodel", "example", "daily_rainfall_27071.txt"
+            SETTING["local_data_path"]["root"],
+            "hydromodel",
+            "example",
+            "daily_rainfall_27071.txt",
         ),
         header=None,
         sep="\\s+",
     )
     flow = pd.read_csv(
         os.path.join(
-            definitions.ROOT_DIR, "hydromodel", "example", "daily_flow_27071.txt"
+            SETTING["local_data_path"]["root"],
+            "hydromodel",
+            "example",
+            "daily_flow_27071.txt",
         ),
         header=None,
         sep="\\s+",
diff --git a/test/test_xaj.py b/test/test_xaj.py
index 045e693..f4c3e05 100644
--- a/test/test_xaj.py
+++ b/test/test_xaj.py
@@ -6,11 +6,9 @@
 
 from hydroutils import hydro_time
 
-import definitions
+from hydromodel import SETTING
 from hydromodel.trainers.calibrate_sceua import calibrate_by_sceua
-from hydromodel.trainers.calibrate_ga import calibrate_by_ga
 from hydromodel.datasets.data_postprocess import read_save_sceua_calibrated_params
-from hydromodel.utils import units
 from hydromodel.trainers.plots import show_calibrate_result, show_test_result
 from hydromodel.models.xaj import xaj, uh_gamma, uh_conv
 
@@ -24,8 +22,7 @@ def basin_area():
 
 @pytest.fixture()
 def db_name():
-    db_name = os.path.join(definitions.ROOT_DIR, "test", "SCEUA_xaj_mz")
-    return db_name
+    return os.path.join(SETTING["local_data_path"]["root"], "test", "SCEUA_xaj_mz")
 
 
 @pytest.fixture()
@@ -35,7 +32,7 @@ def warmup_length():
 
 @pytest.fixture()
 def the_data():
-    root_dir = definitions.ROOT_DIR
+    root_dir = SETTING["local_data_path"]["root"]
     # test_data = pd.read_csv(os.path.join(root_dir, "hydromodel", "example", '01013500_lump_p_pe_q.txt'))
     return pd.read_csv(
         os.path.join(root_dir, "hydromodel", "example", "hymod_input.csv"), sep=";"
diff --git a/test/test_xaj_bmi.py b/test/test_xaj_bmi.py
index 0b8ba47..b1c1829 100644
--- a/test/test_xaj_bmi.py
+++ b/test/test_xaj_bmi.py
@@ -1,7 +1,5 @@
 import logging
 
-import definitions
-from hydromodel.models.configuration import read_config
 from hydromodel.models.xaj_bmi import xajBmi
 import pandas as pd
 import os
@@ -13,7 +11,6 @@
 
 from hydroutils import hydro_file
 
-from hydromodel.utils import units
 from hydromodel.datasets.data_preprocess import (
     cross_valid_data,
     split_train_test,