From b2b1f11eced84b60495ed7dbb003c59092cce36a Mon Sep 17 00:00:00 2001 From: minghangli-uni Date: Mon, 23 Sep 2024 12:27:22 +1000 Subject: [PATCH] refactor cross-block --- expts_manager/Expts_manager.py | 329 ++++++++++++++++----------------- 1 file changed, 160 insertions(+), 169 deletions(-) diff --git a/expts_manager/Expts_manager.py b/expts_manager/Expts_manager.py index 31ce34c..fd137de 100755 --- a/expts_manager/Expts_manager.py +++ b/expts_manager/Expts_manager.py @@ -14,7 +14,6 @@ import os import sys import re -import copy import subprocess import shutil import glob @@ -156,6 +155,8 @@ def _initialise_variables(self): commt_dict_change (dict): Specific for MOM_input, dictionary of comments for parameters. append_group_list (list): Specific for f90nml, the list containing tunning parameters. expt_names list(str): Optional user-defined directory names for perturbation experiments. + tmp_count (int): count the number of parameter groups in a single parameter block in process. + group_count (int): total number of parameter groups in a single parameter block. """ self.nml_ctrl = None self.tag_model = None @@ -237,7 +238,7 @@ def manage_ctrl_expt(self): if os.path.exists(base_path): print(f"Base path is already created and located at {base_path}") - if self._count_file_nums() == 4: + if not os.path.isfile(os.path.join(base_path,"config.yaml")): print( "previous commit fails, please try with an updated commit hash for the control experiment!" ) @@ -321,72 +322,37 @@ def _setup_ctrl_expt(self): namelist and MOM_input for the control experiment if needed. """ for file_name in os.listdir(self.base_path): - # Update parameters from namelists - if file_name.endswith("_in") or file_name.endswith(".nml"): - yaml_data = self.indata.get(file_name, None) - if yaml_data: - if ( - "dynamics_nml" in yaml_data - and "turning_angle" in yaml_data["dynamics_nml"] - ): - cosw = np.cos( - yaml_data["dynamics_nml"]["turning_angle"] * np.pi / 180.0 + yaml_data = self.indata.get(file_name, None) + if yaml_data: + # Update parameters from namelists + if file_name.endswith("_in") or file_name.endswith(".nml"): + self._update_nml_params(self.base_path, yaml_data, file_name) + + # Update config entries from `nuopc.runconfig` and `config_yaml` + if file_name in (("nuopc.runconfig", "config.yaml")): + if file_name == "nuopc.runconfig": + self._update_runconfig_params(self.base_path, yaml_data, file_name) + elif file_name == "config.yaml": + self._update_config_params(self.base_path, yaml_data, file_name) + + # Update parameters from `MOM_input` + if file_name == "MOM_input": + # parse existing MOM_input + MOM_inputParser = self._parser_mom6_input( + os.path.join(self.base_path, file_name) ) - sinw = np.sin( - yaml_data["dynamics_nml"]["turning_angle"] * np.pi / 180.0 + param_dict = MOM_inputParser.param_dict # read parameter dictionary + commt_dict = MOM_inputParser.commt_dict # read comment dictionary + param_dict.update(yaml_data) + # overwrite to the same `MOM_input` + MOM_inputParser.writefile_MOM_input( + os.path.join(self.base_path, file_name) ) - yaml_data["dynamics_nml"]["cosw"] = cosw - yaml_data["dynamics_nml"]["sinw"] = sinw - del yaml_data["dynamics_nml"]["turning_angle"] - - # read existing namelist file from the control experiment - nml_ctrl = f90nml.read(os.path.join(self.base_path, file_name)) - # update the namelist with the YAML input file - self._update_config_entries(nml_ctrl, yaml_data) - # write the updated namelist back to the file - nml_ctrl.write(os.path.join(self.base_path, file_name), force=True) - - # Update config entries from `nuopc.runconfig` and `config_yaml` - if file_name in (("nuopc.runconfig", "config.yaml")): - yaml_data = self.indata.get(file_name, None) - if yaml_data: - tmp_file_path = os.path.join(self.base_path, file_name) - if file_name == "nuopc.runconfig": - file_read = self.read_nuopc_config(tmp_file_path) - self._update_config_entries(file_read, yaml_data) - self.write_nuopc_config(file_read, tmp_file_path) - elif file_name == "config.yaml": - file_read = self._read_ryaml(tmp_file_path) - yaml_data["jobname"] = self.base_dir_name - self._update_config_entries(file_read, yaml_data) - if yaml_data["jobname"] != self.base_dir_name: - raise ValueError( - f"jobname must be the same as {self.base_dir_name}!" - ) - self._write_ryaml(file_read, tmp_file_path) - - # Update parameters from `MOM_input` - if file_name == "MOM_input": - yaml_data = self.indata.get(file_name, None) - if yaml_data: - # parse existing MOM_input - MOM_inputParser = self._parser_mom6_input( - os.path.join(self.base_path, file_name) - ) - param_dict = MOM_inputParser.param_dict # read parameter dictionary - commt_dict = MOM_inputParser.commt_dict # read comment dictionary - param_dict.update(yaml_data) - # overwrite to the same `MOM_input` - MOM_inputParser.writefile_MOM_input( - os.path.join(self.base_path, file_name) - ) - # Update only coupling timestep from `nuopc.runseq` - if file_name == "nuopc.runseq": - yaml_data = self.indata.get("cpl_dt", None) - if yaml_data: - nuopc_runseq_file = os.path.join(self.base_path, file_name) - self._update_cpl_dt_nuopc_seq(nuopc_runseq_file, yaml_data) + # Update only coupling timestep from `nuopc.runseq` + if file_name == "nuopc.runseq": + nuopc_runseq_file = os.path.join(self.base_path, file_name) + self._update_cpl_dt_nuopc_seq(nuopc_runseq_file, yaml_data) def _check_and_commit_changes(self): """ @@ -438,54 +404,13 @@ def manage_perturb_expt(self): for k, nmls in namelists.items(): if not nmls: continue - # parameter tunning within one file + + # parameter tunning within one group in a single file if not k.startswith("cross_block"): self._process_params_blocks(k, nmls) - # parameter tunning across multiple files else: - self.tag_model, _ = self._determine_block_type(k) - self.group_count = self._count_second_level_keys(nmls) - self.tmp_count = 0 - # tmp_k => k (equivalent to `k`, when `cross_block` is disabled) - for tmp_k, tmp_nmls in namelists[k].items(): - if tmp_k == "cross_block_input": - self.expt_names = tmp_nmls # user-defined directories - self.num_expts = len(self.expt_names) # count dirs - else: - for k_sub in tmp_nmls: - self.tmp_count += 1 - name_dict = tmp_nmls[k_sub] - - if k_sub.endswith(self.combo_suffix): - if tmp_k.startswith("MOM_input"): - MOM_inputParser = self._parser_mom6_input( - os.path.join(self.base_path, "MOM_input") - ) - commt_dict = MOM_inputParser.commt_dict - else: - commt_dict = None - if name_dict is not None: - self._generate_combined_dicts( - name_dict, commt_dict, k_sub, tmp_k - ) - self.setup_expts(tmp_k) - - # reset user-defined dirs - self.expt_names = None - - def _count_second_level_keys(self, tmp_dict): - group_count = 0 - - for key, value in tmp_dict.items(): - # Skip 'cross_block_input' - if key == "cross_block_input": - continue - if isinstance(value, dict): - for inner_key, inner_value in value.items(): - if isinstance(inner_value, dict): - group_count += 1 - - return group_count + # parameter tunning across multiple files + self._process_params_blocks_cross_files(k, namelists) def _process_params_blocks(self, k, nmls): """ @@ -508,30 +433,106 @@ def _determine_block_type(self, k): Args: k (str): The key indicating the type of parameter block. """ - # parameter blocks, in which contains one or more groups of parameters, e.g., input.nml, ice_in etc. + # parameter blocks, in which contains one or more groups of parameters, + # e.g., input.nml, ice_in etc. if k.endswith(("_in", ".nml")): tag_model = "nml" - # [Optional] The key in the YAML file specifies a list of user-defined directory names related to parameter testing. - expt_dir_name = k[:-3] if k.endswith("_in") else k[:-4] elif k == "MOM_input": tag_model = "mom6" - expt_dir_name = k elif k == "nuopc.runseq": tag_model = "cpl_dt" - expt_dir_name = k[-6:] elif k == "config.yaml": tag_model = "config" - expt_dir_name = k[:6] + "_input" elif k == "nuopc.runconfig": tag_model = "runconfig" - expt_dir_name = k[-9:] + "_input" elif k.startswith("cross_block"): tag_model = "cb" - expt_dir_name = "cross_block" + "_input" else: raise ValueError(f"Unsupported block type: {k}") + # [Optional] The key in the YAML input file specifies a list of + # user-defined directory names related to parameter testing. + expt_dir_name = k + "_dirs" return tag_model, expt_dir_name + def _count_second_level_keys(self, tmp_dict, expt_dir_name): + """ + Counts the number of groups + """ + group_count = 0 + for key, value in tmp_dict.items(): + # skip the user-defined expt directory name + if key == expt_dir_name: + continue + if isinstance(value, dict): + for inner_key, inner_value in value.items(): + if isinstance(inner_value, dict): + group_count += 1 + return group_count + + def _process_params_blocks_cross_files(self, k, namelists): + """ + Determines the type of parameter block for cross-blocks and processes them accordingly. + + Args: + k (str): The key indicating the type of parameter block. + namelists (dict): The highest-level namelist dictionary. + """ + self.tag_model, expt_dir_name = self._determine_block_type(k) + self.group_count = self._count_second_level_keys(namelists[k], expt_dir_name) + self.tmp_count = 0 + + # tmp_k => k (equivalent to `k`, when `cross_block` is disabled) + # k: filename + for tmp_k, tmp_nmls in namelists[k].items(): + if tmp_k.startswith(expt_dir_name): + self._set_cross_block_dirs(tmp_nmls) + else: + self._handle_params_cross_files(tmp_k, tmp_nmls) + + # reset user-defined dirs + self._reset_expt_names() + + def _set_cross_block_dirs(self, tmp_nmls): + """ + Sets cross block directories and sets perturbation directory names. + """ + self.expt_names = tmp_nmls # user-defined directories + self.num_expts = len(self.expt_names) # count dirs + + def _handle_params_cross_files(self, tmp_k, tmp_nmls): + """ + Processes all parameters in the namelist. + """ + for k_sub in tmp_nmls: + self.tmp_count += 1 + name_dict = tmp_nmls[k_sub] + if k_sub.endswith(self.combo_suffix): + if tmp_k.startswith("MOM_input"): + MOM_inputParser = self._parser_mom6_input( + os.path.join(self.base_path, "MOM_input") + ) + commt_dict = MOM_inputParser.commt_dict + else: + commt_dict = None + if name_dict is not None: + self._generate_combined_dicts(name_dict, commt_dict, k_sub, tmp_k) + self.setup_expts(tmp_k) + + def _reset_expt_names(self): + """ + Resets user-defined perturbation experiment names. + """ + self.expt_names = None + + def _parser_mom6_input(self, path): + """ + Parses MOM6 input file. + """ + mom6parser = self.MOM6InputParser.MOM6InputParser() + mom6parser.read_input(path) + mom6parser.parse_lines() + return mom6parser + def _process_params_group(self, k, k_sub, nmls, expt_dir_name, tag_model): """ Processes individual parameter groups based on the tag model. @@ -566,7 +567,7 @@ def _handle_runconfig_group(self, k, k_sub, expt_dir_name, nmls): """ Handles config.yaml and nuopc.runconfig parameter groups specific to `config` tag model. """ - if not k_sub.startswith("runconfig_input"): + if not k_sub.startswith("nuopc.runconfig_dirs"): self._process_parameter_group_common(k, k_sub, nmls, expt_dir_name) def _handle_nml_group(self, k, k_sub, expt_dir_name, nmls): @@ -725,8 +726,8 @@ def setup_expts(self, parameter_block): restartpath = self._generate_restart_symlink(expt_path) self._update_metadata_yaml_perturb(expt_path, param_dict, restartpath) - # only update perturbation jobname [TODO: put somewhere else] - self._update_config_yaml_perturb(expt_path, expt_name) + # # only update perturbation jobname [TODO: put somewhere else] + # self._update_config_yaml_perturb(expt_path, expt_name) # optionally update nuopc.runconfig for perturbation runs if self.tag_model not in (("cb", "runconfig")): @@ -744,7 +745,7 @@ def setup_expts(self, parameter_block): elif self.tag_model == "cpl_dt" or parameter_block == "nuopc.runseq": self._update_cpl_dt_params(expt_path, param_dict, parameter_block) elif self.tag_model == "config" or parameter_block == "config.yaml": - self._update_config_params(expt_path, param_dict, parameter_block, i) + self._update_config_params(expt_path, param_dict, parameter_block) elif self.tag_model == "runconfig" or parameter_block == "nuopc.runconfig": self._update_runconfig_params(expt_path, param_dict, parameter_block, i) @@ -761,7 +762,7 @@ def setup_expts(self, parameter_block): if self.tag_model != "cb": # reset to None after the loop to update user-defined perturbation experiment names! - self.expt_names = None + self._reset_expt_names() def _generate_expt_names(self, indx): if self.expt_names is None: @@ -815,7 +816,7 @@ def _update_mom6_params(self, expt_path, param_dict): ) MOM6_or_parser.writefile_MOM_input(os.path.join(expt_path, "MOM_override")) - def _update_nml_params(self, expt_path, param_dict, parameter_block, indx): + def _update_nml_params(self, expt_path, param_dict, parameter_block, indx=None): """ Updates namelist parameters and overwrites namelist file. @@ -824,28 +825,27 @@ def _update_nml_params(self, expt_path, param_dict, parameter_block, indx): param_dict (dict): The dictionary of parameters to update. parameter_block (str): The name of the namelist file. """ - nml_path = os.path.join(expt_path, parameter_block) - nml_group = self.append_group_list[indx] - - # rename the namlist by removing the suffix if the suffix with `_combo` - if nml_group.endswith(self.combo_suffix): - nml_group = nml_group[: -len(self.combo_suffix)] - - patch_dict = {nml_group: {}} - for nml_name, nml_value in param_dict.items(): - if nml_name == "turning_angle": - cosw = np.cos(nml_value * np.pi / 180.0) - sinw = np.sin(nml_value * np.pi / 180.0) - patch_dict[nml_group]["cosw"] = cosw - patch_dict[nml_group]["sinw"] = sinw - else: # for generic parameters - patch_dict[nml_group][nml_name] = nml_value - - f90nml.patch(nml_path, patch_dict, nml_path + "_tmp") + if indx is not None: + nml_group = self.append_group_list[indx] + # rename the namlist by removing the suffix if the suffix with `_combo` + if nml_group.endswith(self.combo_suffix): + nml_group = nml_group[: -len(self.combo_suffix)] + patch_dict = {nml_group: {}} + for nml_name, nml_value in param_dict.items(): + if nml_name == "turning_angle": + cosw = np.cos(nml_value * np.pi / 180.0) + sinw = np.sin(nml_value * np.pi / 180.0) + patch_dict[nml_group]["cosw"] = cosw + patch_dict[nml_group]["sinw"] = sinw + else: # for generic parameters + patch_dict[nml_group][nml_name] = nml_value + f90nml.patch(nml_path, patch_dict, nml_path + "_tmp") + else: + f90nml.patch(nml_path, param_dict, nml_path + "_tmp") os.rename(nml_path + "_tmp", nml_path) - def _update_config_params(self, expt_path, param_dict, parameter_block, indx): + def _update_config_params(self, expt_path, param_dict, parameter_block): """ Updates namelist parameters and overwrites namelist file. @@ -854,18 +854,23 @@ def _update_config_params(self, expt_path, param_dict, parameter_block, indx): param_dict (dict): The dictionary of parameters to update. parameter_block (str): The name of the namelist file. """ - nml_path = os.path.join(expt_path, parameter_block) - nml_group = self.append_group_list[indx] + expt_name = os.path.basename(expt_path) - # rename the namlist by removing the suffix if the suffix with `_combo` - if nml_group.endswith(self.combo_suffix): - nml_group = nml_group[: -len(self.combo_suffix)] file_read = self._read_ryaml(nml_path) + if "jobname" in param_dict: + if param_dict["jobname"] != expt_name: + warnings.warn( + f"\n" + f"-- jobname must be the same as {expt_name}, " + f"hence jobname is forced to be {expt_name}!", + UserWarning + ) + param_dict["jobname"] = expt_name self._update_config_entries(file_read, param_dict) self._write_ryaml(file_read, nml_path) - def _update_runconfig_params(self, expt_path, param_dict, parameter_block, indx): + def _update_runconfig_params(self, expt_path, param_dict, parameter_block, indx=None): """ Updates namelist parameters and overwrites namelist file. @@ -874,14 +879,13 @@ def _update_runconfig_params(self, expt_path, param_dict, parameter_block, indx) param_dict (dict): The dictionary of parameters to update. parameter_block (str): The name of the namelist file. """ - nml_path = os.path.join(expt_path, parameter_block) - nml_group = self.append_group_list[indx] - - # rename the namlist by removing the suffix if the suffix with `_combo` - if nml_group.endswith(self.combo_suffix): - nml_group = nml_group[: -len(self.combo_suffix)] - param_dict = self.nested_dict(nml_group, param_dict) + if indx is not None: + nml_group = self.append_group_list[indx] + # rename the namlist by removing the suffix if the suffix with `_combo` + if nml_group.endswith(self.combo_suffix): + nml_group = nml_group[: -len(self.combo_suffix)] + param_dict = self.nested_dict(nml_group, param_dict) file_read = self.read_nuopc_config(nml_path) self._update_config_entries(file_read, param_dict) self.write_nuopc_config(file_read, nml_path) @@ -948,19 +952,6 @@ def _update_nuopc_config_perturb(self, path): self._update_config_entries(nuopc_runconfig, nuopc_input) self.write_nuopc_config(nuopc_runconfig, nuopc_file_path) - def _update_config_yaml_perturb(self, expt_path, expt_name): - """ - Updates `jobname` only for now. - - Args: - expt_path (str): The path to the perturbation experiment directory. - expt_name (str): The name of the perturbation experiment. - """ - config_path = os.path.join(expt_path, "config.yaml") - config_data = self._read_ryaml(config_path) - config_data["jobname"] = expt_name - self._write_ryaml(config_data, config_path) - def _update_metadata_yaml_perturb(self, expt_path, param_dict, restartpath): """ Updates the `metadata.yaml` file with relevant metadata.