Skip to content

Commit

Permalink
fix(model): Rework list filtering [WIP]
Browse files Browse the repository at this point in the history
TODO: Update `_ListFilter.__iter__` and `__contains__` with new logic

This commit reworks the code that handles 'by_X' style list filtering.
The new code can now also filter on attributes that contain lists, where
an element matches if a filtering target is contained (or not contained)
in the list.

This also introduces the 'by_class' filter, which uses a subclass check
to filter a list. It replaces 'by_type' wherever it was used to filter
on the class name. The 'by_type' special handling will eventually be
removed, so that it always filters on the 'type' attribute of the list
members; until that happens, a warning will be emitted when using
'by_type' to filter on the class name.
  • Loading branch information
Wuestengecko committed Dec 23, 2024
1 parent 316f1d0 commit e9ea787
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 27 deletions.
117 changes: 90 additions & 27 deletions src/capellambse/model/_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,7 @@ class ElementList(cabc.MutableSequence[T], t.Generic[T]):
"""Provides access to elements without affecting the underlying model."""

__slots__ = (
"_ElementList__legacy_by_type",
"_ElementList__mapkey",
"_ElementList__mapvalue",
"_elemclass",
Expand All @@ -990,6 +991,7 @@ def __init__(
*,
mapkey: str | None = None,
mapvalue: str | None = None,
legacy_by_type: bool = False,
) -> None:
assert None not in elements
self._model = model
Expand Down Expand Up @@ -1020,6 +1022,9 @@ def __init__(
else:
self.__mapkey = mapkey
self.__mapvalue = mapvalue
self.__legacy_by_type = legacy_by_type or isinstance(
self, MixedElementList
)

def __eq__(self, other: object) -> bool:
if not isinstance(other, cabc.Sequence):
Expand Down Expand Up @@ -1142,14 +1147,34 @@ def __contains__(self, obj: t.Any) -> bool:
return any(i == obj for i in self)

def __getattr__(self, attr: str) -> _ListFilter:
if self.__legacy_by_type and attr == "by_type":
if isinstance(self, ElementListCouplingMixin):
acc = type(self)._accessor
text = f"'by_type' on {acc._qualname}"
else:
text = "This 'by_type'"
text = (
f"{text} will soon change to filter"
" on the 'type' attribute of the contained elements,"
" change calls to use 'by_class' instead"
)
warnings.warn(text, UserWarning, stacklevel=2)
attr = "by_class"

if attr.startswith("by_"):
attr = attr[len("by_") :]
if attr in {"name", "uuid"}:
return _ListFilter(self, attr, single=True)
if attr == "class":
return _ListFilter(self, attr, case_insensitive=True)
return _ListFilter(self, attr)

if attr.startswith("exclude_") and attr.endswith("s"):
attr = attr[len("exclude_") : -len("s")]
if attr == "classe":
return _ListFilter(
self, "class", positive=False, case_insensitive=True
)
return _ListFilter(self, attr, positive=False)

return getattr(super(), attr)
Expand All @@ -1172,6 +1197,8 @@ def filterable_attrs() -> cabc.Iterator[str]:

attrs = list(super().__dir__())
attrs.extend(filterable_attrs())
if self.__legacy_by_type:
attrs.extend(("by_type", "exclude_types"))
return attrs

def __repr__(self) -> str: # pragma: no cover
Expand Down Expand Up @@ -1404,7 +1431,7 @@ def extend(self, values: cabc.Iterable[t.Any]) -> None: ...
class _ListFilter(t.Generic[T]):
"""Filters this list based on an extractor function."""

__slots__ = ("_attr", "_parent", "_positive", "_single")
__slots__ = ("_attr", "_lower", "_parent", "_positive", "_single")

def __init__(
self,
Expand All @@ -1413,6 +1440,7 @@ def __init__(
*,
positive: bool = True,
single: bool = False,
case_insensitive: bool = False,
) -> None:
"""Create a filter object.
Expand All @@ -1436,11 +1464,14 @@ def __init__(
instead. If multiple elements match, it is an error; if
none match, a ``KeyError`` is raised. Can be overridden
at call time.
case_insensitive
Use case-insensitive matching.
"""
self._attr = attr
self._parent = parent
self._positive = positive
self._single = single
self._lower = case_insensitive

def extract_key(self, element: T) -> t.Any:
extractor = operator.attrgetter(self._attr)
Expand Down Expand Up @@ -1490,22 +1521,72 @@ def __call__(
"""
if single is None:
single = self._single
valueset = self.make_values_container(*values)
indices = []
elements = []
for i, elm in enumerate(self._parent):
if self.ismatch(elm, valueset):
indices.append(i)
elements.append(self._parent._elements[i])

if ".class." in self._attr or self._attr.startswith("class."):
raise ValueError(
"'class' must be the last component of the filter attribute"
)

valueset: tuple[t.Any, ...]
if self._attr == "class" or self._attr.endswith(".class"):
valueset = tuple(
v
if isinstance(v, type)
else self._parent._model.resolve_class(v)
for v in values
)

def ismatch(o: t.Any) -> bool:
return any(issubclass(o, v) for v in valueset)
else:
if self._lower:
valueset = tuple(
value.lower() if isinstance(value, str) else value
for value in values
)
else:
valueset = values

def ismatch(o: t.Any) -> bool:
if isinstance(o, cabc.Iterable):
return any(i in valueset for i in o)
return o in valueset

candidates: list[tuple[etree._Element, t.Any]] = [
(i._element, i) for i in self._parent
]
for attr in self._attr.split("."):
if not attr:
raise ValueError(f"Invalid filter attribute {self._attr!r}")
if attr == "class":
candidates = [(e, type(i)) for e, i in candidates]
else:
next_candidates: list[tuple[etree._Element, t.Any]] = []
for e, o in candidates:
if isinstance(o, cabc.Iterable):
o = [getattr(c, attr) for c in o if hasattr(c, attr)]
else:
try:
o = getattr(o, attr)
except AttributeError:
continue
next_candidates.append((e, o))
candidates = next_candidates

elements: list[etree._Element] = []
for e, o in candidates:
if self._positive == ismatch(o):
elements.append(e)

if not single:
return self._parent._newlist(elements)

if len(elements) > 1:
value = values[0] if len(values) == 1 else values
raise KeyError(f"Multiple matches for {value!r}")
if len(elements) == 0:
raise KeyError(values[0] if len(values) == 1 else values)
return self._parent[indices[0]] # Ensure proper construction
return wrap_xml(self._parent._model, elements[0])

def __iter__(self) -> cabc.Iterator[t.Any]:
"""Yield values that result in a non-empty list when filtered for.
Expand Down Expand Up @@ -1540,16 +1621,6 @@ def __getattr__(self, attr: str) -> te.Self:
)


class _LowercaseListFilter(_ListFilter[T], t.Generic[T]):
def extract_key(self, element: T) -> t.Any:
value = super().extract_key(element)
assert isinstance(value, str)
return value.lower()

def make_values_container(self, *values: t.Any) -> cabc.Iterable[t.Any]:
return tuple(map(operator.methodcaller("lower"), values))


class CachedElementList(ElementList[T], t.Generic[T]):
"""An ElementList that caches the constructed proxies by UUID."""

Expand Down Expand Up @@ -1625,14 +1696,6 @@ def __init__(
del elemclass
super().__init__(model, elements, None, **kw)

def __getattr__(self, attr: str) -> _ListFilter[ModelElement]:
if attr == "by_type":
return _LowercaseListFilter(self, "__class__.__name__")
return super().__getattr__(attr)

def __dir__(self) -> list[str]: # pragma: no cover
return [*super().__dir__(), "by_type", "exclude_types"]


class ElementListMapKeyView(cabc.Sequence):
def __init__(self, parent, /) -> None:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_model_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,59 @@ def test_ElementList_filter_by_type(model: m.MelodyModel):
assert diags[0].type is m.DiagramType.OCB


@pytest.mark.parametrize(
"filter_arg",
[
pytest.param(mm.oa.Entity, id="type-object"),
pytest.param("Entity", id="type-name"),
pytest.param((mm.oa.NS, "Entity"), id="classname-tuple"),
],
)
def test_filtering_lists_by_the_only_contained_class_doesnt_change_the_content(
model: m.MelodyModel, filter_arg
) -> None:
pkg = model.oa.entity_pkg
assert pkg is not None
base = pkg.entities
base_ids = [i.uuid for i in base]
assert all(type(i) is mm.oa.Entity for i in base)

filtered = base.by_class(filter_arg)
filtered_ids = [i.uuid for i in filtered]

assert filtered_ids == base_ids


def test_filtering_dotted_names_filters_on_nested_attributes(
model: m.MelodyModel,
) -> None:
base = model.la.all_component_exchanges
assert len(base) > 1
expected = [
"c31491db-817d-44b3-a27c-67e9cc1e06a2", # Care
]

filtered = base.by_target.parent.name("Whomping Willow")

assert isinstance(filtered, m.ElementList)
found = [i.uuid for i in filtered]
assert found == expected


def test_filtering_on_list_attributes_returns_match_if_any_member_matches(
model: m.MelodyModel,
) -> None:
base = model.la.all_components
willow = model.by_uuid("3bdd4fa2-5646-44a1-9fa6-80c68433ddb7")
expected = [willow.parent.uuid]

filtered = base.by_components(willow)

assert isinstance(filtered, m.ElementList)
found = [i.uuid for i in filtered]
assert found == expected


def test_ElementList_dictlike_getitem(model: m.MelodyModel):
obj = model.search("LogicalComponent").by_name("Whomping Willow")
assert isinstance(obj, mm.la.LogicalComponent)
Expand Down

0 comments on commit e9ea787

Please sign in to comment.