Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: cache results of cnf, dnf and _merge_single_markers #609

Merged
merged 2 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/poetry/core/constraints/generic/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def intersect(self, other: BaseConstraint) -> BaseConstraint:
return other.intersect(self)

def union(self, other: BaseConstraint) -> BaseConstraint:
if isinstance(other, Constraint):
from poetry.core.constraints.generic.union_constraint import UnionConstraint
from poetry.core.constraints.generic.union_constraint import UnionConstraint

if isinstance(other, Constraint):
if other == self:
return self

Expand All @@ -140,6 +140,10 @@ def union(self, other: BaseConstraint) -> BaseConstraint:

return AnyConstraint()

# to preserve order (functionally not necessary)
if isinstance(other, UnionConstraint):
return UnionConstraint(self).union(other)

return other.union(self)

def is_any(self) -> bool:
Expand Down
8 changes: 2 additions & 6 deletions src/poetry/core/constraints/generic/multi_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,10 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, MultiConstraint):
return False

return set(self._constraints) == set(other._constraints)
return self._constraints == other._constraints

def __hash__(self) -> int:
h = hash("multi")
for constraint in self._constraints:
h ^= hash(constraint)

return h
return hash(("multi", *self._constraints))

def __str__(self) -> str:
constraints = [str(constraint) for constraint in self._constraints]
Expand Down
20 changes: 11 additions & 9 deletions src/poetry/core/constraints/generic/union_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,19 @@ def union(self, other: BaseConstraint) -> BaseConstraint:
our_new_constraints: list[BaseConstraint] = []
their_new_constraints: list[BaseConstraint] = []
merged_new_constraints: list[BaseConstraint] = []
for our_constraint in self._constraints:
for their_constraint in other.constraints:
for their_constraint in other.constraints:
for our_constraint in self._constraints:
union = our_constraint.union(their_constraint)
if union.is_any():
return AnyConstraint()
if isinstance(union, Constraint):
if union not in merged_new_constraints:
if union == our_constraint:
if union not in our_new_constraints:
our_new_constraints.append(union)
elif union == their_constraint:
if union not in their_new_constraints:
their_new_constraints.append(their_constraint)
elif union not in merged_new_constraints:
merged_new_constraints.append(union)
else:
if our_constraint not in our_new_constraints:
Expand Down Expand Up @@ -169,14 +175,10 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, UnionConstraint):
return False

return set(self._constraints) == set(other._constraints)
return self._constraints == other._constraints

def __hash__(self) -> int:
h = hash("union")
for constraint in self._constraints:
h ^= hash(constraint)

return h
return hash(("union", *self._constraints))

def __str__(self) -> str:
constraints = [str(constraint) for constraint in self._constraints]
Expand Down
58 changes: 32 additions & 26 deletions src/poetry/core/version/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __repr__(self) -> str:
return "<AnyMarker>"

def __hash__(self) -> int:
return hash(("<any>", "<any>"))
return hash("any")

def __eq__(self, other: object) -> bool:
if not isinstance(other, BaseMarker):
Expand Down Expand Up @@ -193,7 +193,7 @@ def __repr__(self) -> str:
return "<EmptyMarker>"

def __hash__(self) -> int:
return hash(("<empty>", "<empty>"))
return hash("empty")

def __eq__(self, other: object) -> bool:
if not isinstance(other, BaseMarker):
Expand Down Expand Up @@ -232,6 +232,10 @@ def name(self) -> str:
def constraint(self) -> SingleMarkerConstraint:
return self._constraint

@property
def _key(self) -> tuple[object, ...]:
return self._name, self._constraint

def validate(self, environment: dict[str, Any] | None) -> bool:
if environment is None:
return True
Expand Down Expand Up @@ -284,10 +288,10 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, SingleMarkerLike):
return NotImplemented

return self._name == other.name and self._constraint == other.constraint
return self._key == other._key

def __hash__(self) -> int:
return hash((self._name, self._constraint))
return hash(self._key)


class SingleMarker(SingleMarkerLike[Union[BaseConstraint, VersionConstraint]]):
Expand Down Expand Up @@ -368,6 +372,10 @@ def operator(self) -> str:
def value(self) -> str:
return self._value

