Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
lpthong90 committed Feb 21, 2024
1 parent eef95d5 commit 4597802
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 93 deletions.
19 changes: 10 additions & 9 deletions py_simple_trees/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .core import Node # noqa: F401
from .core import GenericTree # noqa: F401
from .core import BinaryNode # noqa: F401
from .core import BinaryTree # noqa: F401
from .core import TraversalType # noqa: F401
from .core import BinarySearchTree # noqa: F401
from .core import BSTree # noqa: F401
from .core import AVLTree # noqa: F401
from .core import AVLNode # noqa: F401
from .node import Node # noqa: F401
from .node import BinaryNode # noqa: F401
from .node import AVLNode # noqa: F401

from .tree import GenericTree # noqa: F401
from .tree import BinaryTree # noqa: F401
from .tree import TraversalType # noqa: F401
from .tree import BinarySearchTree # noqa: F401
from .tree import BSTree # noqa: F401
from .tree import AVLTree # noqa: F401
60 changes: 60 additions & 0 deletions py_simple_trees/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from typing import TypeVar, Generic, Optional, List, Any


K = TypeVar("K")
V = TypeVar("V")


class Node(Generic[K, V]):
def __init__(self, key: K, value: Optional[V] = None):
self.key: K = key
self.value: Optional[V] = value
self.children: List[Any] = []


class BinaryNode(Node[K, V]):
def __init__(self, key: K, value: Optional[V] = None):
super().__init__(key, value)
self.children: List[BinaryNode | None] = [None, None]

@property
def left(self):
return self.children[0]

@left.setter
def left(self, node: Optional[BinaryNode]):
self.children[0] = node

@property
def right(self):
return self.children[1]

@right.setter
def right(self, node: Optional[BinaryNode]):
self.children[1] = node

@property
def min_node(self):
if self.left is None:
return self
return self.left.min_node

@property
def max_node(self):
if self.right is None:
return self
return self.right.max_node


class AVLNode(BinaryNode[K, V]):
def __init__(self, key: K, value: Optional[V] = None):
super().__init__(key, value)
self.height = 1
self.left_height = 0
self.right_height = 0

@property
def balance(self):
return self.left_height - self.right_height
133 changes: 49 additions & 84 deletions py_simple_trees/core.py → py_simple_trees/tree.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations

from typing import TypeVar, Generic, Optional, List, Any
from typing import TypeVar, Generic, Optional
from enum import Enum

from py_simple_trees.node import Node, BinaryNode, AVLNode

K = TypeVar("K")
V = TypeVar("V")

N = TypeVar("N", bound=Node)
BN = TypeVar("BN", bound=BinaryNode)
AVLBN = TypeVar("AVLBN", bound=AVLNode)


class NodeExistedError(RuntimeError):
...
Expand All @@ -25,39 +31,6 @@ class TraversalType(Enum):
POST_ORDER = "post_order"


class Node(Generic[K, V]):
def __init__(self, key: K, value: V):
self.key: K = key
self.value: V = value
self.children: List[Any] = []


class BinaryNode(Node[K, V]):
def __init__(self, key: K, value: V):
super().__init__(key, value)
self.children: List[BinaryNode | None] = [None, None]

@property
def left(self):
return self.children[0]

@left.setter
def left(self, node: Optional[BinaryNode]):
self.children[0] = node

@property
def right(self):
return self.children[1]

@right.setter
def right(self, node: Optional[BinaryNode]):
self.children[1] = node


N = TypeVar("N", bound=Node)
BN = TypeVar("BN", bound=BinaryNode)


class GenericTree(Generic[K, V, N]):
def __init__(self, root: Optional[N] = None):
self.root: Optional[N] = root
Expand All @@ -71,6 +44,9 @@ def update(self, node: N):
def search(self, node: N) -> Optional[N]:
raise NotImplementedError

def remove(self, node: N):
raise NotImplementedError

