From eaea3e82102203df2e1fece7cf8c915f1533f2e0 Mon Sep 17 00:00:00 2001 From: lpthong90 Date: Thu, 22 Feb 2024 02:02:55 +0700 Subject: [PATCH] refactor --- .gitignore | 3 +- helper.py | 18 +++--- order_book/__init__.py | 40 ++++++++------ tests/common.py | 92 ++++++++----------------------- tests/test_advanced_avl_tree_1.py | 51 ----------------- tests/test_avl_tree_1.py | 0 6 files changed, 58 insertions(+), 146 deletions(-) delete mode 100644 tests/test_advanced_avl_tree_1.py delete mode 100644 tests/test_avl_tree_1.py diff --git a/.gitignore b/.gitignore index fc52d55..aafb37d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ **/__pycache__/ .coverage -.idea/ \ No newline at end of file +.idea/ +.DS_Store \ No newline at end of file diff --git a/helper.py b/helper.py index bc071ec..30141ad 100644 --- a/helper.py +++ b/helper.py @@ -2,13 +2,13 @@ from typing import List, Tuple, TypeVar, Optional from order_book import OrderData, Order, OrderBook, MatchingEngine -from order_book.advanced_avl_tree import AdvancedAVLTree +from py_simple_trees import AVLTree, AVLNode K = TypeVar("K") V = TypeVar("V") -def build_order(order_data: OrderData) -> Order: # pragma: no cover +def build_order(order_data: OrderData) -> Order: # pragma: no cover random_id = int(time.time() * 1e6) order = Order( random_id, @@ -19,7 +19,7 @@ def build_order(order_data: OrderData) -> Order: # pragma: no cover return order -def build_order_book(orders_data: list[OrderData]) -> [OrderBook, list[Order]]: # pragma: no cover +def build_order_book(orders_data: list[OrderData]) -> [OrderBook, list[Order]]: # pragma: no cover order_book = OrderBook() orders = [] for order_data in orders_data: @@ -29,7 +29,7 @@ def build_order_book(orders_data: list[OrderData]) -> [OrderBook, list[Order]]: return order_book, orders -def update_order_book(order_book: OrderBook, orders_data: list[OrderData]) -> [OrderBook, list[Order]]: # pragma: no cover +def update_order_book(order_book: OrderBook, orders_data: list[OrderData]) -> [OrderBook, list[Order]]: # pragma: no cover orders = [] for order_data in orders_data: order = build_order(order_data) @@ -38,15 +38,15 @@ def update_order_book(order_book: OrderBook, orders_data: list[OrderData]) -> [O return order_book, orders -def update_data_to_avl_tree(avl_tree: AdvancedAVLTree[K, V], kv_data: List[Tuple[K, V]]): # pragma: no cover +def update_data_to_avl_tree(avl_tree: AVLTree, kv_data: List[Tuple]): # pragma: no cover for action, key, value in kv_data: if action == 'insert': - avl_tree.insert_node(key, value) + avl_tree.insert(AVLNode(key, value)) if action == 'delete': - avl_tree.delete_node(key) + avl_tree.remove(AVLNode(key)) -def update_matching_engine(matching_engine: MatchingEngine, orders_data: list[OrderData]) -> MatchingEngine: # pragma: no cover +def update_matching_engine(matching_engine: MatchingEngine, orders_data: list[OrderData]) -> MatchingEngine: # pragma: no cover orders = [] for order_data in orders_data: order = build_order(order_data) @@ -55,7 +55,7 @@ def update_matching_engine(matching_engine: MatchingEngine, orders_data: list[Or return matching_engine, orders -def build_matching_engine(orders_data: list[OrderData], order_book: Optional[OrderBook] = None) -> MatchingEngine: # pragma: no cover +def build_matching_engine(orders_data: list[OrderData], order_book: Optional[OrderBook] = None) -> MatchingEngine: # pragma: no cover matching_engine = MatchingEngine(order_book) matching_engine, orders = update_matching_engine(matching_engine, orders_data) return matching_engine diff --git a/order_book/__init__.py b/order_book/__init__.py index 7366539..182d977 100644 --- a/order_book/__init__.py +++ b/order_book/__init__.py @@ -2,7 +2,7 @@ from typing import Union, Dict, NamedTuple, Optional from order_book.double_linked_list import LinkedList -from order_book.advanced_avl_tree import AdvancedAVLTree +from py_simple_trees import AVLTree, AVLNode ID_TYPE = Union[int, str] @@ -47,7 +47,7 @@ def other_side(self): return SideType.SELL return SideType.BUY - def print_out(self): # pragma: no cover + def print_out(self): # pragma: no cover print("Order ", self.id, " ", self.price, " ", self.volume, " ", self.origin_volume, " ", self.origin_volume - self.volume) @@ -83,12 +83,12 @@ def has_no_orders(self): class OrderBook: def __init__(self): - self.bids_tree: AdvancedAVLTree[float, PriceLevel] = AdvancedAVLTree[float, PriceLevel]() - self.asks_tree: AdvancedAVLTree[float, PriceLevel] = AdvancedAVLTree[float, PriceLevel]() + self.bids_tree: AVLTree = AVLTree() + self.asks_tree: AVLTree = AVLTree() self.best_bid_price_level: Optional[PriceLevel] = None self.best_ask_price_level: Optional[PriceLevel] = None - self.price_levels: Dict[float, PriceLevel] = {} + self.price_levels: Dict = {} def add_order(self, order: Order): if order.side == SideType.BUY: @@ -106,9 +106,9 @@ def execute_order(self, order: Order) -> dict: def _delete_price_level(self, side: SideType, price_level: PriceLevel): del self.price_levels[price_level.price] if side == SideType.BUY: - self.bids_tree.delete_node(price_level.price) + self.bids_tree.remove(AVLNode(price_level.price)) else: - self.asks_tree.delete_node(price_level.price) + self.asks_tree.remove(AVLNode(price_level.price)) def cancel_order(self, order: Order): price_level = order.price_level @@ -119,15 +119,15 @@ def cancel_order(self, order: Order): self._delete_price_level(order.side, price_level) def is_empty_bids(self) -> bool: - self.bids_tree.is_empty() + return self.bids_tree.root is None def is_empty_asks(self) -> bool: - self.asks_tree.is_empty() + return self.asks_tree.root is None - def _add_new_price_level(self, prices_tree: AdvancedAVLTree[float, PriceLevel], best_price_level: PriceLevel, order: Order) -> PriceLevel: + def _add_new_price_level(self, prices_tree: AVLTree, best_price_level: PriceLevel, order: Order) -> PriceLevel: price_level = PriceLevel(order.price) price_level.add_order(order) - prices_tree.insert_node(price_level.price, price_level) + prices_tree.insert(AVLNode(price_level.price, price_level)) self.price_levels[order.price] = price_level if best_price_level is None: @@ -138,14 +138,14 @@ def _add_new_price_level(self, prices_tree: AdvancedAVLTree[float, PriceLevel], return price_level return best_price_level - def _add_order(self, price_tree: AdvancedAVLTree[float, PriceLevel], + def _add_order(self, price_tree: AVLTree, best_price_level: PriceLevel, order: Order) -> PriceLevel: if order.price not in self.price_levels: return self._add_new_price_level(price_tree, best_price_level, order) else: price_level = self.price_levels[order.price] price_level.add_order(order) - price_tree.update(price_level.price, price_level) + price_tree.update(AVLNode(price_level.price, price_level)) return best_price_level def _next_best_price_level(self, side: SideType) -> Optional[PriceLevel]: @@ -154,16 +154,22 @@ def _next_best_price_level(self, side: SideType) -> Optional[PriceLevel]: return if self.best_bid_price_level.has_no_orders(): del self.price_levels[self.best_bid_price_level.price] - self.bids_tree.delete_node(self.best_bid_price_level.price) - self.best_bid_price_level = self.bids_tree.max_node_value + self.bids_tree.remove(AVLNode(self.best_bid_price_level.price)) + if self.bids_tree.root is not None: + self.best_bid_price_level = self.bids_tree.root.max_node.value + else: + self.best_bid_price_level = None return self.best_bid_price_level else: # side == SideType.SELL if self.best_ask_price_level is None: return if self.best_ask_price_level.has_no_orders(): del self.price_levels[self.best_ask_price_level.price] - self.asks_tree.delete_node(self.best_ask_price_level.price) - self.best_ask_price_level = self.asks_tree.min_node_value + self.asks_tree.remove(AVLNode(self.best_ask_price_level.price)) + if self.asks_tree.root is not None: + self.best_ask_price_level = self.asks_tree.root.min_node.value + else: + self.best_ask_price_level = None return self.best_ask_price_level def _execute_order(self, order: Order, best_price_level: Optional[PriceLevel]) -> (Optional[PriceLevel], dict): diff --git a/tests/common.py b/tests/common.py index 2e4ac5a..34044a1 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,19 +1,12 @@ -import sys -from typing import TypeVar, List, Tuple, Dict, Optional - -from helper import ( - # build_order_book, - build_matching_engine, - update_matching_engine, - # update_order_book, - update_data_to_avl_tree -) -from order_book import OrderBook, PriceLevel, Order, MatchingEngine -from order_book.avl_tree import AVLTree, TreeNode -from order_book.advanced_avl_tree import AdvancedAVLTree +from typing import TypeVar, List, Tuple + +from helper import update_matching_engine +from order_book import OrderBook, MatchingEngine +from py_simple_trees import AVLTree, AVLNode, TraversalType K = TypeVar("K") V = TypeVar("V") +AVLBN = TypeVar("AVLBN", bound=AVLNode) def check_order_book(inputs, outputs): @@ -32,34 +25,29 @@ def check_order_book(inputs, outputs): assert bid_price_levels == outputs[1] -def check_advanced_avl_tree(inputs, outputs): - avl_tree = TestAdvancedAVLTree[int, int]() - update_data_to_avl_tree(avl_tree, inputs) - - parents = avl_tree.get_parents() - print("parents: ", parents) - print("expected parents: ", outputs["parents"]) - - assert parents == outputs["parents"] - - class TestOrderBook(OrderBook): def __init__(self): super().__init__() - self.bids_tree: TestAdvancedAVLTree[float, PriceLevel] = TestAdvancedAVLTree[float, PriceLevel]() - self.asks_tree: TestAdvancedAVLTree[float, PriceLevel] = TestAdvancedAVLTree[float, PriceLevel]() + self.bids_tree: TestAVLTree = TestAVLTree() + self.asks_tree: TestAVLTree = TestAVLTree() def get_ask_price_levels(self): - price_levels = self.asks_tree.get_all_nodes() - price_levels.reverse() - price_levels = map(lambda pl: (pl[1].price, pl[1].total_volume), price_levels) - return price_levels + return map( + lambda node: (node.value.price, node.value.total_volume), + self.asks_tree.traversal( + traversal_type=TraversalType.IN_ORDER, + reverse=True + ), + ) def get_bid_price_levels(self): - price_levels = self.bids_tree.get_all_nodes() - price_levels.reverse() - price_levels = map(lambda pl: (pl[1].price, pl[1].total_volume), price_levels) - return price_levels + return map( + lambda node: (node.value.price, node.value.total_volume), + self.bids_tree.traversal( + traversal_type=TraversalType.IN_ORDER, + reverse=True + ), + ) def print_bids(self): # pragma: no cover price_levels = self.get_bid_price_levels() @@ -87,8 +75,8 @@ def print_filled_orders(self): # pragma: no cover order.print_out() -class TestAdvancedAVLTree(AdvancedAVLTree[K, V]): - def _parents(self, root: TreeNode) -> List[Tuple]: +class TestAVLTree(AVLTree[K, V, AVLBN]): + def _parents(self, root: AVLBN) -> List[Tuple]: results = [] if root is None: return results @@ -107,35 +95,3 @@ def print_parents(self): # pragma: no cover results = self.get_parents() print("Key | ParentKey") print(results) - - # Print the tree - def _print_helper(self, root: TreeNode[K, V], indent: str, last: bool): # pragma: no cover - if root is not None: - sys.stdout.write(indent) - if last: - sys.stdout.write("R----") - indent += " " - else: - sys.stdout.write("L----") - indent += "| " - print(root.key) - self._print_helper(root.left, indent, False) - self._print_helper(root.right, indent, True) - - def print_helper(self, indent: str, last: bool): # pragma: no cover - self._print_helper(self.root, indent, last) - - def _get_all_nodes(self, node: TreeNode[K, V]) -> List[tuple[K, V]]: - results = [] - if not node: - return results - - if node.left: - results += self._get_all_nodes(node.left) - results += [node.data] - if node.right: - results += self._get_all_nodes(node.right) - return results - - def get_all_nodes(self) -> List[tuple[K, V]]: - return self._get_all_nodes(self.root) \ No newline at end of file diff --git a/tests/test_advanced_avl_tree_1.py b/tests/test_advanced_avl_tree_1.py deleted file mode 100644 index b96206a..0000000 --- a/tests/test_advanced_avl_tree_1.py +++ /dev/null @@ -1,51 +0,0 @@ -# from order_book.advanced_avl_tree import AdvancedAVLTree -# from tests.common import TestAdvancedAVLTree -from tests.common import check_advanced_avl_tree - - -def test_1(): - inputs = [ - ('insert', 1, 1), - ('insert', 2, 2), - ('insert', 3, 3), - ('insert', 4, 4), - ('insert', 5, 5), - ('insert', 6, 6), - ] - outputs = { - "parents": [(1, 2), (2, 4), (3, 2), (4, None), (5, 4), (6, 5)] - } - check_advanced_avl_tree(inputs, outputs) - - -def test_2(): - inputs = [ - ('insert', 1, 1), - ('insert', 2, 2), - ('insert', 3, 3), - ('insert', 4, 4), - ('insert', 5, 5), - ('insert', 6, 6), - ('delete', 4, None), - ] - outputs = { - "parents": [(1, 2), (2, 5), (3, 2), (5, None), (6, 5)] - } - check_advanced_avl_tree(inputs, outputs) - - -def test_3(): - inputs = [ - ('insert', 1, 1), - ('insert', 2, 2), - ('insert', 3, 3), - ('insert', 4, 4), - ('insert', 5, 5), - ('insert', 6, 6), - ('delete', 4, None), - ('delete', 5, None), - ] - outputs = { - "parents": [(1, 2), (2, None), (3, 6), (6, 2)] - } - check_advanced_avl_tree(inputs, outputs) \ No newline at end of file diff --git a/tests/test_avl_tree_1.py b/tests/test_avl_tree_1.py deleted file mode 100644 index e69de29..0000000