diff --git a/src/config_utils/assign_hyperparameters.py b/src/config_utils/assign_hyperparameters.py index 88ee308..4f25e6d 100644 --- a/src/config_utils/assign_hyperparameters.py +++ b/src/config_utils/assign_hyperparameters.py @@ -84,7 +84,7 @@ def fix_types(df: pd.DataFrame, types: dict[str, str]) -> pd.DataFrame: def include_permutations( - list_of_lists: list[list[str]], param_names: list[str], df: pd.DataFrame + list_of_lists: tuple[Any, ...], param_names: tuple[Any, ...], df: pd.DataFrame ) -> pd.DataFrame: """ Permutes discrete hyperparameter values before adding them to hyperparameter samples dataframe @@ -107,7 +107,9 @@ def include_permutations( df["temp"] = pd.Series([permutations] * len(df)) temp_df = df.explode("temp") temp_df.reset_index(inplace=True, drop=True) - df_with_permutations = temp_df.join(pd.DataFrame([*temp_df.temp], temp_df.index, param_names)) + df_with_permutations = temp_df.join( + pd.DataFrame([*temp_df.temp], temp_df.index, param_names) + ) df_with_permutations.drop(columns=["temp"], inplace=True) return df_with_permutations @@ -120,7 +122,9 @@ def check_list_lengths(l1: list[Any], l2: list[Any]) -> None: raise ValueError("Lists must be the same length.") from exc -def add_constant_params(names: list[str], values: list[Any], df: pd.DataFrame) -> pd.DataFrame: +def add_constant_params( + names: tuple[Any, ...], values: tuple[Any, ...], df: pd.DataFrame +) -> pd.DataFrame: """Adds any constant parameter values to the hyperparameter dataframe""" for name, value in zip(names, values): df[name] = value @@ -175,7 +179,9 @@ def _handle_continuous_config(param_cfg: DictConfig, sobol_power: int) -> pd.Dat return fixed_df -def _handle_discrete_config(param_cfg: DictConfig, hparam_df: pd.DataFrame) -> pd.DataFrame: +def _handle_discrete_config( + param_cfg: DictConfig, hparam_df: pd.DataFrame +) -> pd.DataFrame: """ Reads in discrete parameters from config file and appends permutations of them to the hyperparameter dataframe @@ -195,7 +201,9 @@ def _handle_discrete_config(param_cfg: DictConfig, hparam_df: pd.DataFrame) -> p return with_perm_df -def _handle_static_config(param_cfg: DictConfig, hparam_df: pd.DataFrame) -> pd.DataFrame: +def _handle_static_config( + param_cfg: DictConfig, hparam_df: pd.DataFrame +) -> pd.DataFrame: """ Reads in static parameters from config file and appends them to the hyperparameter dataframe