@property
def _key(self) -> tuple[object, ...]:
return self._name, self._operator, self._value

def invert(self) -> BaseMarker:
if self._operator in ("===", "=="):
operator = "!="
Expand Down Expand Up @@ -413,6 +421,15 @@ def invert(self) -> BaseMarker:

return parse_marker(f"{self._name} {operator} '{self._value}'")

def __eq__(self, other: object) -> bool:
if not isinstance(other, SingleMarker):
return NotImplemented

return self._key == other._key

def __hash__(self) -> int:
return hash(self._key)

def __str__(self) -> str:
return f'{self._name} {self._operator} "{self._value}"'

Expand Down Expand Up @@ -494,10 +511,10 @@ def _flatten_markers(

class MultiMarker(BaseMarker):
def __init__(self, *markers: BaseMarker) -> None:
self._markers = _flatten_markers(markers, MultiMarker)
self._markers = tuple(_flatten_markers(markers, MultiMarker))

@property
def markers(self) -> list[BaseMarker]:
def markers(self) -> tuple[BaseMarker, ...]:
return self._markers

@property
Expand Down Expand Up @@ -645,14 +662,10 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, MultiMarker):
return False

return set(self._markers) == set(other.markers)
return self._markers == other.markers

def __hash__(self) -> int:
h = hash("multi")
for m in self._markers:
h ^= hash(m)

return h
return hash(("multi", *self._markers))

def __str__(self) -> str:
elements = []
Expand All @@ -667,10 +680,10 @@ def __str__(self) -> str:

class MarkerUnion(BaseMarker):
def __init__(self, *markers: BaseMarker) -> None:
self._markers = _flatten_markers(markers, MarkerUnion)
self._markers = tuple(_flatten_markers(markers, MarkerUnion))

@property
def markers(self) -> list[BaseMarker]:
def markers(self) -> tuple[BaseMarker, ...]:
return self._markers

@property
Expand Down Expand Up @@ -735,12 +748,6 @@ def of(cls, *markers: BaseMarker) -> BaseMarker:

return MarkerUnion(*new_markers)

def append(self, marker: BaseMarker) -> None:
if marker in self._markers:
return

self._markers.append(marker)

def intersect(self, other: BaseMarker) -> BaseMarker:
return intersection(self, other)

Expand Down Expand Up @@ -825,14 +832,10 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, MarkerUnion):
return False

return set(self._markers) == set(other.markers)
return self._markers == other.markers

def __hash__(self) -> int:
h = hash("union")
for m in self._markers:
h ^= hash(m)

return h
return hash(("union", *self._markers))

