Skip to content

Commit

Permalink
Remove use of deprecated types
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Dec 10, 2024
1 parent ad786d8 commit bda3f35
Show file tree
Hide file tree
Showing 17 changed files with 133 additions and 181 deletions.
14 changes: 4 additions & 10 deletions src/_ert/events.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import sys
from datetime import datetime
from typing import Any, Dict, Final, Literal, Union

if sys.version_info < (3, 9):
from typing_extensions import Annotated
else:
from typing import Annotated
from typing import Annotated, Any, Final, Literal, Union

from pydantic import BaseModel, ConfigDict, Field, TypeAdapter

Expand Down Expand Up @@ -103,7 +97,7 @@ class ForwardModelStepChecksum(BaseEvent):
event_type: Id.FORWARD_MODEL_STEP_CHECKSUM_TYPE = Id.FORWARD_MODEL_STEP_CHECKSUM
ensemble: Union[str, None] = None
real: str
checksums: Dict[str, Dict[str, Any]]
checksums: dict[str, dict[str, Any]]


class RealizationBaseEvent(BaseEvent):
Expand Down Expand Up @@ -240,13 +234,13 @@ def event_from_json(raw_msg: Union[str, bytes]) -> Event:
return EventAdapter.validate_json(raw_msg)


def event_from_dict(dict_msg: Dict[str, Any]) -> Event:
def event_from_dict(dict_msg: dict[str, Any]) -> Event:
return EventAdapter.validate_python(dict_msg)


def event_to_json(event: Event) -> str:
return event.model_dump_json()


def event_to_dict(event: Event) -> Dict[str, Any]:
def event_to_dict(event: Event) -> dict[str, Any]:
return event.model_dump()
3 changes: 1 addition & 2 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio
import logging
import ssl
from typing import Any, AnyStr, Optional, Union
from typing import Any, AnyStr, Optional, Self, Union

from typing_extensions import Self
from websockets.asyncio.client import ClientConnection, connect
from websockets.datastructures import Headers
from websockets.exceptions import (
Expand Down
40 changes: 17 additions & 23 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,19 @@
Callable,
Generic,
Iterable,
List,
Optional,
Self,
Sequence,
Tuple,
TypeVar,
)

import iterative_ensemble_smoother as ies
import numpy as np
import polars
import psutil
from iterative_ensemble_smoother.experimental import (
AdaptiveESMDA,
)
from typing_extensions import Self
from iterative_ensemble_smoother.experimental import AdaptiveESMDA

from ert.config import (
GenKwConfig,
)
from ert.config import GenKwConfig

