Skip to content

Commit

Permalink
Update typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jacob-evarts committed Dec 13, 2023
1 parent 3825ea2 commit d41151d
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/config_utils/assign_hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def fix_types(df: pd.DataFrame, types: dict[str, str]) -> pd.DataFrame:


def include_permutations(
list_of_lists: list[list[str]], param_names: list[str], df: pd.DataFrame
list_of_lists: tuple[Any, ...], param_names: tuple[Any, ...], df: pd.DataFrame
) -> pd.DataFrame:
"""
Permutes discrete hyperparameter values before adding them to hyperparameter samples dataframe
Expand All @@ -107,7 +107,9 @@ def include_permutations(
df["temp"] = pd.Series([permutations] * len(df))
temp_df = df.explode("temp")
temp_df.reset_index(inplace=True, drop=True)
df_with_permutations = temp_df.join(pd.DataFrame([*temp_df.temp], temp_df.index, param_names))
df_with_permutations = temp_df.join(
pd.DataFrame([*temp_df.temp], temp_df.index, param_names)
)
df_with_permutations.drop(columns=["temp"], inplace=True)
return df_with_permutations

Expand All @@ -120,7 +122,9 @@ def check_list_lengths(l1: list[Any], l2: list[Any]) -> None:
raise ValueError("Lists must be the same length.") from exc


def add_constant_params(names: list[str], values: list[Any], df: pd.DataFrame) -> pd.DataFrame:
def add_constant_params(
names: tuple[Any, ...], values: tuple[Any, ...], df: pd.DataFrame
) -> pd.DataFrame:
"""Adds any constant parameter values to the hyperparameter dataframe"""
for name, value in zip(names, values):
df[name] = value
Expand Down Expand Up @@ -175,7 +179,9 @@ def _handle_continuous_config(param_cfg: DictConfig, sobol_power: int) -> pd.Dat
return fixed_df


def _handle_discrete_config(param_cfg: DictConfig, hparam_df: pd.DataFrame) -> pd.DataFrame:
def _handle_discrete_config(
param_cfg: DictConfig, hparam_df: pd.DataFrame
) -> pd.DataFrame:
"""
Reads in discrete parameters from config file and appends permutations
of them to the hyperparameter dataframe
Expand All @@ -195,7 +201,9 @@ def _handle_discrete_config(param_cfg: DictConfig, hparam_df: pd.DataFrame) -> p
return with_perm_df


def _handle_static_config(param_cfg: DictConfig, hparam_df: pd.DataFrame) -> pd.DataFrame:
def _handle_static_config(
param_cfg: DictConfig, hparam_df: pd.DataFrame
) -> pd.DataFrame:
"""
Reads in static parameters from config file and appends them to the hyperparameter dataframe
Expand Down

0 comments on commit d41151d

Please sign in to comment.