Skip to content

Commit

Permalink
typing: Fix up some files, ignore others
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed May 2, 2024
1 parent 9ced118 commit 072bbb5
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 80 deletions.
2 changes: 1 addition & 1 deletion neps/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
if patience < 1:
raise ValueError("Patience should be at least 1")

self.used_budget = 0
self.used_budget: float = 0.0
self.budget = budget
self.pipeline_space = pipeline_space
self.patience = patience
Expand Down
2 changes: 1 addition & 1 deletion neps/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def plot( # noqa: C901, PLR0913
costs.append(cost)
max_costs.append(max_cost)

is_last_row = lambda idx: idx >= (nrows - 1) * ncols
is_last_row = benchmark_idx >= (nrows - 1) * ncols
is_first_column = benchmark_idx % ncols == 0
xlabel = "Evaluations" if key_to_extract is None else key_to_extract.upper()
_plot_incumbent(
Expand Down
7 changes: 4 additions & 3 deletions neps/plot/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Any

import matplotlib.axes
import matplotlib.figure
Expand Down Expand Up @@ -80,7 +81,7 @@ def _plot_incumbent(
log_x: bool = False,
log_y: bool = False,
x_range: tuple | None = None,
**plotting_kwargs,
**plotting_kwargs: Any,
) -> None:
df = _interpolate_time(incumbents=y, costs=x, x_range=x_range, scale_x=scale_x)
df = _df_to_x_range(df, x_range=x_range)
Expand Down Expand Up @@ -134,7 +135,7 @@ def _interpolate_time(
df = pd.DataFrame.from_dict(df_dict)

# important step to plot func evals on x-axis
df.index = df.index if scale_x is None else df.index.to_numpy() / scale_x # type: ignore
df.index = df.index if scale_x is None else df.index.to_numpy() / scale_x

if x_range is not None:
min_b, max_b = x_range
Expand Down Expand Up @@ -193,7 +194,7 @@ def _set_legend(
frameon=True,
)

for legend_item in legend.legendHandles: # type: ignore
for legend_item in legend.legend_handles:
legend_item.set_linewidth(2.0)


Expand Down
4 changes: 2 additions & 2 deletions neps/plot/read_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def process_seed(

global_start = stats[min(stats.keys())].metadata["time_sampled"]

def get_cost(idx):
def get_cost(idx: str) -> float:
if key_to_extract is not None:
return stats[idx].result["info_dict"][key_to_extract]
return float(stats[idx].result["info_dict"][key_to_extract])
return 1.0

losses = []
Expand Down
16 changes: 12 additions & 4 deletions neps/plot/tensorboard_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
from pathlib import Path
from typing import Any, ClassVar, Mapping
from typing_extensions import override

import numpy as np
import torch
Expand Down Expand Up @@ -33,7 +34,13 @@ class SummaryWriter_(SummaryWriter): # noqa: N801
metrics with better formatting.
"""

def add_hparams(self, hparam_dict: dict, metric_dict: dict, global_step: int) -> None:
@override
def add_hparams( # type: ignore
self,
hparam_dict: dict[str, Any],
metric_dict: dict[str, Any],
global_step: int,
) -> None:
"""Add a set of hyperparameters to be logged."""
if not isinstance(hparam_dict, dict) or not isinstance(metric_dict, dict):
raise TypeError("hparam_dict and metric_dict should be dictionary.")
Expand All @@ -51,7 +58,7 @@ def add_hparams(self, hparam_dict: dict, metric_dict: dict, global_step: int) ->
class tblogger: # noqa: N801
"""The tblogger class provides a simplified interface for logging to tensorboard."""

config_id: ClassVar[str | None] | None = None
config_id: ClassVar[str | None] = None
config: ClassVar[Mapping[str, Any] | None] = None
config_working_directory: ClassVar[Path | None] = None
optimizer_dir: ClassVar[Path | None] = None
Expand Down Expand Up @@ -321,10 +328,11 @@ def _write_image_config(
else:
if not isinstance(seed, np.random.RandomState):
seed = np.random.RandomState(seed)

# We do not interfere with any randomness from the pipeline
num_total_images = len(image)
indices = seed.choice(num_total_images, num_images, replace=False)
subset_images = image[indices]
subset_images = image[indices] # type: ignore

resized_images = torch.nn.functional.interpolate(
subset_images,
Expand Down Expand Up @@ -452,7 +460,7 @@ def enable() -> None:
tblogger.disable_logging = False

@staticmethod
def get_status():
def get_status() -> bool:
"""Returns the currect state of tblogger ie. whether the logger is
enabled or not.
"""
Expand Down
64 changes: 31 additions & 33 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@
from .optimizers.base_optimizer import BaseOptimizer

# Wait time between each successive poll to see if state can be grabbed
DEFAULT_STATE_POLL = 1
DEFAULT_STATE_POLL: float = 1.0
ENVIRON_STATE_POLL_KEY = "NEPS_STATE_POLL"

# Timeout before giving up on trying to grab the state, raising an error
DEFAULT_STATE_TIMEOUT = None
DEFAULT_STATE_TIMEOUT: float | None = None
ENVIRON_STATE_TIMEOUT_KEY = "NEPS_STATE_TIMEOUT"


Expand Down Expand Up @@ -119,7 +119,7 @@ class Trial:

disk: Trial.Disk = field(init=False)

def __post_init__(self):
def __post_init__(self) -> None:
if self.prev_config_id is not None:
self.metadata["previous_config_id"] = self.prev_config_id
self.disk = Trial.Disk(pipeline_dir=self.pipeline_dir)
Expand Down Expand Up @@ -158,7 +158,7 @@ class State(str, Enum):
CORRUPTED = "corrupted"
"""The trial is not in one of the previous states and should be removed."""

def __str__(self):
def __str__(self) -> str:
return self.value

@dataclass
Expand Down Expand Up @@ -189,7 +189,7 @@ class Disk:
previous_pipeline_dir: Path | None = field(init=False)
lock: Locker = field(init=False)

def __post_init__(self):
def __post_init__(self) -> None:
self.id = self.pipeline_dir.name[len("config_") :]
self.config_file = self.pipeline_dir / "config.yaml"
self.result_file = self.pipeline_dir / "result.yaml"
Expand Down Expand Up @@ -278,12 +278,11 @@ def to_result(
config = deserialize(self.config_file)
result = deserialize(self.result_file)
metadata = deserialize(self.metadata_file)
if config_transform is not None:
config = config_transform(config)
_config = config_transform(config) if config_transform is not None else config

return ConfigResult(
id=self.id,
config=config,
config=_config,
result=result,
metadata=metadata,
)
Expand Down Expand Up @@ -371,7 +370,7 @@ def _evaluate_config(
trial: Trial,
evaluation_fn: Callable[..., float | Mapping[str, Any]],
logger: logging.Logger,
) -> tuple[ERROR | dict, float]:
) -> tuple[ERROR | dict[str, Any], float]:
config = trial.config
config_id = trial.id
pipeline_directory = trial.disk.pipeline_dir
Expand All @@ -383,7 +382,7 @@ def _evaluate_config(

# If pipeline_directory and previous_pipeline_directory are included in the
# signature we supply their values, otherwise we simply do nothing.
directory_params = []
directory_params: list[Path | None] = []

evaluation_fn_params = inspect.signature(evaluation_fn).parameters
if "pipeline_directory" in evaluation_fn_params:
Expand All @@ -392,29 +391,29 @@ def _evaluate_config(
directory_params.append(previous_pipeline_directory)

try:
result = evaluation_fn(*directory_params, **config)
eval_result = evaluation_fn(*directory_params, **config)
except Exception as e:
logger.error(f"An error occured evaluating config '{config_id}': {config}.")
logger.exception(e)
result = "error"
return "error", time.time()

# Ensure the results have correct format that can be exploited by other functions
result: dict[str, Any] = {}
if isinstance(eval_result, Mapping):
result = dict(eval_result)
if "loss" not in result:
raise KeyError("The 'loss' should be provided in the evaluation result")
loss = result["loss"]
else:
# Ensure the results have correct format that can be exploited by other functions
if isinstance(result, Mapping):
result = dict(result)
if "loss" not in result:
raise KeyError("The 'loss' should be provided in the evaluation result")
loss = result["loss"]
else:
loss = result
result = {}
loss = eval_result

try:
result["loss"] = float(loss)
except (TypeError, ValueError) as e:
raise ValueError(
"The evaluation result should be a dictionnary or a float but got"
f" a `{type(loss)}` with value of {loss}",
) from e
try:
result["loss"] = float(loss)
except (TypeError, ValueError) as e:
raise ValueError(
"The evaluation result should be a dictionnary or a float but got"
f" a `{type(loss)}` with value of {loss}",
) from e

time_end = time.time()
return result, time_end
Expand Down Expand Up @@ -522,10 +521,9 @@ def launch_runtime( # noqa: PLR0913, C901, PLR0915

_poll = float(os.environ.get(ENVIRON_STATE_POLL_KEY, DEFAULT_STATE_POLL))
_timeout = os.environ.get(ENVIRON_STATE_TIMEOUT_KEY, DEFAULT_STATE_TIMEOUT)
if _timeout is not None:
_timeout = float(_timeout)
timeout = float(_timeout) if _timeout is not None else None

with shared_state.lock(poll=_poll, timeout=_timeout):
with shared_state.lock(poll=_poll, timeout=timeout):
if not shared_state.optimizer_info_file.exists():
serialize(optimizer_info, shared_state.optimizer_info_file, sort_keys=False)
else:
Expand All @@ -540,7 +538,7 @@ def launch_runtime( # noqa: PLR0913, C901, PLR0915
logger.info("Maximum evaluations per run is reached, shutting down")
break

with shared_state.lock(poll=_poll, timeout=_timeout):
with shared_state.lock(poll=_poll, timeout=timeout):
refs = shared_state.trial_refs()

_try_remove_corrupted_configs(refs[Trial.State.CORRUPTED], logger)
Expand Down Expand Up @@ -644,7 +642,7 @@ def launch_runtime( # noqa: PLR0913, C901, PLR0915
trial.results = result
trial.metadata.update(meta)

with shared_state.lock(poll=_poll, timeout=_timeout):
with shared_state.lock(poll=_poll, timeout=timeout):
trial.write_to_disk()
if account_for_cost:
assert eval_cost is not None
Expand Down
21 changes: 13 additions & 8 deletions neps/search_spaces/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import deepcopy
from itertools import product
from pathlib import Path
from typing import Mapping, Any
from typing import Mapping, Any, Iterator

import ConfigSpace as CS
import numpy as np
Expand All @@ -22,7 +22,7 @@
NumericalParameter,
)
from .architecture.graph import Graph
from .parameter import Parameter
from neps.search_spaces.parameter import Parameter
from .yaml_search_space_utils import (
SearchSpaceFromYamlFileError,
deduce_and_validate_param_type,
Expand Down Expand Up @@ -191,7 +191,7 @@ def pipeline_space_from_yaml(

class SearchSpace(Mapping[str, Any]):
def __init__(self, **hyperparameters):
self.hyperparameters = OrderedDict()
self.hyperparameters: dict[str, Parameter] = {}

self.fidelity = None
self.has_prior = False
Expand Down Expand Up @@ -561,19 +561,24 @@ def set_hyperparameters_from_dict(
hp.lower = new_hp_value
hp.upper = new_hp_value

def __getitem__(self, key):
def __getitem__(self, key: str) -> Parameter:
return self.hyperparameters[key]

def __iter__(self):
def __iter__(self) -> Iterator[str]:
return iter(self.hyperparameters)

def __len__(self):
def __len__(self) -> int:
return len(self.hyperparameters)

def __str__(self):
def __str__(self) -> str:
return pprint.pformat(self.hyperparameters)

def is_equal_value(self, other, include_fidelity=True, on_decimal=8):
def is_equal_value(
self,
other: SearchSpace,
include_fidelity: bool = True,
on_decimal: int = 8,
) -> bool:
# This does NOT check that the entire SearchSpace is equal (and thus it is
# not a dunder method), but only checks the configuration
if self.hyperparameters.keys() != other.hyperparameters.keys():
Expand Down
2 changes: 1 addition & 1 deletion neps/status/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_summary_dict(
pending = [r.load() for r in trial_refs[Trial.State.PENDING]]
in_progress = [r.load() for r in trial_refs[Trial.State.IN_PROGRESS]]

summary = {}
summary: dict[str, Any] = {}

if add_details:
summary["previous_results"] = {c.id: c for c in evaluated}
Expand Down
3 changes: 2 additions & 1 deletion neps/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Primitive types to be used in NePS or consumers of NePS."""

from __future__ import annotations

import logging
Expand Down Expand Up @@ -55,7 +56,7 @@ class ConfigResult:
class AttrDict(dict):
"""Dictionary that allows access to keys as attributes."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
"""Initialize like a dict."""
super().__init__(*args, **kwargs)
self.__dict__ = self
Loading

0 comments on commit 072bbb5

Please sign in to comment.