def traversal(
self,
traversal_type: TraversalType = TraversalType.PRE_ORDER,
Expand Down Expand Up @@ -154,6 +130,9 @@ def _post_order_traversal(self, root: BN, reverse: bool = False):
yield node
yield root

def remove(self, node: BN):
return None

def print(self):
for node in self.traversal(traversal_type=TraversalType.PRE_ORDER):
if node.left is not None:
Expand Down Expand Up @@ -205,43 +184,29 @@ def _search(self, root: Optional[BN], node: BN) -> Optional[BN]:
return None
if root.key == node.key:
return root
return self._search(root.left, node.key) or self._search(root.right, node.key)

def remove(self, node: BN) -> bool:
return self._remove(self.root, node)
return self._search(root.left, node) or self._search(root.right, node)

def _remove(self, root: Optional[BN], node: BN) -> bool:
return False
def remove(self, node: BN):
self.root = self._remove(self.root, node)


BSTree = BinarySearchTree


class AVLNode(BinaryNode[K, V]):
def __init__(self, key: K, value: V):
super().__init__(key, value)
self.height = 1
self.left_height = 0
self.right_height = 0

@property
def balance(self):
return self.left_height - self.right_height

@property
def min_node(self):
if self.left is None:
return self
return self.left.min_node

@property
def max_node(self):
if self.right is None:
return self
return self.right.max_node


AVLBN = TypeVar("AVLBN", bound=AVLNode)
def _remove(self, root: Optional[BN], node: BN) -> Optional[BN]:
if root is None:
return None
if node.key < root.key:
root.left = self._remove(root.left, node)
return root
if node.key > root.key:
root.right = self._remove(root.right, node)
return root
if root.left is None:
return root.right
elif root.right is None:
return root.left
temp = root.right.min_node
root.key = temp.key
root.value = temp.value
root.right = self._remove(root.right, temp)
return root


class AVLTree(BinarySearchTree[K, V, AVLBN]):
Expand All @@ -253,9 +218,6 @@ def insert(self, node: AVLBN):
raise NodeTypeNotValidError
self.root = self._insert(self.root, node)

def delete(self, node: AVLBN):
self.root = self._delete_node(self.root, node)

def _insert(self, root: Optional[AVLBN], node: AVLBN) -> AVLBN:
if root is None:
return node
Expand All @@ -269,16 +231,19 @@ def _insert(self, root: Optional[AVLBN], node: AVLBN) -> AVLBN:
root.right_height = 0 if root.right is None else root.right.height

root.height = 1 + max(root.left_height, root.right_height)
return self._balance(root, node)
return self._balance(root)

def remove(self, node: AVLBN):
self.root = self._remove(self.root, node)

def _delete_node(self, root: Optional[AVLBN], node: AVLBN) -> Optional[AVLBN]:
def _remove(self, root: Optional[AVLBN], node: AVLBN) -> Optional[AVLBN]:
if root is None:
return None
elif node.key < root.key:
root.left = self._delete_node(root.left, node)
root.left = self._remove(root.left, node)
root.left_height = 0 if root.left is None else root.left.height
elif node.key > root.key:
root.right = self._delete_node(root.right, node)
root.right = self._remove(root.right, node)
root.right_height = 0 if root.right is None else root.right.height
else:
if root.left is None:
Expand All @@ -288,26 +253,23 @@ def _delete_node(self, root: Optional[AVLBN], node: AVLBN) -> Optional[AVLBN]:
temp = root.right.min_node
root.key = temp.key
root.value = temp.value
root.right = self._delete_node(root.right, temp)
root.right = self._remove(root.right, temp)
root.right_height = 0 if root.right is None else root.right.height

return self._balance(root, node)

def _balance(self, root: AVLBN, node: AVLBN):
if root is None:
return
return self._balance(root)

def _balance(self, root: AVLBN) -> AVLBN:
balance_factor = root.balance

if balance_factor > 1:
if node.key < root.left.key:
if root.left.balance >= 0:
return self._right_rotate(root)
else:
root.left = self._left_rotate(root.left)
return self._right_rotate(root)

if balance_factor < -1:
if node.key > root.right.key:
if root.right.balance <= 0:
return self._left_rotate(root)
else:
root.right = self._right_rotate(root.right)
Expand Down Expand Up @@ -351,3 +313,6 @@ def print(self):
print(node.key, "--L-->", node.left.key)
if node.right is not None:
print(node.key, "--R-->", node.right.key)


BSTree = BinarySearchTree
24 changes: 24 additions & 0 deletions tests/test_avl_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,27 @@ def test_post_order_traversal_and_reverse():
tree.traversal(traversal_type=TraversalType.POST_ORDER, reverse=True),
)
)


def test_remove():
tree = build_avl_tree()

tmp_node = AVLNode(key=4)
tree.remove(tmp_node)
node = tree.search(tmp_node)
assert node is None

tmp_node = AVLNode(key=6)
tree.remove(tmp_node)
node = tree.search(tmp_node)
assert node is None

tmp_node = AVLNode(key=7)
tree.remove(tmp_node)
node = tree.search(tmp_node)
assert node is None

assert tree.root.key == 2
assert tree.root.left.key == 1
assert tree.root.right.key == 5
assert tree.root.right.left.key == 3
28 changes: 28 additions & 0 deletions tests/test_binary_search_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,31 @@ def test_post_order_traversal_and_reverse():
tree.traversal(traversal_type=TraversalType.POST_ORDER, reverse=True),
)
)


def test_update():
tree = build_bstree()

tree.update(BinaryNode(key=4, value=-4))
assert tree.root.key == 4
assert tree.root.value == -4


def test_search():
tree = build_bstree()

node = tree.search(BinaryNode(key=4))

assert node is not None
assert node.key == 4
assert node.value == 4


def test_remove():
tree = build_bstree()

tmp_node = BinaryNode(key=4)
tree.remove(tmp_node)
node = tree.search(tmp_node)

assert node is None

0 comments on commit 4597802

Please sign in to comment.