from ..config.analysis_config import ObservationGroups, UpdateSettings
from ..config.analysis_module import ESSettings, IESSettings
Expand Down Expand Up @@ -222,8 +216,8 @@ def _get_observations_and_responses(


def _expand_wildcards(
input_list: npt.NDArray[np.str_], patterns: List[str]
) -> List[str]:
input_list: npt.NDArray[np.str_], patterns: list[str]
) -> list[str]:
"""
Returns a sorted list of unique strings from `input_list` that match any of the specified wildcard patterns.
Expand All @@ -250,14 +244,14 @@ def _load_observations_and_responses(
global_std_scaling: float,
iens_active_index: npt.NDArray[np.int_],
selected_observations: Iterable[str],
auto_scale_observations: Optional[List[ObservationGroups]],
auto_scale_observations: Optional[list[ObservationGroups]],
progress_callback: Callable[[AnalysisEvent], None],
) -> Tuple[
) -> tuple[
npt.NDArray[np.float64],
Tuple[
tuple[
npt.NDArray[np.float64],
npt.NDArray[np.float64],
List[ObservationAndResponseSnapshot],
list[ObservationAndResponseSnapshot],
],
]:
# cols: response_key, index, observation_key, observations, std, *[1, ...nreals]
Expand Down Expand Up @@ -412,7 +406,7 @@ def _load_observations_and_responses(

def _split_by_batchsize(
arr: npt.NDArray[np.int_], batch_size: int
) -> List[npt.NDArray[np.int_]]:
) -> list[npt.NDArray[np.int_]]:
"""
Splits an array into sub-arrays of a specified batch size.
Expand Down Expand Up @@ -496,8 +490,8 @@ def _copy_unupdated_parameters(
This is necessary because users can choose not to update parameters but may still want to analyse them.
Parameters:
all_parameter_groups (List[str]): A list of all parameter groups.
updated_parameter_groups (List[str]): A list of parameter groups that have already been updated.
all_parameter_groups (list[str]): A list of all parameter groups.
updated_parameter_groups (list[str]): A list of parameter groups that have already been updated.
iens_active_index (npt.NDArray[np.int_]): An array of indices for the active realizations in the
target ensemble.
source_ensemble (Ensemble): The file system of the source ensemble, from which parameters are copied.
Expand Down Expand Up @@ -532,7 +526,7 @@ def analysis_ES(
source_ensemble: Ensemble,
target_ensemble: Ensemble,
progress_callback: Callable[[AnalysisEvent], None],
auto_scale_observations: Optional[List[ObservationGroups]],
auto_scale_observations: Optional[list[ObservationGroups]],
) -> None:
iens_active_index = np.flatnonzero(ens_mask)

Expand Down Expand Up @@ -611,7 +605,7 @@ def adaptive_localization_progress_callback(

def correlation_callback(
cross_correlations_of_batch: npt.NDArray[np.float64],
cross_correlations_accumulator: List[npt.NDArray[np.float64]],
cross_correlations_accumulator: list[npt.NDArray[np.float64]],
) -> None:
cross_correlations_accumulator.append(cross_correlations_of_batch)

Expand All @@ -632,7 +626,7 @@ def correlation_callback(
progress_callback(AnalysisStatusEvent(msg=log_msg))

start = time.time()
cross_correlations: List[npt.NDArray[np.float64]] = []
cross_correlations: list[npt.NDArray[np.float64]] = []
for param_batch_idx in batches:
X_local = param_ensemble_array[param_batch_idx, :]
if isinstance(config_node, GenKwConfig):
Expand Down Expand Up @@ -712,7 +706,7 @@ def analysis_IES(
target_ensemble: Ensemble,
sies_smoother: Optional[ies.SIES],
progress_callback: Callable[[AnalysisEvent], None],
auto_scale_observations: List[ObservationGroups],
auto_scale_observations: list[ObservationGroups],
sies_step_length: Callable[[int], float],
initial_mask: npt.NDArray[np.bool_],
) -> ies.SIES:
Expand Down Expand Up @@ -911,7 +905,7 @@ def iterative_smoother_update(
rng: Optional[np.random.Generator] = None,
progress_callback: Optional[Callable[[AnalysisEvent], None]] = None,
global_scaling: float = 1.0,
) -> Tuple[SmootherSnapshot, ies.SIES]:
) -> tuple[SmootherSnapshot, ies.SIES]:
if not progress_callback:
progress_callback = noop_progress_callback
if rng is None:
Expand Down
11 changes: 5 additions & 6 deletions src/ert/dark_storage/endpoints/records.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import io
from typing import Any, Dict, List, Mapping, Union
from typing import Annotated, Any, Mapping, Union
from urllib.parse import unquote
from uuid import UUID, uuid4

import numpy as np
from fastapi import APIRouter, Body, Depends, File, Header, HTTPException, status
from fastapi.responses import Response
from typing_extensions import Annotated

from ert.dark_storage import json_schema as js
from ert.dark_storage.common import (
Expand Down Expand Up @@ -34,7 +33,7 @@ async def get_record_observations(
storage: Storage = DEFAULT_STORAGE,
ensemble_id: UUID,
response_name: str,
) -> List[js.ObservationOut]:
) -> list[js.ObservationOut]:
response_name = unquote(response_name)
ensemble = storage.get_ensemble(ensemble_id)
obs_keys = get_observation_keys_for_response(ensemble, response_name)
Expand Down Expand Up @@ -104,10 +103,10 @@ async def get_ensemble_record(
)


@router.get("/ensembles/{ensemble_id}/parameters", response_model=List[Dict[str, Any]])
@router.get("/ensembles/{ensemble_id}/parameters", response_model=list[dict[str, Any]])
async def get_ensemble_parameters(
*, storage: Storage = DEFAULT_STORAGE, ensemble_id: UUID
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
return ensemble_parameters(storage, ensemble_id)


Expand All @@ -119,7 +118,7 @@ def get_ensemble_responses(
storage: Storage = DEFAULT_STORAGE,
ensemble_id: UUID,
) -> Mapping[str, js.RecordOut]:
response_map: Dict[str, js.RecordOut] = {}
response_map: dict[str, js.RecordOut] = {}
ensemble = storage.get_ensemble(ensemble_id)

response_names_with_observations = set()
Expand Down
8 changes: 1 addition & 7 deletions src/ert/dark_storage/json_schema/prior.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
import sys
from typing import Union
from typing import Literal, Union

from pydantic import BaseModel

if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal


class PriorConst(BaseModel):
"""
Expand Down
37 changes: 16 additions & 21 deletions src/ert/ensemble_evaluator/snapshot.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
import logging
import typing
from collections import defaultdict
from datetime import datetime
from typing import (
Any,
Counter,
DefaultDict,
Dict,
List,
Mapping,
Optional,
Tuple,
TypedDict,
TypeVar,
Union,
cast,
get_args,
)

from qtpy.QtGui import QColor
from typing_extensions import TypedDict

from _ert.events import (
EESnapshot,
Expand Down Expand Up @@ -90,11 +85,11 @@ def convert_iso8601_to_datetime(

class EnsembleSnapshotMetadata(TypedDict):
# contains the QColor used in the GUI for each fm_step
aggr_fm_step_status_colors: DefaultDict[RealId, Dict[FmStepId, QColor]]
aggr_fm_step_status_colors: defaultdict[RealId, dict[FmStepId, QColor]]
# contains the QColor used in the GUI for each real
real_status_colors: Dict[RealId, QColor]
sorted_real_ids: List[RealId]
sorted_fm_step_ids: DefaultDict[RealId, List[FmStepId]]
real_status_colors: dict[RealId, QColor]
sorted_real_ids: list[RealId]
sorted_fm_step_ids: defaultdict[RealId, list[FmStepId]]


class EnsembleSnapshot:
Expand All @@ -105,13 +100,13 @@ class EnsembleSnapshot:
"""

def __init__(self) -> None:
self._realization_snapshots: DefaultDict[
self._realization_snapshots: defaultdict[
RealId,
RealizationSnapshot,
] = defaultdict(RealizationSnapshot) # type: ignore

self._fm_step_snapshots: DefaultDict[
Tuple[RealId, FmStepId], FMStepSnapshot
self._fm_step_snapshots: defaultdict[
tuple[RealId, FmStepId], FMStepSnapshot
] = defaultdict(FMStepSnapshot) # type: ignore

self._ensemble_state: Optional[str] = None
Expand Down Expand Up @@ -159,9 +154,9 @@ def merge_snapshot(self, ensemble: "EnsembleSnapshot") -> "EnsembleSnapshot":
def merge_metadata(self, metadata: EnsembleSnapshotMetadata) -> None:
self._metadata.update(metadata)

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""used to send snapshot updates"""
_dict: Dict[str, Any] = {}
_dict: dict[str, Any] = {}
if self._metadata:
_dict["metadata"] = self._metadata
if self._ensemble_state:
Expand Down Expand Up @@ -191,12 +186,12 @@ def metadata(self) -> EnsembleSnapshotMetadata:

def get_all_fm_steps(
self,
) -> Mapping[Tuple[RealId, FmStepId], "FMStepSnapshot"]:
) -> Mapping[tuple[RealId, FmStepId], "FMStepSnapshot"]:
return self._fm_step_snapshots.copy()

def get_fm_steps_for_all_reals(
self,
) -> Mapping[Tuple[RealId, FmStepId], str]:
) -> Mapping[tuple[RealId, FmStepId], str]:
return {
idx: fm_step_snapshot["status"]
for idx, fm_step_snapshot in self._fm_step_snapshots.items()
Expand All @@ -209,7 +204,7 @@ def reals(self) -> Mapping[RealId, "RealizationSnapshot"]:

def get_fm_steps_for_real(
self, real_id: RealId
) -> Dict[FmStepId, "FMStepSnapshot"]:
) -> dict[FmStepId, "FMStepSnapshot"]:
return {
fm_step_idx[1]: fm_step_snapshot.copy()
for fm_step_idx, fm_step_snapshot in self._fm_step_snapshots.items()
Expand All @@ -222,7 +217,7 @@ def get_real(self, real_id: RealId) -> "RealizationSnapshot":
def get_fm_step(self, real_id: RealId, fm_step_id: FmStepId) -> "FMStepSnapshot":
return self._fm_step_snapshots[real_id, fm_step_id].copy()

def get_successful_realizations(self) -> typing.List[int]:
def get_successful_realizations(self) -> list[int]:
return [
int(real_idx)
for real_idx, real_data in self._realization_snapshots.items()
Expand Down Expand Up @@ -400,12 +395,12 @@ class RealizationSnapshot(TypedDict, total=False):
start_time: Optional[datetime]
end_time: Optional[datetime]
exec_hosts: Optional[str]
fm_steps: Dict[str, FMStepSnapshot]
fm_steps: dict[str, FMStepSnapshot]
message: Optional[str]


def _realization_dict_to_realization_snapshot(
source: Dict[str, Any],
source: dict[str, Any],
) -> RealizationSnapshot:
realization = RealizationSnapshot(
status=source.get("status"),
Expand Down
3 changes: 2 additions & 1 deletion src/ert/gui/suggestor/_suggestor_message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any, Self

from qtpy import QtSvg
from qtpy.QtCore import Qt
from qtpy.QtGui import QColor
Expand All @@ -11,7 +13,6 @@
QVBoxLayout,
QWidget,
)
from typing_extensions import Any, Self

from ._colors import (
BLUE_BACKGROUND,
Expand Down
4 changes: 1 addition & 3 deletions src/ert/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from functools import wraps
from typing import Callable

from typing_extensions import Any, ParamSpec
from typing import Any, Callable, ParamSpec

from .plugin_manager import (
ErtPluginContext,
Expand Down
Loading

0 comments on commit bda3f35

Please sign in to comment.