Skip to content

Commit

Permalink
Fix signature of overrides from QAbstractItemModel
Browse files Browse the repository at this point in the history
Also fixes some typing issues in ert.gui.model so that type checking can
be turned on for parts of it.
  • Loading branch information
eivindjahren authored Jun 14, 2024
1 parent 2b21402 commit 77fa511
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 84 deletions.
7 changes: 6 additions & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@ ignore_missing_imports = True
ignore_missing_imports = True
ignore_errors = True

[mypy-ert.gui.model.*]
[mypy-ert.gui.model.node]
ignore_missing_imports = True
ignore_errors = True

[mypy-ert.gui.model.snapshot]
ignore_missing_imports = True
ignore_errors = True


[mypy-ert.gui.simulation.*]
ignore_missing_imports = True
ignore_errors = True
Expand Down
44 changes: 32 additions & 12 deletions src/ert/gui/ertwidgets/models/storage_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from enum import IntEnum
from typing import Any, List
from typing import Any, List, Optional, overload
from uuid import UUID

import humanize
from qtpy.QtCore import (
QAbstractItemModel,
QModelIndex,
QObject,
Qt,
Slot,
)
from qtpy.QtWidgets import QApplication
from typing_extensions import override

from ert.storage import Ensemble, Experiment, Storage

Expand Down Expand Up @@ -165,45 +167,63 @@ def _load_storage(self, storage: Storage) -> None:

self._children.append(ex)

@staticmethod
def columnCount(parent: QModelIndex) -> int:
@override
def columnCount(self, parent: Optional[QModelIndex] = None) -> int:
return _NUM_COLUMNS

def rowCount(self, parent: QModelIndex) -> int:
@override
def rowCount(self, parent: Optional[QModelIndex] = None) -> int:
if parent is None:
parent = QModelIndex()
if parent.isValid():
if isinstance(parent.internalPointer(), RealizationModel):
return 0
return len(parent.internalPointer()._children)
else:
return len(self._children)

def parent(self, index: QModelIndex) -> QModelIndex:
if not index.isValid():
@overload
def parent(self, child: QModelIndex) -> QModelIndex: ...
@overload
def parent(self) -> Optional[QObject]: ...
@override
def parent(self, child: Optional[QModelIndex] = None) -> Optional[QObject]:
if child is None or not child.isValid():
return QModelIndex()

child_item = index.internalPointer()
child_item = child.internalPointer()
parentItem = child_item._parent

if parentItem == self:
return QModelIndex()

return self.createIndex(parentItem.row(), 0, parentItem)

@staticmethod
def headerData(section: int, orientation: int, role: int) -> Any:
@override
def headerData(
self,
section: int,
orientation: Qt.Orientation,
role: int = Qt.ItemDataRole.DisplayRole,
) -> Any:
if role != Qt.ItemDataRole.DisplayRole:
return None

return _COLUMN_TEXT[_Column(section)]

@staticmethod
def data(index: QModelIndex, role=Qt.ItemDataRole.DisplayRole) -> Any:
@override
def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any:
if not index.isValid():
return None

return index.internalPointer().data(index, role)

def index(self, row: int, column: int, parent: QModelIndex) -> QModelIndex:
@override
def index(
self, row: int, column: int, parent: Optional[QModelIndex] = None
) -> QModelIndex:
if parent is None:
parent = QModelIndex()
parentItem = parent.internalPointer() if parent.isValid() else self
try:
childItem = parentItem._children[row]
Expand Down
47 changes: 32 additions & 15 deletions src/ert/gui/model/job_list.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any, List, Optional, overload

from qtpy.QtCore import (
QAbstractItemModel,
Expand All @@ -9,6 +9,7 @@
QVariant,
Slot,
)
from typing_extensions import override

from ert.ensemble_evaluator import identifiers as ids
from ert.gui.model.node import NodeType
Expand Down Expand Up @@ -70,7 +71,8 @@ def _get_source_parent_index(self) -> QModelIndex:
source_parent = self.mapToSource(start).parent()
return source_parent

def setSourceModel(self, sourceModel: QAbstractItemModel) -> None:
@override
def setSourceModel(self, sourceModel: Optional[QAbstractItemModel]) -> None:
if not sourceModel:
raise ValueError("need source model")
self.beginResetModel()
Expand All @@ -79,40 +81,54 @@ def setSourceModel(self, sourceModel: QAbstractItemModel) -> None:
self._connect()
self.endResetModel()

@staticmethod
def headerData(section: int, orientation: Qt.Orientation, role: Qt.UserRole) -> Any:
if role != Qt.DisplayRole:
@override
def headerData(
self,
section: int,
orientation: Qt.Orientation,
role: int = Qt.ItemDataRole.DisplayRole,
) -> Any:
if role != Qt.ItemDataRole.DisplayRole:
return QVariant()
if orientation == Qt.Horizontal:
if orientation == Qt.Orientation.Horizontal:
header = COLUMNS[NodeType.REAL][section]
if header in [ids.STDOUT, ids.STDERR]:
return header.upper()
if header in [ids.CURRENT_MEMORY_USAGE, ids.MAX_MEMORY_USAGE]:
header = header.replace("_", " ")
return header.capitalize()
if orientation == Qt.Vertical:
if orientation == Qt.Orientation.Vertical:
return section
return QVariant()

@staticmethod
def columnCount(parent: QModelIndex = None):
@override
def columnCount(self, parent: Optional[QModelIndex] = None) -> int:
return len(COLUMNS[NodeType.REAL])

