diff --git a/py_simple_trees/__init__.py b/py_simple_trees/__init__.py index c462d02..6002e53 100644 --- a/py_simple_trees/__init__.py +++ b/py_simple_trees/__init__.py @@ -1,9 +1,10 @@ -from .core import Node # noqa: F401 -from .core import GenericTree # noqa: F401 -from .core import BinaryNode # noqa: F401 -from .core import BinaryTree # noqa: F401 -from .core import TraversalType # noqa: F401 -from .core import BinarySearchTree # noqa: F401 -from .core import BSTree # noqa: F401 -from .core import AVLTree # noqa: F401 -from .core import AVLNode # noqa: F401 +from .node import Node # noqa: F401 +from .node import BinaryNode # noqa: F401 +from .node import AVLNode # noqa: F401 + +from .tree import GenericTree # noqa: F401 +from .tree import BinaryTree # noqa: F401 +from .tree import TraversalType # noqa: F401 +from .tree import BinarySearchTree # noqa: F401 +from .tree import BSTree # noqa: F401 +from .tree import AVLTree # noqa: F401 diff --git a/py_simple_trees/node.py b/py_simple_trees/node.py new file mode 100644 index 0000000..36576b7 --- /dev/null +++ b/py_simple_trees/node.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import TypeVar, Generic, Optional, List, Any + + +K = TypeVar("K") +V = TypeVar("V") + + +class Node(Generic[K, V]): + def __init__(self, key: K, value: Optional[V] = None): + self.key: K = key + self.value: Optional[V] = value + self.children: List[Any] = [] + + +class BinaryNode(Node[K, V]): + def __init__(self, key: K, value: Optional[V] = None): + super().__init__(key, value) + self.children: List[BinaryNode | None] = [None, None] + + @property + def left(self): + return self.children[0] + + @left.setter + def left(self, node: Optional[BinaryNode]): + self.children[0] = node + + @property + def right(self): + return self.children[1] + + @right.setter + def right(self, node: Optional[BinaryNode]): + self.children[1] = node + + @property + def min_node(self): + if self.left is None: + return self + return self.left.min_node + + @property + def max_node(self): + if self.right is None: + return self + return self.right.max_node + + +class AVLNode(BinaryNode[K, V]): + def __init__(self, key: K, value: Optional[V] = None): + super().__init__(key, value) + self.height = 1 + self.left_height = 0 + self.right_height = 0 + + @property + def balance(self): + return self.left_height - self.right_height diff --git a/py_simple_trees/core.py b/py_simple_trees/tree.py similarity index 80% rename from py_simple_trees/core.py rename to py_simple_trees/tree.py index c24356b..b0e0a75 100644 --- a/py_simple_trees/core.py +++ b/py_simple_trees/tree.py @@ -1,11 +1,17 @@ from __future__ import annotations -from typing import TypeVar, Generic, Optional, List, Any +from typing import TypeVar, Generic, Optional from enum import Enum +from py_simple_trees.node import Node, BinaryNode, AVLNode + K = TypeVar("K") V = TypeVar("V") +N = TypeVar("N", bound=Node) +BN = TypeVar("BN", bound=BinaryNode) +AVLBN = TypeVar("AVLBN", bound=AVLNode) + class NodeExistedError(RuntimeError): ... @@ -25,39 +31,6 @@ class TraversalType(Enum): POST_ORDER = "post_order" -class Node(Generic[K, V]): - def __init__(self, key: K, value: V): - self.key: K = key - self.value: V = value - self.children: List[Any] = [] - - -class BinaryNode(Node[K, V]): - def __init__(self, key: K, value: V): - super().__init__(key, value) - self.children: List[BinaryNode | None] = [None, None] - - @property - def left(self): - return self.children[0] - - @left.setter - def left(self, node: Optional[BinaryNode]): - self.children[0] = node - - @property - def right(self): - return self.children[1] - - @right.setter - def right(self, node: Optional[BinaryNode]): - self.children[1] = node - - -N = TypeVar("N", bound=Node) -BN = TypeVar("BN", bound=BinaryNode) - - class GenericTree(Generic[K, V, N]): def __init__(self, root: Optional[N] = None): self.root: Optional[N] = root @@ -71,6 +44,9 @@ def update(self, node: N): def search(self, node: N) -> Optional[N]: raise NotImplementedError + def remove(self, node: N): + raise NotImplementedError + def traversal( self, traversal_type: TraversalType = TraversalType.PRE_ORDER, @@ -154,6 +130,9 @@ def _post_order_traversal(self, root: BN, reverse: bool = False): yield node yield root + def remove(self, node: BN): + return None + def print(self): for node in self.traversal(traversal_type=TraversalType.PRE_ORDER): if node.left is not None: @@ -205,43 +184,29 @@ def _search(self, root: Optional[BN], node: BN) -> Optional[BN]: return None if root.key == node.key: return root - return self._search(root.left, node.key) or self._search(root.right, node.key) - - def remove(self, node: BN) -> bool: - return self._remove(self.root, node) + return self._search(root.left, node) or self._search(root.right, node) - def _remove(self, root: Optional[BN], node: BN) -> bool: - return False + def remove(self, node: BN): + self.root = self._remove(self.root, node) - -BSTree = BinarySearchTree - - -class AVLNode(BinaryNode[K, V]): - def __init__(self, key: K, value: V): - super().__init__(key, value) - self.height = 1 - self.left_height = 0 - self.right_height = 0 - - @property - def balance(self): - return self.left_height - self.right_height - - @property - def min_node(self): - if self.left is None: - return self - return self.left.min_node - - @property - def max_node(self): - if self.right is None: - return self - return self.right.max_node - - -AVLBN = TypeVar("AVLBN", bound=AVLNode) + def _remove(self, root: Optional[BN], node: BN) -> Optional[BN]: + if root is None: + return None + if node.key < root.key: + root.left = self._remove(root.left, node) + return root + if node.key > root.key: + root.right = self._remove(root.right, node) + return root + if root.left is None: + return root.right + elif root.right is None: + return root.left + temp = root.right.min_node + root.key = temp.key + root.value = temp.value + root.right = self._remove(root.right, temp) + return root class AVLTree(BinarySearchTree[K, V, AVLBN]): @@ -253,9 +218,6 @@ def insert(self, node: AVLBN): raise NodeTypeNotValidError self.root = self._insert(self.root, node) - def delete(self, node: AVLBN): - self.root = self._delete_node(self.root, node) - def _insert(self, root: Optional[AVLBN], node: AVLBN) -> AVLBN: if root is None: return node @@ -269,16 +231,19 @@ def _insert(self, root: Optional[AVLBN], node: AVLBN) -> AVLBN: root.right_height = 0 if root.right is None else root.right.height root.height = 1 + max(root.left_height, root.right_height) - return self._balance(root, node) + return self._balance(root) + + def remove(self, node: AVLBN): + self.root = self._remove(self.root, node) - def _delete_node(self, root: Optional[AVLBN], node: AVLBN) -> Optional[AVLBN]: + def _remove(self, root: Optional[AVLBN], node: AVLBN) -> Optional[AVLBN]: if root is None: return None elif node.key < root.key: - root.left = self._delete_node(root.left, node) + root.left = self._remove(root.left, node) root.left_height = 0 if root.left is None else root.left.height elif node.key > root.key: - root.right = self._delete_node(root.right, node) + root.right = self._remove(root.right, node) root.right_height = 0 if root.right is None else root.right.height else: if root.left is None: @@ -288,26 +253,23 @@ def _delete_node(self, root: Optional[AVLBN], node: AVLBN) -> Optional[AVLBN]: temp = root.right.min_node root.key = temp.key root.value = temp.value - root.right = self._delete_node(root.right, temp) + root.right = self._remove(root.right, temp) root.right_height = 0 if root.right is None else root.right.height - return self._balance(root, node) - - def _balance(self, root: AVLBN, node: AVLBN): - if root is None: - return + return self._balance(root) + def _balance(self, root: AVLBN) -> AVLBN: balance_factor = root.balance if balance_factor > 1: - if node.key < root.left.key: + if root.left.balance >= 0: return self._right_rotate(root) else: root.left = self._left_rotate(root.left) return self._right_rotate(root) if balance_factor < -1: - if node.key > root.right.key: + if root.right.balance <= 0: return self._left_rotate(root) else: root.right = self._right_rotate(root.right) @@ -351,3 +313,6 @@ def print(self): print(node.key, "--L-->", node.left.key) if node.right is not None: print(node.key, "--R-->", node.right.key) + + +BSTree = BinarySearchTree diff --git a/tests/test_avl_tree.py b/tests/test_avl_tree.py index 070f261..60e517b 100644 --- a/tests/test_avl_tree.py +++ b/tests/test_avl_tree.py @@ -85,3 +85,27 @@ def test_post_order_traversal_and_reverse(): tree.traversal(traversal_type=TraversalType.POST_ORDER, reverse=True), ) ) + + +def test_remove(): + tree = build_avl_tree() + + tmp_node = AVLNode(key=4) + tree.remove(tmp_node) + node = tree.search(tmp_node) + assert node is None + + tmp_node = AVLNode(key=6) + tree.remove(tmp_node) + node = tree.search(tmp_node) + assert node is None + + tmp_node = AVLNode(key=7) + tree.remove(tmp_node) + node = tree.search(tmp_node) + assert node is None + + assert tree.root.key == 2 + assert tree.root.left.key == 1 + assert tree.root.right.key == 5 + assert tree.root.right.left.key == 3 diff --git a/tests/test_binary_search_tree.py b/tests/test_binary_search_tree.py index 372688d..ddd7672 100644 --- a/tests/test_binary_search_tree.py +++ b/tests/test_binary_search_tree.py @@ -92,3 +92,31 @@ def test_post_order_traversal_and_reverse(): tree.traversal(traversal_type=TraversalType.POST_ORDER, reverse=True), ) ) + + +def test_update(): + tree = build_bstree() + + tree.update(BinaryNode(key=4, value=-4)) + assert tree.root.key == 4 + assert tree.root.value == -4 + + +def test_search(): + tree = build_bstree() + + node = tree.search(BinaryNode(key=4)) + + assert node is not None + assert node.key == 4 + assert node.value == 4 + + +def test_remove(): + tree = build_bstree() + + tmp_node = BinaryNode(key=4) + tree.remove(tmp_node) + node = tree.search(tmp_node) + + assert node is None