Skip to content

Commit

Permalink
Fixup
Browse files Browse the repository at this point in the history
Fixed typing for nodes
Removed suggestor notifier
Removed PyQt5 dependencies
Fixed more typing
  • Loading branch information
JHolba committed Dec 30, 2024
1 parent 0a93046 commit 28e191c
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 62 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ dependencies = [
"psutil",
"pyarrow", # extra dependency for pandas (parquet)
"pydantic > 2",
"PyQt5",
"python-dateutil",
"python-multipart", # extra dependency for fastapi
"pyyaml",
Expand Down
2 changes: 1 addition & 1 deletion src/ert/gui/ertwidgets/searchbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def focusOutEvent(self, arg__1: QFocusEvent) -> None:
self.exitSearch()

def keyPressEvent(self, arg__1: QKeyEvent) -> None:
if arg__1 is not None and arg__1.key() == Qt.Key.Key_Escape:
if arg__1.key() == Qt.Key.Key_Escape:
self.clear()
self.clearFocus()
else:
Expand Down
1 change: 0 additions & 1 deletion src/ert/gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def continue_action() -> None:
continue_action,
plugin_manager.get_help_links() if plugin_manager is not None else {},
)
suggestor.notifier = main_window.notifier
return (
suggestor,
ert_config.ens_path,
Expand Down
86 changes: 52 additions & 34 deletions src/ert/gui/model/node.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,50 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from abc import ABC
from dataclasses import dataclass, field
from typing import cast

from PySide6.QtGui import QColor

from ert.ensemble_evaluator.snapshot import FMStepSnapshot


@dataclass
class Node(ABC):
class _NodeBase(ABC):
id_: str
parent: RootNode | IterNode | RealNode | None = None
children: (
dict[str, IterNode] | dict[str, RealNode] | dict[str, ForwardModelStepNode]
) = field(default_factory=dict)
_index: int | None = None

def __repr__(self) -> str:
parent = "no " if self.parent is None else ""
children = "no " if len(self.children) == 0 else f"{len(self.children)} "
return f"Node<{type(self).__name__}>@{self.id_} with {parent}parent and {children}children"

@abstractmethod
def add_child(self, node: Node) -> None:
pass
def _repr(node: _Node) -> str:
parent = "no " if node.parent is None else ""
children = "no " if not node.children else f"{len(node.children)} "
return f"Node<{type(node).__name__}>@{node.id_} with {parent}parent and {children}children"

def row(self) -> int:
if not self._index:
if self.parent:
self._index = list(self.parent.children.keys()).index(self.id_)
else:
raise ValueError(f"{self} had no parent")
return self._index

def _row(node: _Node) -> int:
if not node._index:
if node.parent:
node._index = list(node.parent.children.keys()).index(node.id_)
else:
raise ValueError(f"{node} had no parent")
return node._index


@dataclass
class RootNode(Node):
class RootNode(_NodeBase):
parent: None = field(default=None, init=False)
children: dict[str, IterNode] = field(default_factory=dict)
max_memory_usage: int | None = None

def add_child(self, node: Node) -> None:
node = cast(IterNode, node)
def add_child(self, node: IterNode) -> None:
node.parent = self
self.children[node.id_] = node

def row(self) -> int:
return _row(self)

def __repr__(self) -> str:
return _repr(self)


@dataclass
class IterNodeData:
Expand All @@ -55,16 +53,21 @@ class IterNodeData:


@dataclass
class IterNode(Node):
class IterNode(_NodeBase):
parent: RootNode | None = None
data: IterNodeData = field(default_factory=IterNodeData)
children: dict[str, RealNode] = field(default_factory=dict)

def add_child(self, node: Node) -> None:
node = cast(RealNode, node)
def add_child(self, node: RealNode) -> None:
node.parent = self
self.children[node.id_] = node

def row(self) -> int:
return _row(self)

def __repr__(self) -> str:
return _repr(self)


@dataclass
class RealNodeData:
Expand All @@ -80,21 +83,36 @@ class RealNodeData:


@dataclass
class RealNode(Node):
class RealNode(_NodeBase):
parent: IterNode | None = None
data: RealNodeData = field(default_factory=RealNodeData)
children: dict[str, ForwardModelStepNode] = field(default_factory=dict)

def add_child(self, node: Node) -> None:
node = cast(ForwardModelStepNode, node)
def add_child(self, node: ForwardModelStepNode) -> None:
node.parent = self
self.children[node.id_] = node

def row(self) -> int:
return _row(self)

def __repr__(self) -> str:
return _repr(self)


@dataclass
class ForwardModelStepNode(Node):
parent: RealNode | None
class ForwardModelStepNode(_NodeBase):
parent: RealNode | None = None
data: FMStepSnapshot = field(default_factory=lambda: FMStepSnapshot()) # noqa: PLW0108
children: None = None

def add_child(self, _: _NodeBase) -> None:
raise RuntimeError(f"Can not add children to {self.__class__.__name__}")

def row(self) -> int:
return _row(self)

def __repr__(self) -> str:
return _repr(self)


def add_child(self, node: Node) -> None:
pass
_Node = RootNode | IterNode | RealNode | ForwardModelStepNode
31 changes: 19 additions & 12 deletions src/ert/gui/model/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Sequence
from contextlib import ExitStack
from datetime import datetime, timedelta
from typing import Any, Final, overload
from typing import Any, Final, cast, overload

from PySide6.QtCore import (
QAbstractItemModel,
Expand Down Expand Up @@ -52,9 +52,9 @@
IsFMStepRole = UserRole + 10
StatusRole = UserRole + 11

DURATION = "Duration"
DURATION: Final[str] = "Duration"

FM_STEP_COLUMNS: Sequence[str] = [
FM_STEP_COLUMNS: Final[Sequence[str]] = [
ids.NAME,
ids.ERROR,
ids.STATUS,
Expand Down Expand Up @@ -295,12 +295,16 @@ def rowCount(
) -> int:
if parent is None:
parent = QModelIndex()
parent_item = self.root if not parent.isValid() else parent.internalPointer()
parent_item = (
self.root
if not parent.isValid()
else cast(RootNode | IterNode | RealNode, parent.internalPointer())
)

if parent.column() > 0:
return 0

return len(parent_item.children) # type: ignore
return len(parent_item.children)

@overload
def parent(self) -> QObject: ...
Expand All @@ -313,10 +317,13 @@ def parent(
if child is None or not child.isValid():
return QModelIndex()

parent_item = child.internalPointer().parent # type:ignore
parent_item = cast(
IterNode | RealNode | ForwardModelStepNode, child.internalPointer()
).parent
if parent_item == self.root:
return QModelIndex()

assert parent_item
return self.createIndex(parent_item.row(), 0, parent_item)

@override
Expand All @@ -331,7 +338,7 @@ def data(
if role == Qt.ItemDataRole.TextAlignmentRole:
return Qt.AlignmentFlag.AlignCenter

node: IterNode | RealNode | ForwardModelStepNode = index.internalPointer() # type:ignore
node = cast(IterNode | RealNode | ForwardModelStepNode, index.internalPointer())
if role == NodeRole:
return node

Expand All @@ -343,7 +350,7 @@ def data(
return isinstance(node, ForwardModelStepNode)

if isinstance(node, ForwardModelStepNode):
return self._fm_step_data(index, node, role)
return self._fm_step_data(index, node, Qt.ItemDataRole(role))
if isinstance(node, RealNode):
return self._real_data(index, node, role)

Expand Down Expand Up @@ -408,7 +415,7 @@ def _real_data(
def _fm_step_data(
index: QModelIndex | QPersistentModelIndex,
node: ForwardModelStepNode,
role: int, # Qt.ItemDataRole
role: Qt.ItemDataRole,
) -> Any:
node_id = str(node.id_)

Expand Down Expand Up @@ -437,9 +444,9 @@ def _fm_step_data(

if role == Qt.ItemDataRole.DisplayRole:
data_name = FM_STEP_COLUMNS[index.column()]
if data_name in {ids.MAX_MEMORY_USAGE}:
if data_name == ids.MAX_MEMORY_USAGE:
data = node.data
bytes_: str | None = data.get(data_name) # type: ignore
bytes_ = cast(str | None, data.get(data_name))
if bytes_:
return byte_with_unit(float(bytes_))

Expand All @@ -448,7 +455,7 @@ def _fm_step_data(
return "-"
return "View" if data_name in node.data else None

if data_name in {DURATION}:
if data_name == DURATION:
start_time = node.data.get(ids.START_TIME)
if start_time is None:
return None
Expand Down
4 changes: 2 additions & 2 deletions src/ert/gui/simulation/run_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ert.gui.ertnotifier import ErtNotifier
from ert.gui.ertwidgets.message_box import ErtMessageBox
from ert.gui.model.fm_step_list import FMStepListProxyModel
from ert.gui.model.node import Node
from ert.gui.model.node import IterNode
from ert.gui.model.real_list import RealListModel
from ert.gui.model.snapshot import (
FM_STEP_COLUMNS,
Expand Down Expand Up @@ -306,7 +306,7 @@ def on_snapshot_new_iteration(
) -> None:
if not parent.isValid():
index = self._snapshot_model.index(start, 0, parent)
iteration = cast(Node, index.internalPointer()).id_
iteration = cast(IterNode, index.internalPointer()).id_
iter_row = start
self._iteration_progress_label.setText(
f"Progress for iteration {iteration}"
Expand Down
8 changes: 3 additions & 5 deletions src/ert/gui/simulation/view/progress_widget.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

from typing import Any

from PySide6.QtGui import QColor
from PySide6.QtGui import QColor, QResizeEvent
from PySide6.QtWidgets import (
QFrame,
QHBoxLayout,
Expand All @@ -16,7 +14,7 @@

class ProgressWidget(QFrame):
def __init__(self) -> None:
QFrame.__init__(self)
super().__init__()
self.setFixedHeight(70)

self._vertical_layout = QVBoxLayout(self)
Expand Down Expand Up @@ -109,5 +107,5 @@ def update_progress(self, status: dict[str, int], realization_count: int) -> Non
self.stop_waiting_progress_bar()
self.repaint_components()

def resizeEvent(self, event: Any = None) -> None:
def resizeEvent(self, event: QResizeEvent) -> None:
self.repaint_components()
4 changes: 0 additions & 4 deletions src/ert/gui/suggestor/suggestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
QWidget,
)

from ert.gui.ertnotifier import ErtNotifier

from ._colors import BLUE_TEXT
from ._suggestor_message import SuggestorMessage

Expand Down Expand Up @@ -113,7 +111,6 @@ def __init__(
deprecations: list[WarningInfo],
continue_action: Callable[[], None] | None,
help_links: dict[str, str] | None = None,
notifier: ErtNotifier | None = None,
) -> None:
super().__init__()
self._continue_action = continue_action
Expand All @@ -134,7 +131,6 @@ def __init__(
self.setStyleSheet(f"background-color: {LIGHT_GREY}; color: black")
self.__layout.setContentsMargins(32, 47, 32, 16)
self.__layout.setSpacing(32)
self.notifier = notifier

data_layout = QHBoxLayout()
data_widget.setLayout(data_layout)
Expand Down
4 changes: 2 additions & 2 deletions tests/ert/unit_tests/gui/simulation/view/test_realization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
FORWARD_MODEL_STATE_START,
REALIZATION_STATE_UNKNOWN,
)
from ert.gui.model.node import Node
from ert.gui.model.node import _Node
from ert.gui.model.snapshot import SnapshotModel
from ert.gui.simulation.view.realization import RealizationWidget
from tests.ert import SnapshotBuilder
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_selection_success(large_snapshot, qtbot):

def check_selection_cb(index):
node = index.internalPointer()
return isinstance(node, Node) and str(node.id_) == str(selection_id)
return isinstance(node, _Node) and str(node.id_) == str(selection_id)

with qtbot.waitSignal(
widget.itemClicked, timeout=30000, check_params_cb=check_selection_cb
Expand Down

0 comments on commit 28e191c

Please sign in to comment.