def rowCount(self, parent=None) -> int:
def rowCount(self, parent: Optional[QModelIndex] = None) -> int:
if parent is None:
parent = QModelIndex()
if parent.isValid():
return 0
source_index = self._get_source_parent_index()
if not source_index.isValid():
return 0
return self.sourceModel().rowCount(source_index)

@staticmethod
def parent(_index: QModelIndex):
source_model = self.sourceModel()
assert source_model is not None
return source_model.rowCount(source_index)

@overload
def parent(self, child: QModelIndex) -> QModelIndex: ...
@overload
def parent(self) -> Optional[QObject]: ...
@override
def parent(self, child: Optional[QModelIndex] = None) -> Optional[QObject]:
return QModelIndex()

def index(self, row: int, column: int, parent=None) -> QModelIndex:
@override
def index(
self, row: int, column: int, parent: Optional[QModelIndex] = None
) -> QModelIndex:
if parent is None:
parent = QModelIndex()
if parent.isValid():
Expand All @@ -124,6 +140,7 @@ def mapToSource(self, proxyIndex: QModelIndex) -> QModelIndex:
if not proxyIndex.isValid():
return QModelIndex()
source_model = self.sourceModel()
assert source_model is not None
iter_index = source_model.index(self._iter, 0, QModelIndex())
if not iter_index.isValid() or not source_model.hasChildren(iter_index):
return QModelIndex()
Expand Down
63 changes: 41 additions & 22 deletions src/ert/gui/model/progress_proxy.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from collections import defaultdict
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, overload

from qtpy.QtCore import QAbstractItemModel, QModelIndex, QSize, Qt, QVariant
from qtpy.QtCore import QAbstractItemModel, QModelIndex, QObject, QSize, Qt, QVariant
from qtpy.QtGui import QColor, QFont
from typing_extensions import override

from ert.gui.model.snapshot import IsEnsembleRole, ProgressRole, StatusRole


class ProgressProxyModel(QAbstractItemModel):
def __init__(
self, source_model: QAbstractItemModel, parent: QModelIndex = None
self, source_model: QAbstractItemModel, parent: Optional[QModelIndex] = None
) -> None:
QAbstractItemModel.__init__(self, parent)
self._source_model: QAbstractItemModel = source_model
self._progress: Optional[Dict[str, Union[dict, int]]] = None
self._progress: Optional[Dict[str, Union[dict[Any, Any], int]]] = None
self._connect()

def _connect(self) -> None:
Expand All @@ -28,66 +29,84 @@ def _connect(self) -> None:
if last_iter >= 0:
self._recalculate_progress(last_iter)

@staticmethod
def columnCount(parent: QModelIndex = None) -> int:
@override
def columnCount(self, parent: Optional[QModelIndex] = None) -> int:
if parent is None:
parent = QModelIndex()
if parent.isValid():
return 0
return 1

@staticmethod
def rowCount(parent: QModelIndex = None) -> int:
@override
def rowCount(self, parent: Optional[QModelIndex] = None) -> int:
if parent is None:
parent = QModelIndex()
if parent.isValid():
return 0
return 1

def index(self, row: int, column: int, parent: QModelIndex = None) -> QModelIndex:
@override
def index(
self, row: int, column: int, parent: Optional[QModelIndex] = None
) -> QModelIndex:
if parent is None:
parent = QModelIndex()
if parent.isValid():
return QModelIndex()
return self.createIndex(row, column, None)

@staticmethod
def parent(_index: QModelIndex) -> QModelIndex:
@overload
def parent(self, child: QModelIndex) -> QModelIndex: ...
@overload
def parent(self) -> Optional[QObject]: ...
@override
def parent(self, child: Optional[QModelIndex] = None) -> Optional[QObject]:
return QModelIndex()

@staticmethod
def hasChildren(parent: QModelIndex) -> bool:
@override
def hasChildren(self, parent: Optional[QModelIndex] = None) -> bool:
if parent is None:
return QModelIndex().isValid()
return not parent.isValid()

def data(self, index: QModelIndex, role=Qt.DisplayRole) -> QVariant:
@override
def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any:
if not index.isValid():
return QVariant()

if role == Qt.TextAlignmentRole:
return Qt.AlignCenter
if role == Qt.ItemDataRole.TextAlignmentRole:
return Qt.AlignmentFlag.AlignCenter

if role == ProgressRole:
return self._progress

if role in (Qt.StatusTipRole, Qt.WhatsThisRole, Qt.ToolTipRole):
if role in (
Qt.ItemDataRole.StatusTipRole,
Qt.ItemDataRole.WhatsThisRole,
Qt.ItemDataRole.ToolTipRole,
):
return ""

if role == Qt.SizeHintRole:
if role == Qt.ItemDataRole.SizeHintRole:
return QSize(30, 30)

if role == Qt.FontRole:
if role == Qt.ItemDataRole.FontRole:
return QFont()

if role in (Qt.BackgroundRole, Qt.ForegroundRole, Qt.DecorationRole):
if role in (
Qt.ItemDataRole.BackgroundRole,
Qt.ItemDataRole.ForegroundRole,
Qt.ItemDataRole.DecorationRole,
):
return QColor()

if role == Qt.DisplayRole:
if role == Qt.ItemDataRole.DisplayRole:
return ""

return QVariant()

def _recalculate_progress(self, iter_: int) -> None:
status_counts = defaultdict(int)
status_counts: Dict[Any, int] = defaultdict(int)
nr_reals: int = 0
current_iter_index = self._source_model.index(iter_, 0, QModelIndex())
if current_iter_index.internalPointer() is None:
Expand Down
Loading

0 comments on commit 77fa511

Please sign in to comment.