-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
aa70de8
commit c3c75a2
Showing
1 changed file
with
130 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,11 +5,13 @@ | |
to set up control and perturbation experiments, modify configuration files, | ||
and manage related utilities. | ||
Latest version: https://github.com/minghangli-uni/Expts_manager | ||
Latest version: https://github.com/COSIMA/om3-scripts/pull/34 | ||
Author: Minghang Li | ||
Email: [email protected] | ||
License: Apache 2.0 License http://www.apache.org/licenses/LICENSE-2.0.txt | ||
""" | ||
|
||
|
||
# =========================================================================== | ||
import os | ||
import sys | ||
|
@@ -118,6 +120,7 @@ def load_variables(self, yamlfile): | |
self.yamlfile = yamlfile | ||
self.indata = self._read_ryaml(yamlfile) | ||
|
||
self.model = self.indata["model"] | ||
self.utils_url = self.indata["utils_url"] | ||
self.utils_dir_name = self.indata["utils_dir_name"] | ||
self.utils_branch_name = self.indata["utils_branch_name"] | ||
|
@@ -236,6 +239,9 @@ def manage_ctrl_expt(self): | |
base_path = self.base_path | ||
ctrl_nruns = self.ctrl_nruns | ||
|
||
# access-om2 specific | ||
ocn_path = os.path.join(base_path, "ocean") | ||
|
||
if os.path.exists(base_path): | ||
print(f"Base path is already created and located at {base_path}") | ||
if not os.path.isfile(os.path.join(base_path, "config.yaml")): | ||
|
@@ -253,7 +259,10 @@ def manage_ctrl_expt(self): | |
|
||
# [optional] modify diag_table | ||
if self.diag_ctrl and self.diag_path: | ||
self._copy_diag_table(base_path) | ||
if self.model == "access-om3": | ||
self._copy_diag_table(base_path) | ||
elif self.model == "access-om2": | ||
self._copy_diag_table(ocn_path) | ||
|
||
# setup the control experiments | ||
self._setup_ctrl_expt() | ||
|
@@ -321,40 +330,52 @@ def _setup_ctrl_expt(self): | |
Updates configuration files (config.yaml, nuopc.runconfig etc), | ||
namelist and MOM_input for the control experiment if needed. | ||
""" | ||
for file_name in os.listdir(self.base_path): | ||
yaml_data = self.indata.get(file_name, None) | ||
if yaml_data: | ||
# Update parameters from namelists | ||
if file_name.endswith("_in") or file_name.endswith(".nml"): | ||
self._update_nml_params(self.base_path, yaml_data, file_name) | ||
|
||
# Update config entries from `nuopc.runconfig` and `config_yaml` | ||
if file_name in (("nuopc.runconfig", "config.yaml")): | ||
# for file_name in os.listdir(self.base_path): | ||
for root, dirs, files in os.walk(self.base_path): | ||
dirs[:] = [tmp_d for tmp_d in dirs if ".git" not in tmp_d] | ||
for f in files: | ||
if ".git" in f: | ||
continue | ||
file_name = os.path.relpath(os.path.join(root, f), self.base_path) | ||
yaml_data = self.indata.get(file_name, None) | ||
|
||
if yaml_data: | ||
# Update parameters from namelists | ||
if file_name.endswith("_in") or file_name.endswith(".nml"): | ||
self._update_nml_params(self.base_path, yaml_data, file_name) | ||
|
||
# Update config entries from `nuopc.runconfig` | ||
if file_name == "nuopc.runconfig": | ||
self._update_runconfig_params( | ||
self.base_path, yaml_data, file_name | ||
) | ||
elif file_name == "config.yaml": | ||
|
||
# Update config entries from `config_yaml` | ||
if file_name == "config.yaml": | ||
self._update_config_params(self.base_path, yaml_data, file_name) | ||
|
||
# Update parameters from `MOM_input` | ||
if file_name == "MOM_input": | ||
# parse existing MOM_input | ||
MOM_inputParser = self._parser_mom6_input( | ||
os.path.join(self.base_path, file_name) | ||
) | ||
param_dict = MOM_inputParser.param_dict # read parameter dictionary | ||
commt_dict = MOM_inputParser.commt_dict # read comment dictionary | ||
param_dict.update(yaml_data) | ||
# overwrite to the same `MOM_input` | ||
MOM_inputParser.writefile_MOM_input( | ||
os.path.join(self.base_path, file_name) | ||
) | ||
# Update and overwrite parameters from and into `MOM_input` | ||
if file_name == "MOM_input": | ||
# parse existing MOM_input | ||
MOM_inputParser = self._parser_mom6_input( | ||
os.path.join(self.base_path, file_name) | ||
) | ||
param_dict = ( | ||
MOM_inputParser.param_dict | ||
) # read parameter dictionary | ||
commt_dict = ( | ||
MOM_inputParser.commt_dict | ||
) # read comment dictionary | ||
param_dict.update(yaml_data) | ||
# overwrite to the same `MOM_input` | ||
MOM_inputParser.writefile_MOM_input( | ||
os.path.join(self.base_path, file_name) | ||
) | ||
|
||
# Update only coupling timestep from `nuopc.runseq` | ||
if file_name == "nuopc.runseq": | ||
nuopc_runseq_file = os.path.join(self.base_path, file_name) | ||
self._update_cpl_dt_nuopc_seq(nuopc_runseq_file, yaml_data) | ||
# Update only coupling timestep from `nuopc.runseq` | ||
if file_name == "nuopc.runseq": | ||
nuopc_runseq_file = os.path.join(self.base_path, file_name) | ||
self._update_cpl_dt_nuopc_seq(nuopc_runseq_file, yaml_data) | ||
|
||
def _check_and_commit_changes(self): | ||
""" | ||
|
@@ -399,7 +420,8 @@ def manage_perturb_expt(self): | |
] # main section, top level key that groups different namlists | ||
if not namelists: | ||
warnings.warn( | ||
"NO namelists were provided, hence there are no parameter-tunning tests!" | ||
"NO namelists were provided, hence there are no parameter-tunning tests!", | ||
UserWarning, | ||
) | ||
return | ||
|
||
|
@@ -685,11 +707,13 @@ def _generate_combined_dicts(self, name_dict, commt_dict, k_sub, parameter_block | |
param_dict_change_list = [] | ||
append_group_list = [] | ||
for i in range(self.num_expts): | ||
name_dict = self._preprocess_nested_dicts(name_dict) | ||
param_dict_change = {k: name_dict[k][i] for k in name_dict} | ||
append_group = k_sub | ||
append_group_list.append(append_group) | ||
param_dict_change_list.append(param_dict_change) | ||
self.param_dict_change_list = param_dict_change_list | ||
|
||
if self.tag_model == "mom6" or parameter_block == "MOM_input": | ||
self.commt_dict_change = {k: commt_dict.get(k, "") for k in name_dict} | ||
elif ( | ||
|
@@ -699,6 +723,29 @@ def _generate_combined_dicts(self, name_dict, commt_dict, k_sub, parameter_block | |
): | ||
self.append_group_list = append_group_list | ||
|
||
def _preprocess_nested_dicts(self, input_data): | ||
""" | ||
Pre-processes nested dictionary with lists. | ||
""" | ||
res_dicts = {} | ||
for tmp_key, tmp_values in input_data.items(): | ||
if isinstance(tmp_values, list) and all( | ||
isinstance(v, dict) for v in tmp_values | ||
): | ||
res_dicts[tmp_key] = [] | ||
num_entries = len(next(iter(tmp_values[0].values()))) | ||
for i in range(num_entries): | ||
entry_list = [] | ||
for submodel in tmp_values: | ||
entry = {} | ||
for k, v in submodel.items(): | ||
entry[k] = v[i] | ||
entry_list.append(entry) | ||
res_dicts[tmp_key].append(entry_list) | ||
else: | ||
res_dicts[tmp_key] = tmp_values | ||
return res_dicts | ||
|
||
def setup_expts(self, parameter_block): | ||
""" | ||
Sets up perturbation experiments based on the YAML input file provided in `Expts_manager.yaml`. | ||
|
@@ -728,13 +775,14 @@ def setup_expts(self, parameter_block): | |
restartpath = self._generate_restart_symlink(expt_path) | ||
self._update_metadata_yaml_perturb(expt_path, param_dict, restartpath) | ||
|
||
# # only update perturbation jobname [TODO: put somewhere else] | ||
# self._update_config_yaml_perturb(expt_path, expt_name) | ||
|
||
# optionally update nuopc.runconfig for perturbation runs | ||
# if there is no parameter tunning under cb or runconfig flags! | ||
if self.tag_model not in (("cb", "runconfig")): | ||
self._update_nuopc_config_perturb(expt_path) | ||
|
||
# update jobname same as perturbation experiment name | ||
self._update_perturb_jobname(expt_path, expt_name) | ||
|
||
# update params for each parameter block | ||
if self.tag_model == "mom6" or parameter_block == "MOM_input": | ||
self._update_mom6_params(expt_path, param_dict) | ||
|
@@ -757,6 +805,7 @@ def setup_expts(self, parameter_block): | |
duplicated_bool = self._check_duplicated_jobs(pbs_jobs, expt_path) | ||
else: | ||
duplicated_bool = False | ||
|
||
# start runs, count existing runs and do additional runs if needed | ||
self._start_experiment_runs( | ||
expt_path, expt_name, duplicated_bool, self.nruns | ||
|
@@ -842,11 +891,42 @@ def _update_nml_params(self, expt_path, param_dict, parameter_block, indx=None): | |
patch_dict[nml_group]["sinw"] = sinw | ||
else: # for generic parameters | ||
patch_dict[nml_group][nml_name] = nml_value | ||
f90nml.patch(nml_path, patch_dict, nml_path + "_tmp") | ||
param_dict = patch_dict | ||
f90nml.patch(nml_path, param_dict, nml_path + "_tmp") | ||
else: | ||
f90nml.patch(nml_path, param_dict, nml_path + "_tmp") | ||
os.rename(nml_path + "_tmp", nml_path) | ||
|
||
self._format_nml_params(nml_path, param_dict) | ||
|
||
def _format_nml_params(self, nml_path, param_dict): | ||
""" | ||
Handles pre-formatted strings or values. | ||
Args: | ||
nml_path (str): The path to specific f90 namelist file. | ||
param_dict (dict): The dictionary of parameters to update. | ||
e.g., in yaml input file, | ||
ocean/input.nml: | ||
mom_oasis3_interface_nml: | ||
fields_in: "'u_flux', 'v_flux', 'lprec'" | ||
fields_out: "'t_surf', 's_surf', 'u_surf'" | ||
results in, | ||
&mom_oasis3_interface_nml | ||
fields_in = 'u_flux', 'v_flux', 'lprec' | ||
fields_out = 't_surf', 's_surf', 'u_surf' | ||
""" | ||
with open(nml_path, "r") as f: | ||
fileread = f.readlines() | ||
for tmp_group, tmp_subgroups in param_dict.items(): | ||
for tmp_param, tmp_values in tmp_subgroups.items(): | ||
for i in range(len(fileread)): | ||
if tmp_param in fileread[i]: | ||
fileread[i] = f" {tmp_param} = {tmp_values}\n" | ||
break | ||
with open(nml_path, "w") as f: | ||
f.writelines(fileread) | ||
|
||
def _update_config_params(self, expt_path, param_dict, parameter_block): | ||
""" | ||
Updates namelist parameters and overwrites namelist file. | ||
|
@@ -922,8 +1002,8 @@ def _generate_restart_symlink(self, expt_path): | |
existing_restarts = glob.glob( | ||
os.path.join(self.base_path, "archive", "restart*") | ||
) | ||
if existing_restarts and not self.force_restart: | ||
return | ||
# if existing_restarts and not self.force_restart: | ||
# return | ||
|
||
link_restart = os.path.join("archive", "restart" + self.startfrom_str) | ||
# restart dir from control experiment | ||
|
@@ -956,6 +1036,18 @@ def _update_nuopc_config_perturb(self, path): | |
self._update_config_entries(nuopc_runconfig, nuopc_input) | ||
self.write_nuopc_config(nuopc_runconfig, nuopc_file_path) | ||
|
||
def _update_perturb_jobname(self, expt_path, expt_name): | ||
""" | ||
Updates `jobname` only for now. | ||
Args: | ||
expt_path (str): The path to the perturbation experiment directory. | ||
expt_name (str): The name of the perturbation experiment. | ||
""" | ||
config_path = os.path.join(expt_path, "config.yaml") | ||
config_data = self._read_ryaml(config_path) | ||
config_data["jobname"] = expt_name | ||
self._write_ryaml(config_data, config_path) | ||
|
||
def _update_metadata_yaml_perturb(self, expt_path, param_dict, restartpath): | ||
""" | ||
Updates the `metadata.yaml` file with relevant metadata. | ||
|
@@ -1279,8 +1371,8 @@ def main(self): | |
Main function for the program. | ||
""" | ||
parser = argparse.ArgumentParser( | ||
description="Manage ACCESS-OM3 experiments.\ | ||
Latest version and help: https://github.com/minghangli-uni/Expts_manager" | ||
description="Manage ACCESS-OM2 or ACCESS-OM3 experiments.\ | ||
Latest version and help: https://github.com/COSIMA/om3-scripts/pull/34" | ||
) | ||
parser.add_argument( | ||
"INPUT_YAML", | ||
|