From 77fa511421ef1a08aa71272a80674562d8067236 Mon Sep 17 00:00:00 2001 From: Eivind Jahren Date: Fri, 14 Jun 2024 10:26:08 +0200 Subject: [PATCH] Fix signature of overrides from QAbstractItemModel Also fixes some typing issues in ert.gui.model so that type checking can be turned on for parts of it. --- .mypy.ini | 7 ++- .../gui/ertwidgets/models/storage_model.py | 44 +++++++++---- src/ert/gui/model/job_list.py | 47 +++++++++----- src/ert/gui/model/progress_proxy.py | 63 ++++++++++++------- src/ert/gui/model/real_list.py | 52 ++++++++++----- src/ert/gui/model/snapshot.py | 28 ++++++--- .../tools/plot/data_type_keys_list_model.py | 28 ++++++--- 7 files changed, 185 insertions(+), 84 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index ee120cac22f..c3abf98fcbb 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -39,10 +39,15 @@ ignore_missing_imports = True ignore_missing_imports = True ignore_errors = True -[mypy-ert.gui.model.*] +[mypy-ert.gui.model.node] ignore_missing_imports = True ignore_errors = True +[mypy-ert.gui.model.snapshot] +ignore_missing_imports = True +ignore_errors = True + + [mypy-ert.gui.simulation.*] ignore_missing_imports = True ignore_errors = True diff --git a/src/ert/gui/ertwidgets/models/storage_model.py b/src/ert/gui/ertwidgets/models/storage_model.py index 6fc86280910..1061e898fbf 100644 --- a/src/ert/gui/ertwidgets/models/storage_model.py +++ b/src/ert/gui/ertwidgets/models/storage_model.py @@ -1,15 +1,17 @@ from enum import IntEnum -from typing import Any, List +from typing import Any, List, Optional, overload from uuid import UUID import humanize from qtpy.QtCore import ( QAbstractItemModel, QModelIndex, + QObject, Qt, Slot, ) from qtpy.QtWidgets import QApplication +from typing_extensions import override from ert.storage import Ensemble, Experiment, Storage @@ -165,11 +167,14 @@ def _load_storage(self, storage: Storage) -> None: self._children.append(ex) - @staticmethod - def columnCount(parent: QModelIndex) -> int: + @override + def columnCount(self, parent: Optional[QModelIndex] = None) -> int: return _NUM_COLUMNS - def rowCount(self, parent: QModelIndex) -> int: + @override + def rowCount(self, parent: Optional[QModelIndex] = None) -> int: + if parent is None: + parent = QModelIndex() if parent.isValid(): if isinstance(parent.internalPointer(), RealizationModel): return 0 @@ -177,11 +182,16 @@ def rowCount(self, parent: QModelIndex) -> int: else: return len(self._children) - def parent(self, index: QModelIndex) -> QModelIndex: - if not index.isValid(): + @overload + def parent(self, child: QModelIndex) -> QModelIndex: ... + @overload + def parent(self) -> Optional[QObject]: ... + @override + def parent(self, child: Optional[QModelIndex] = None) -> Optional[QObject]: + if child is None or not child.isValid(): return QModelIndex() - child_item = index.internalPointer() + child_item = child.internalPointer() parentItem = child_item._parent if parentItem == self: @@ -189,21 +199,31 @@ def parent(self, index: QModelIndex) -> QModelIndex: return self.createIndex(parentItem.row(), 0, parentItem) - @staticmethod - def headerData(section: int, orientation: int, role: int) -> Any: + @override + def headerData( + self, + section: int, + orientation: Qt.Orientation, + role: int = Qt.ItemDataRole.DisplayRole, + ) -> Any: if role != Qt.ItemDataRole.DisplayRole: return None return _COLUMN_TEXT[_Column(section)] - @staticmethod - def data(index: QModelIndex, role=Qt.ItemDataRole.DisplayRole) -> Any: + @override + def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: if not index.isValid(): return None return index.internalPointer().data(index, role) - def index(self, row: int, column: int, parent: QModelIndex) -> QModelIndex: + @override + def index( + self, row: int, column: int, parent: Optional[QModelIndex] = None + ) -> QModelIndex: + if parent is None: + parent = QModelIndex() parentItem = parent.internalPointer() if parent.isValid() else self try: childItem = parentItem._children[row] diff --git a/src/ert/gui/model/job_list.py b/src/ert/gui/model/job_list.py index 488290180a9..dd8a6c9b14d 100644 --- a/src/ert/gui/model/job_list.py +++ b/src/ert/gui/model/job_list.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any, List, Optional, overload from qtpy.QtCore import ( QAbstractItemModel, @@ -9,6 +9,7 @@ QVariant, Slot, ) +from typing_extensions import override from ert.ensemble_evaluator import identifiers as ids from ert.gui.model.node import NodeType @@ -70,7 +71,8 @@ def _get_source_parent_index(self) -> QModelIndex: source_parent = self.mapToSource(start).parent() return source_parent - def setSourceModel(self, sourceModel: QAbstractItemModel) -> None: + @override + def setSourceModel(self, sourceModel: Optional[QAbstractItemModel]) -> None: if not sourceModel: raise ValueError("need source model") self.beginResetModel() @@ -79,26 +81,31 @@ def setSourceModel(self, sourceModel: QAbstractItemModel) -> None: self._connect() self.endResetModel() - @staticmethod - def headerData(section: int, orientation: Qt.Orientation, role: Qt.UserRole) -> Any: - if role != Qt.DisplayRole: + @override + def headerData( + self, + section: int, + orientation: Qt.Orientation, + role: int = Qt.ItemDataRole.DisplayRole, + ) -> Any: + if role != Qt.ItemDataRole.DisplayRole: return QVariant() - if orientation == Qt.Horizontal: + if orientation == Qt.Orientation.Horizontal: header = COLUMNS[NodeType.REAL][section] if header in [ids.STDOUT, ids.STDERR]: return header.upper() if header in [ids.CURRENT_MEMORY_USAGE, ids.MAX_MEMORY_USAGE]: header = header.replace("_", " ") return header.capitalize() - if orientation == Qt.Vertical: + if orientation == Qt.Orientation.Vertical: return section return QVariant() - @staticmethod - def columnCount(parent: QModelIndex = None): + @override + def columnCount(self, parent: Optional[QModelIndex] = None) -> int: return len(COLUMNS[NodeType.REAL]) - def rowCount(self, parent=None) -> int: + def rowCount(self, parent: Optional[QModelIndex] = None) -> int: if parent is None: parent = QModelIndex() if parent.isValid(): @@ -106,13 +113,22 @@ def rowCount(self, parent=None) -> int: source_index = self._get_source_parent_index() if not source_index.isValid(): return 0 - return self.sourceModel().rowCount(source_index) - - @staticmethod - def parent(_index: QModelIndex): + source_model = self.sourceModel() + assert source_model is not None + return source_model.rowCount(source_index) + + @overload + def parent(self, child: QModelIndex) -> QModelIndex: ... + @overload + def parent(self) -> Optional[QObject]: ... + @override + def parent(self, child: Optional[QModelIndex] = None) -> Optional[QObject]: return QModelIndex() - def index(self, row: int, column: int, parent=None) -> QModelIndex: + @override + def index( + self, row: int, column: int, parent: Optional[QModelIndex] = None + ) -> QModelIndex: if parent is None: parent = QModelIndex() if parent.isValid(): @@ -124,6 +140,7 @@ def mapToSource(self, proxyIndex: QModelIndex) -> QModelIndex: if not proxyIndex.isValid(): return QModelIndex() source_model = self.sourceModel() + assert source_model is not None iter_index = source_model.index(self._iter, 0, QModelIndex()) if not iter_index.isValid() or not source_model.hasChildren(iter_index): return QModelIndex() diff --git a/src/ert/gui/model/progress_proxy.py b/src/ert/gui/model/progress_proxy.py index 48979e50914..25564452a04 100644 --- a/src/ert/gui/model/progress_proxy.py +++ b/src/ert/gui/model/progress_proxy.py @@ -1,19 +1,20 @@ from collections import defaultdict -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, overload -from qtpy.QtCore import QAbstractItemModel, QModelIndex, QSize, Qt, QVariant +from qtpy.QtCore import QAbstractItemModel, QModelIndex, QObject, QSize, Qt, QVariant from qtpy.QtGui import QColor, QFont +from typing_extensions import override from ert.gui.model.snapshot import IsEnsembleRole, ProgressRole, StatusRole class ProgressProxyModel(QAbstractItemModel): def __init__( - self, source_model: QAbstractItemModel, parent: QModelIndex = None + self, source_model: QAbstractItemModel, parent: Optional[QModelIndex] = None ) -> None: QAbstractItemModel.__init__(self, parent) self._source_model: QAbstractItemModel = source_model - self._progress: Optional[Dict[str, Union[dict, int]]] = None + self._progress: Optional[Dict[str, Union[dict[Any, Any], int]]] = None self._connect() def _connect(self) -> None: @@ -28,66 +29,84 @@ def _connect(self) -> None: if last_iter >= 0: self._recalculate_progress(last_iter) - @staticmethod - def columnCount(parent: QModelIndex = None) -> int: + @override + def columnCount(self, parent: Optional[QModelIndex] = None) -> int: if parent is None: parent = QModelIndex() if parent.isValid(): return 0 return 1 - @staticmethod - def rowCount(parent: QModelIndex = None) -> int: + @override + def rowCount(self, parent: Optional[QModelIndex] = None) -> int: if parent is None: parent = QModelIndex() if parent.isValid(): return 0 return 1 - def index(self, row: int, column: int, parent: QModelIndex = None) -> QModelIndex: + @override + def index( + self, row: int, column: int, parent: Optional[QModelIndex] = None + ) -> QModelIndex: if parent is None: parent = QModelIndex() if parent.isValid(): return QModelIndex() return self.createIndex(row, column, None) - @staticmethod - def parent(_index: QModelIndex) -> QModelIndex: + @overload + def parent(self, child: QModelIndex) -> QModelIndex: ... + @overload + def parent(self) -> Optional[QObject]: ... + @override + def parent(self, child: Optional[QModelIndex] = None) -> Optional[QObject]: return QModelIndex() - @staticmethod - def hasChildren(parent: QModelIndex) -> bool: + @override + def hasChildren(self, parent: Optional[QModelIndex] = None) -> bool: + if parent is None: + return QModelIndex().isValid() return not parent.isValid() - def data(self, index: QModelIndex, role=Qt.DisplayRole) -> QVariant: + @override + def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: if not index.isValid(): return QVariant() - if role == Qt.TextAlignmentRole: - return Qt.AlignCenter + if role == Qt.ItemDataRole.TextAlignmentRole: + return Qt.AlignmentFlag.AlignCenter if role == ProgressRole: return self._progress - if role in (Qt.StatusTipRole, Qt.WhatsThisRole, Qt.ToolTipRole): + if role in ( + Qt.ItemDataRole.StatusTipRole, + Qt.ItemDataRole.WhatsThisRole, + Qt.ItemDataRole.ToolTipRole, + ): return "" - if role == Qt.SizeHintRole: + if role == Qt.ItemDataRole.SizeHintRole: return QSize(30, 30) - if role == Qt.FontRole: + if role == Qt.ItemDataRole.FontRole: return QFont() - if role in (Qt.BackgroundRole, Qt.ForegroundRole, Qt.DecorationRole): + if role in ( + Qt.ItemDataRole.BackgroundRole, + Qt.ItemDataRole.ForegroundRole, + Qt.ItemDataRole.DecorationRole, + ): return QColor() - if role == Qt.DisplayRole: + if role == Qt.ItemDataRole.DisplayRole: return "" return QVariant() def _recalculate_progress(self, iter_: int) -> None: - status_counts = defaultdict(int) + status_counts: Dict[Any, int] = defaultdict(int) nr_reals: int = 0 current_iter_index = self._source_model.index(iter_, 0, QModelIndex()) if current_iter_index.internalPointer() is None: diff --git a/src/ert/gui/model/real_list.py b/src/ert/gui/model/real_list.py index 3317c4ca5bb..8c0b4b329e0 100644 --- a/src/ert/gui/model/real_list.py +++ b/src/ert/gui/model/real_list.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, overload from qtpy.QtCore import ( QAbstractItemModel, @@ -8,6 +8,7 @@ Signal, Slot, ) +from typing_extensions import override from ert.gui.model.snapshot import IsEnsembleRole, IsRealizationRole, NodeRole @@ -19,7 +20,7 @@ def __init__( iter_: int, ) -> None: super().__init__(parent=parent) - self._iter = iter_ + self._iter: int = iter_ def get_iter(self) -> int: return self._iter @@ -30,7 +31,7 @@ def get_iter(self) -> int: def setIter(self, iter_: int) -> None: self._disconnect() self.modelAboutToBeReset.emit() - self._iter: int = iter_ + self._iter = iter_ self.modelReset.emit() self._connect() self.iter_changed.emit(iter_) @@ -59,7 +60,8 @@ def _connect(self) -> None: source_model.modelAboutToBeReset.connect(self.modelAboutToBeReset) source_model.modelReset.connect(self.modelReset) - def setSourceModel(self, sourceModel: QAbstractItemModel) -> None: + @override + def setSourceModel(self, sourceModel: Optional[QAbstractItemModel]) -> None: if not sourceModel: raise ValueError("need source model") self.beginResetModel() @@ -68,31 +70,43 @@ def setSourceModel(self, sourceModel: QAbstractItemModel) -> None: self._connect() self.endResetModel() - def columnCount(self, parent: QModelIndex = None) -> int: + @override + def columnCount(self, parent: Optional[QModelIndex] = None) -> int: if parent is None: parent = QModelIndex() if parent.isValid(): return 0 - iter_index = self.sourceModel().index(self._iter, 0, QModelIndex()) + source_model = self.sourceModel() + assert source_model is not None + iter_index = source_model.index(self._iter, 0, QModelIndex()) if not iter_index.isValid(): return 0 - return self.sourceModel().columnCount(iter_index) + return source_model.columnCount(iter_index) - def rowCount(self, parent: QModelIndex = None) -> int: + def rowCount(self, parent: Optional[QModelIndex] = None) -> int: if parent is None: parent = QModelIndex() if parent.isValid(): return 0 - iter_index = self.sourceModel().index(self._iter, 0, QModelIndex()) + source_model = self.sourceModel() + assert source_model is not None + iter_index = source_model.index(self._iter, 0, QModelIndex()) if not iter_index.isValid(): return 0 - return self.sourceModel().rowCount(iter_index) - - @staticmethod - def parent(_index: QModelIndex): + return source_model.rowCount(iter_index) + + @overload + def parent(self, child: QModelIndex) -> QModelIndex: ... + @overload + def parent(self) -> Optional[QObject]: ... + @override + def parent(self, child: Optional[QModelIndex] = None) -> Optional[QObject]: return QModelIndex() - def index(self, row: int, column: int, parent: QModelIndex = None) -> QModelIndex: + @override + def index( + self, row: int, column: int, parent: Optional[QModelIndex] = None + ) -> QModelIndex: if parent is None: parent = QModelIndex() if parent.isValid(): @@ -101,18 +115,24 @@ def index(self, row: int, column: int, parent: QModelIndex = None) -> QModelInde ret_index = self.createIndex(row, column, real_index.data(NodeRole)) return ret_index - def hasChildren(self, parent: QModelIndex) -> bool: + @override + def hasChildren(self, parent: Optional[QModelIndex] = None) -> bool: # Reimplemented, since in the source model, the realizations have # children (i.e. valid indices.). Realizations do not have children in # this model. + if parent is None: + parent = QModelIndex() if parent.isValid(): return False - return self.sourceModel().hasChildren(self.mapToSource(parent)) + source_model = self.sourceModel() + assert source_model is not None + return source_model.hasChildren(self.mapToSource(parent)) def mapToSource(self, proxyIndex: QModelIndex) -> QModelIndex: if not proxyIndex.isValid(): return QModelIndex() sm = self.sourceModel() + assert sm is not None iter_index = sm.index(self._iter, 0, QModelIndex()) if not iter_index.isValid() or not sm.hasChildren(iter_index): return QModelIndex() diff --git a/src/ert/gui/model/snapshot.py b/src/ert/gui/model/snapshot.py index 1cf43f2b6c4..e92010196c5 100644 --- a/src/ert/gui/model/snapshot.py +++ b/src/ert/gui/model/snapshot.py @@ -2,11 +2,12 @@ import logging from collections import defaultdict from contextlib import ExitStack -from typing import Dict, Final, List, Mapping, Optional, Sequence, Union +from typing import Any, Dict, Final, List, Mapping, Optional, Sequence, Union, overload from dateutil import tz from qtpy.QtCore import QAbstractItemModel, QModelIndex, QObject, QSize, Qt, QVariant from qtpy.QtGui import QColor, QFont +from typing_extensions import override from ert.ensemble_evaluator import PartialSnapshot, Snapshot, state from ert.ensemble_evaluator import identifiers as ids @@ -315,8 +316,8 @@ def _add_snapshot(self, snapshot: Snapshot, iter_: int) -> None: self.root.add_child(snapshot_tree, node_id=iter_) self.rowsInserted.emit(parent, snapshot_tree.row(), snapshot_tree.row()) - @staticmethod - def columnCount(parent: QModelIndex = None): + @override + def columnCount(self, parent: Optional[QModelIndex] = None) -> int: if parent is None: parent = QModelIndex() parent_node = parent.internalPointer() @@ -333,7 +334,7 @@ def columnCount(parent: QModelIndex = None): count = len(COLUMNS[NodeType.JOB]) return count - def rowCount(self, parent: QModelIndex = None): + def rowCount(self, parent: Optional[QModelIndex] = None): if parent is None: parent = QModelIndex() parent_item = self.root if not parent.isValid() else parent.internalPointer() @@ -343,17 +344,23 @@ def rowCount(self, parent: QModelIndex = None): return len(parent_item.children) - def parent(self, index: QModelIndex): - if not index.isValid(): + @overload + def parent(self, child: QModelIndex) -> QModelIndex: ... + @overload + def parent(self) -> Optional[QObject]: ... + @override + def parent(self, child: Optional[QModelIndex] = None) -> Optional[QObject]: + if child is None or not child.isValid(): return QModelIndex() - parent_item = index.internalPointer().parent + parent_item = child.internalPointer().parent if parent_item == self.root: return QModelIndex() return self.createIndex(parent_item.row(), 0, parent_item) - def data(self, index: QModelIndex, role=Qt.DisplayRole): + @override + def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: if not index.isValid(): return QVariant() @@ -539,7 +546,10 @@ def _job_data(index: QModelIndex, node: ForwardModelStepNode, role: int): return QVariant() - def index(self, row: int, column: int, parent: QModelIndex = None) -> QModelIndex: + @override + def index( + self, row: int, column: int, parent: Optional[QModelIndex] = None + ) -> QModelIndex: if parent is None: parent = QModelIndex() if not self.hasIndex(row, column, parent): diff --git a/src/ert/gui/tools/plot/data_type_keys_list_model.py b/src/ert/gui/tools/plot/data_type_keys_list_model.py index 127142674e2..6876e4a31fd 100644 --- a/src/ert/gui/tools/plot/data_type_keys_list_model.py +++ b/src/ert/gui/tools/plot/data_type_keys_list_model.py @@ -1,7 +1,8 @@ -from typing import List, Optional +from typing import Any, List, Optional, overload -from qtpy.QtCore import QAbstractItemModel, QModelIndex, Qt +from qtpy.QtCore import QAbstractItemModel, QModelIndex, QObject, Qt from qtpy.QtGui import QColor, QIcon +from typing_extensions import override from ert.gui.tools.plot.plot_api import PlotApiKeyDefinition @@ -16,21 +17,30 @@ def __init__(self, keys: List[PlotApiKeyDefinition]): self._keys = keys self.__icon = QIcon("img:star_filled.svg") - def index(self, row, column, parent=None, *args, **kwargs): + @override + def index( + self, row: int, column: int, parent: Optional[QModelIndex] = None + ) -> QModelIndex: return self.createIndex(row, column) - @staticmethod - def parent(index=None): + @overload + def parent(self, child: QModelIndex) -> QModelIndex: ... + @overload + def parent(self) -> Optional[QObject]: ... + @override + def parent(self, child: Optional[QModelIndex] = None) -> Optional[QObject]: return QModelIndex() - def rowCount(self, parent=None, *args, **kwargs): + @override + def rowCount(self, parent: Optional[QModelIndex] = None) -> int: return len(self._keys) - @staticmethod - def columnCount(QModelIndex_parent=None, *args, **kwargs): + @override + def columnCount(self, parent: Optional[QModelIndex] = None) -> int: return 1 - def data(self, index, role=None): + @override + def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: assert isinstance(index, QModelIndex) if index.isValid():