Skip to content

Commit

Permalink
f
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Dec 18, 2024
1 parent ba346e0 commit f8b6e8b
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from collections.abc import Callable
from dataclasses import dataclass
from enum import IntEnum
from itertools import count
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, Any, Protocol
Expand Down Expand Up @@ -130,12 +129,14 @@ def __init__(
self._display_all_jobs = display_all_jobs
self._result: OptimalResult | None = None
self._exit_code: EverestExitCode | None = None
self._evaluator_cache: _EvaluatorCache | None = None
if (
everest_config.simulator is not None
and everest_config.simulator.enable_cache
):
self._evaluator_cache = _EvaluatorCache()
self._evaluator_cache: _EvaluatorCache | None = (
_EvaluatorCache()
if (
everest_config.simulator is not None
and everest_config.simulator.enable_cache
)
else None
)
self._experiment: Experiment | None = None
self.eval_server_cfg: EvaluatorServerConfig | None = None
storage = open_storage(config.ens_path, mode="w")
Expand Down Expand Up @@ -773,12 +774,14 @@ class _EvaluatorCache:
EPS = float(np.finfo(np.float32).eps)

def __init__(self) -> None:
self._objectives: dict[int, NDArray[np.float64]] = {}
self._constraints: dict[int, NDArray[np.float64] | None] = {}
self._keys: defaultdict[int, list[tuple[NDArray[np.float64], int]]] = (
defaultdict(list)
)
self._counter = count()
self._data: defaultdict[
int,
list[
tuple[
NDArray[np.float64], NDArray[np.float64], NDArray[np.float64] | None
]
],
] = defaultdict(list)

def add(
self,
Expand All @@ -787,15 +790,20 @@ def add(
objectives: NDArray[np.float64],
constraints: NDArray[np.float64] | None,
) -> None:
key = next(self._counter)
self._keys[realization_id].append((control_values.copy(), key))
self._objectives[key] = objectives.copy()
self._constraints[key] = None if constraints is None else constraints.copy()
self._data[realization_id].append(
(
control_values.copy(),
objectives.copy(),
None if constraints is None else constraints.copy(),
),
)

def get(
self, realization_id: int, controls: NDArray[np.float64]
) -> tuple[Any, ...] | None:
for cached_result, key in self._keys.get(realization_id, []):
if np.allclose(controls, cached_result, rtol=0.0, atol=self.EPS):
return self._objectives[key], self._constraints[key]
for control_values, objectives, constraints in self._data.get(
realization_id, []
):
if np.allclose(controls, control_values, rtol=0.0, atol=self.EPS):
return objectives, constraints
return None

0 comments on commit f8b6e8b

Please sign in to comment.