Skip to content

Commit

Permalink
EverestRunModel: minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Dec 19, 2024
1 parent fdd85bf commit e6a8781
Showing 1 changed file with 37 additions and 33 deletions.
70 changes: 37 additions & 33 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,28 +346,30 @@ def _on_before_forward_model_evaluation(
optimizer.abort_optimization()

def _forward_model_evaluator(
self, control_values: NDArray[np.float64], metadata: EvaluatorContext
self, control_values: NDArray[np.float64], evaluator_context: EvaluatorContext
) -> EvaluatorResult:
# Reset the current run status:
self._status = None

# Get cached_results:
cached_results = self._get_cached_results(control_values, metadata)
cached_results = self._get_cached_results(control_values, evaluator_context)

# Create the batch to run:
case_data = self._init_case_data(control_values, metadata, cached_results)
batch_data = self._init_batch_data(
control_values, evaluator_context, cached_results
)

# Initialize a new experiment in storage:
# Initialize a new ensemble in storage:
assert self._experiment is not None
ensemble = self._experiment.create_ensemble(
name=f"batch_{self._batch_id}",
ensemble_size=len(case_data),
ensemble_size=len(batch_data),
)
for sim_id, controls in enumerate(case_data.values()):
for sim_id, controls in enumerate(batch_data.values()):
self._setup_sim(sim_id, controls, ensemble)

# Evaluate the batch:
run_args = self._get_run_args(ensemble, metadata, case_data)
run_args = self._get_run_args(ensemble, evaluator_context, batch_data)
self._context_env.update(
{
"_ERT_EXPERIMENT_ID": str(ensemble.experiment_id),
Expand All @@ -384,14 +386,14 @@ def _forward_model_evaluator(
# Gather the results and create the result for ropt:
results = self._gather_simulation_results(ensemble)
evaluator_result = self._make_evaluator_result(
control_values, metadata, case_data, results, cached_results
control_values, evaluator_context, batch_data, results, cached_results
)

# Add the results from the evaluations to the cache:
self._add_results_to_cache(
control_values,
metadata,
case_data,
evaluator_context,
batch_data,
evaluator_result.objectives,
evaluator_result.constraints,
)
Expand All @@ -417,7 +419,7 @@ def _get_cached_results(
return cached_results

@staticmethod
def _init_case_data(
def _init_batch_data(
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
cached_results: dict[int, Any],
Expand All @@ -440,7 +442,7 @@ def add_control(
group[variable_name] = control_value
controls[group_name] = group

case_data = {}
batch_data = {}
for control_idx in range(control_values.shape[0]):
if control_idx not in cached_results and (
evaluator_context.active is None
Expand All @@ -454,8 +456,8 @@ def add_control(
strict=False,
):
add_control(controls, control_name, control_value)
case_data[control_idx] = controls
return case_data
batch_data[control_idx] = controls
return batch_data

def _setup_sim(
self,
Expand Down Expand Up @@ -514,17 +516,17 @@ def _check_suffix(
def _get_run_args(
self,
ensemble: Ensemble,
metadata: EvaluatorContext,
case_data: dict[int, Any],
evaluator_context: EvaluatorContext,
batch_data: dict[int, Any],
) -> list[RunArg]:
substitutions = self.ert_config.substitutions
substitutions["<BATCH_NAME>"] = ensemble.name
self.active_realizations = [True] * len(case_data)
assert metadata.config.realizations.names is not None
for sim_id, control_idx in enumerate(case_data.keys()):
realization = metadata.realizations[control_idx]
self.active_realizations = [True] * len(batch_data)
assert evaluator_context.config.realizations.names is not None
for sim_id, control_idx in enumerate(batch_data.keys()):
realization = evaluator_context.realizations[control_idx]
substitutions[f"<GEO_ID_{sim_id}_0>"] = str(
metadata.config.realizations.names[realization]
evaluator_context.config.realizations.names[realization]
)
run_paths = Runpaths(
jobname_format=self.ert_config.model_config.jobname_format_string,
Expand Down Expand Up @@ -584,26 +586,28 @@ def _gather_simulation_results(
def _make_evaluator_result(
self,
control_values: NDArray[np.float64],
metadata: EvaluatorContext,
case_data: dict[int, Any],
evaluator_context: EvaluatorContext,
batch_data: dict[int, Any],
results: list[dict[str, NDArray[np.float64]]],
cached_results: dict[int, Any],
) -> EvaluatorResult:
# We minimize the negative of the objectives:
assert evaluator_context.config.objectives.names is not None
objectives = -self._get_simulation_results(
results,
metadata.config.objectives.names, # type: ignore
evaluator_context.config.objectives.names,
control_values,
case_data,
batch_data,
)

constraints = None
if metadata.config.nonlinear_constraints is not None:
if evaluator_context.config.nonlinear_constraints is not None:
assert evaluator_context.config.nonlinear_constraints.names is not None
constraints = self._get_simulation_results(
results,
metadata.config.nonlinear_constraints.names, # type: ignore
evaluator_context.config.nonlinear_constraints.names,
control_values,
case_data,
batch_data,
)

if self._simulator_cache is not None:
Expand All @@ -617,7 +621,7 @@ def _make_evaluator_result(
constraints[control_idx, ...] = cached_constraints

sim_ids = np.full(control_values.shape[0], -1, dtype=np.intc)
sim_ids[list(case_data.keys())] = np.arange(len(case_data), dtype=np.intc)
sim_ids[list(batch_data.keys())] = np.arange(len(batch_data), dtype=np.intc)
return EvaluatorResult(
objectives=objectives,
constraints=constraints,
Expand All @@ -630,9 +634,9 @@ def _get_simulation_results(
results: list[dict[str, NDArray[np.float64]]],
names: tuple[str],
controls: NDArray[np.float64],
case_data: dict[int, Any],
batch_data: dict[int, Any],
) -> NDArray[np.float64]:
control_indices = list(case_data.keys())
control_indices = list(batch_data.keys())
values = np.zeros((controls.shape[0], len(names)), dtype=float64)
for func_idx, name in enumerate(names):
values[control_indices, func_idx] = np.fromiter(
Expand All @@ -645,13 +649,13 @@ def _add_results_to_cache(
self,
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
case_data: dict[int, Any],
batch_data: dict[int, Any],
objectives: NDArray[np.float64],
constraints: NDArray[np.float64] | None,
) -> None:
if self._simulator_cache is not None:
assert evaluator_context.config.realizations.names is not None
for control_idx in case_data:
for control_idx in batch_data:
realization = evaluator_context.realizations[control_idx]
self._simulator_cache.add(
evaluator_context.config.realizations.names[realization],
Expand Down

0 comments on commit e6a8781

Please sign in to comment.