def __str__(self) -> str:
return " or ".join(str(m) for m in self._markers)
Expand Down Expand Up @@ -898,6 +901,7 @@ def _compact_markers(
return union(*sub_markers)


@functools.lru_cache(maxsize=None)
def cnf(marker: BaseMarker) -> BaseMarker:
"""Transforms the marker into CNF (conjunctive normal form)."""
if isinstance(marker, MarkerUnion):
Expand All @@ -915,6 +919,7 @@ def cnf(marker: BaseMarker) -> BaseMarker:
return marker


@functools.lru_cache(maxsize=None)
def dnf(marker: BaseMarker) -> BaseMarker:
"""Transforms the marker into DNF (disjunctive normal form)."""
if isinstance(marker, MultiMarker):
Expand Down Expand Up @@ -957,6 +962,7 @@ def union(*markers: BaseMarker) -> BaseMarker:
return min(disjunction, conjunction, unnormalized, key=lambda x: x.complexity)


@functools.lru_cache(maxsize=None)
def _merge_single_markers(
marker1: SingleMarkerLike[SingleMarkerConstraint],
marker2: SingleMarkerLike[SingleMarkerConstraint],
Expand Down
77 changes: 56 additions & 21 deletions tests/constraints/generic/test_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def test_invert(constraint: BaseConstraint, inverted: BaseConstraint) -> None:
(
Constraint("win32", "!="),
Constraint("linux", "!="),
MultiConstraint(Constraint("win32", "!="), Constraint("linux", "!=")),
(
MultiConstraint(Constraint("win32", "!="), Constraint("linux", "!=")),
MultiConstraint(Constraint("linux", "!="), Constraint("win32", "!=")),
),
),
(
Constraint("win32", "!="),
Expand Down Expand Up @@ -222,21 +225,30 @@ def test_invert(constraint: BaseConstraint, inverted: BaseConstraint) -> None:
(
MultiConstraint(Constraint("win32", "!="), Constraint("linux", "!=")),
MultiConstraint(Constraint("win32", "!="), Constraint("darwin", "!=")),
MultiConstraint(
Constraint("win32", "!="),
Constraint("linux", "!="),
Constraint("darwin", "!="),
(
MultiConstraint(
Constraint("win32", "!="),
Constraint("linux", "!="),
Constraint("darwin", "!="),
),
MultiConstraint(
Constraint("win32", "!="),
Constraint("darwin", "!="),
Constraint("linux", "!="),
),
),
),
],
)
def test_intersect(
constraint1: BaseConstraint,
constraint2: BaseConstraint,
expected: BaseConstraint,
expected: BaseConstraint | tuple[BaseConstraint, BaseConstraint],
) -> None:
assert constraint1.intersect(constraint2) == expected
assert constraint2.intersect(constraint1) == expected
if not isinstance(expected, tuple):
expected = (expected, expected)
assert constraint1.intersect(constraint2) == expected[0]
assert constraint2.intersect(constraint1) == expected[1]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -295,7 +307,10 @@ def test_intersect(
(
Constraint("win32"),
Constraint("linux"),
UnionConstraint(Constraint("win32"), Constraint("linux")),
(
UnionConstraint(Constraint("win32"), Constraint("linux")),
UnionConstraint(Constraint("linux"), Constraint("win32")),
),
),
(
Constraint("win32"),
Expand Down Expand Up @@ -324,8 +339,13 @@ def test_intersect(
(
Constraint("win32"),
UnionConstraint(Constraint("linux"), Constraint("linux2")),
UnionConstraint(
Constraint("win32"), Constraint("linux"), Constraint("linux2")
(
UnionConstraint(
Constraint("win32"), Constraint("linux"), Constraint("linux2")
),
UnionConstraint(
Constraint("linux"), Constraint("linux2"), Constraint("win32")
),
),
),
(
Expand Down Expand Up @@ -366,8 +386,13 @@ def test_intersect(
(
UnionConstraint(Constraint("win32"), Constraint("linux")),
UnionConstraint(Constraint("win32"), Constraint("darwin")),
UnionConstraint(
Constraint("win32"), Constraint("linux"), Constraint("darwin")
(
UnionConstraint(
Constraint("win32"), Constraint("linux"), Constraint("darwin")
),
UnionConstraint(
Constraint("win32"), Constraint("darwin"), Constraint("linux")
),
),
),
(
Expand All @@ -377,11 +402,19 @@ def test_intersect(
UnionConstraint(
Constraint("win32"), Constraint("cygwin"), Constraint("darwin")
),
UnionConstraint(
Constraint("win32"),
Constraint("linux"),
Constraint("darwin"),
Constraint("cygwin"),
(
UnionConstraint(
Constraint("win32"),
Constraint("linux"),
Constraint("darwin"),
Constraint("cygwin"),
),
UnionConstraint(
Constraint("win32"),
Constraint("cygwin"),
Constraint("darwin"),
Constraint("linux"),
),
),
),
(
Expand Down Expand Up @@ -412,10 +445,12 @@ def test_intersect(
def test_union(
constraint1: BaseConstraint,
constraint2: BaseConstraint,
expected: BaseConstraint,
expected: BaseConstraint | tuple[BaseConstraint, BaseConstraint],
) -> None:
assert constraint1.union(constraint2) == expected
assert constraint2.union(constraint1) == expected
if not isinstance(expected, tuple):
expected = (expected, expected)
assert constraint1.union(constraint2) == expected[0]
assert constraint2.union(constraint1) == expected[1]


def test_difference() -> None:
Expand Down
Loading