From 622613a0ad501b4316b7b03ae5f4e6847e282033 Mon Sep 17 00:00:00 2001 From: Stefan Harmuth Date: Mon, 24 Jan 2022 13:37:38 +0100 Subject: [PATCH] introduce unbalanced binary trees (incredibly fun when inserting already sorted values!) --- btree_test.py | 20 +++- tools/stopwatch.py | 2 + tools/trees.py | 258 +++++++++++++++++++++++++++------------------ 3 files changed, 178 insertions(+), 102 deletions(-) diff --git a/btree_test.py b/btree_test.py index 648c074..97d5b77 100644 --- a/btree_test.py +++ b/btree_test.py @@ -5,18 +5,20 @@ from tools.stopwatch import StopWatch s = StopWatch() h = [] -for x in range(1_000_000): +for x in range(100_000): heappush(h, x) print("Heappush:", s.elapsed()) +s.reset() while h: heappop(h) print("Heappop:", s.elapsed()) s = StopWatch() h = MinHeap() -for x in range(1_000_000): +for x in range(100_000): h.add(x) print("MinHeap.add():", s.elapsed()) +s.reset() while not h.empty(): h.pop() print("MinHeap.pop():", s.elapsed()) @@ -26,6 +28,7 @@ b = set() for x in range(1_000_000): b.add(x) print("set.add():", s.elapsed()) +s.reset() for x in range(1_000_000): _ = x in b print("x in set:", s.elapsed()) @@ -35,6 +38,19 @@ b = BinarySearchTree() for x in range(1_000_000): b.add(x) print("AVL.add():", s.elapsed()) +s.reset() for x in range(1_000_000): _ = x in b print("x in AVL:", s.elapsed()) + +print("DFS/BFS Test") +b = BinarySearchTree() +for x in range(20): + b.add(x) +b.print() +print("DFS:") +for x in b.iter_depth_first(): + print(x) +print("BFS:") +for x in b.iter_breadth_first(): + print(x) diff --git a/tools/stopwatch.py b/tools/stopwatch.py index aff17c3..1c520a0 100644 --- a/tools/stopwatch.py +++ b/tools/stopwatch.py @@ -18,6 +18,8 @@ class StopWatch: self.stopped = time() return self.elapsed() + reset = start + def elapsed(self) -> float: if self.stopped is None: return time() - self.started diff --git a/tools/trees.py b/tools/trees.py index 8e21c59..1bca07d 100644 --- a/tools/trees.py +++ b/tools/trees.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from tools.lists import Queue +from tools.lists import Queue, Stack from typing import Any, Union @@ -36,10 +36,157 @@ def update_node(node: TreeNode): node.balance_factor = right_depth - left_depth -class BinarySearchTree: +class BinaryTree: root: Union[TreeNode, None] = None node_count: int = 0 + def _insert(self, node: TreeNode, parent: TreeNode, obj: Any) -> TreeNode: + new_node = TreeNode(obj, parent) + if node is None: + return new_node + + found = False + while not found: + if obj < node.value: + if node.left is not None: + node = node.left + else: + node.left = new_node + found = True + elif obj > node.value: + if node.right is not None: + node = node.right + else: + node.right = new_node + found = True + else: + raise ValueError("obj already present in tree: %s" % obj) + + new_node.parent = node + return new_node + + def _remove(self, node: TreeNode): + if node.left is None and node.right is None: # leaf node + if node.parent is not None: + if node.parent.left == node: + node.parent.left = None + else: + node.parent.right = None + else: + self.root = None + elif node.left is not None and node.right is not None: # both subtrees present + d_node = node.left + while d_node.right is not None: + d_node = d_node.right + node.value = d_node.value + self.remove(node.value, d_node) + elif node.left is None: # only a subtree on the right + if node.parent is not None: + if node.parent.left == node: + node.parent.left = node.right + else: + node.parent.right = node.right + else: + self.root = node.right + node.right.parent = node.parent + else: # only a subtree on the left + if node.parent is not None: + if node.parent.left == node: + node.parent.left = node.left + else: + node.parent.right = node.left + else: + self.root = node.left + node.left.parent = node.parent + + self.node_count -= 1 + + def _get_node_by_value(self, obj: Any, root_node: TreeNode = None) -> TreeNode: + if self.root is None: + raise IndexError("get node from empty tree") + + if root_node is None: + root_node = self.root + + node = root_node + while node is not None: + if obj < node.value: + node = node.left + continue + elif obj > node.value: + node = node.right + continue + else: + return node + + raise ValueError("obj not in tree:", obj) + + def add(self, obj: Any): + if obj is None: + return + + new_node = self._insert(self.root, self.root, obj) + if self.root is None: + self.root = new_node + self.node_count += 1 + + def remove(self, obj: Any, root_node: TreeNode = None): + node = self._get_node_by_value(obj, root_node) + self._remove(node) + + def iter_depth_first(self): + stack = Stack() + stack.push(self.root) + while len(stack): + node = stack.pop() + if node.right is not None: + stack.push(node.right) + if node.left is not None: + stack.push(node.left) + yield node.value + + def iter_breadth_first(self): + queue = Queue() + queue.push(self.root) + while len(queue): + node = queue.pop() + if node.left is not None: + queue.push(node.left) + if node.right is not None: + queue.push(node.right) + yield node.value + + def print(self, node: TreeNode = None, level: int = 0): + if node is None: + if level == 0 and self.root is not None: + node = self.root + else: + return + + self.print(node.right, level + 1) + print(" " * 4 * level + '->', node) + self.print(node.left, level + 1) + + def __contains__(self, obj: Any) -> bool: + if self.root is None: + return False + + c_node = self.root + while c_node is not None: + if obj == c_node.value: + return True + elif obj < c_node.value: + c_node = c_node.left + else: + c_node = c_node.right + + return False + + def __len__(self) -> int: + return self.node_count + + +class BinarySearchTree(BinaryTree): def _balance(self, node: TreeNode) -> TreeNode: if node.balance_factor == -2: if node.left.balance_factor <= 0: @@ -55,80 +202,20 @@ class BinarySearchTree: return node def _insert(self, node: TreeNode, parent: TreeNode, obj: Any) -> TreeNode: - if node is None: - return TreeNode(obj, parent) - - if obj < node.value: - node.left = self._insert(node.left, node, obj) - elif obj > node.value: - node.right = self._insert(node.right, node, obj) - else: - raise ValueError("obj already present in tree: %s" % obj) - - update_node(node) - return self._balance(node) - - def add(self, obj: Any): - if obj is None: - return - - self.root = self._insert(self.root, self.root, obj) - self.node_count += 1 + node = super()._insert(node, parent, obj) + if self.root is not None: + while node is not None: + update_node(node) + node = self._balance(node).parent + return node def remove(self, obj: Any, root_node: TreeNode = None): - if self.root is None: - raise IndexError("remove from empty tree") if root_node is None: root_node = self.root - node = root_node - while node is not None: - if obj < node.value: - node = node.left - continue - elif obj > node.value: - node = node.right - continue - else: - if node.left is None and node.right is None: # leaf node - if node.parent is not None: - if node.parent.left == node: - node.parent.left = None - else: - node.parent.right = None - else: - self.root = None - elif node.left is not None and node.right is not None: # both subtrees present - d_node = node.left - while d_node.right is not None: - d_node = d_node.right - node.value = d_node.value - self.remove(node.value, d_node) - elif node.left is None: # only a subtree on the right - if node.parent is not None: - if node.parent.left == node: - node.parent.left = node.right - else: - node.parent.right = node.right - else: - self.root = node.right - node.right.parent = node.parent - else: # only a subtree on the left - if node.parent is not None: - if node.parent.left == node: - node.parent.left = node.left - else: - node.parent.right = node.left - else: - self.root = node.left - node.left.parent = node.parent - - update_node(root_node) - self._balance(root_node) - self.node_count -= 1 - return - - raise ValueError("obj not in tree:", obj) + super().remove(obj, root_node) + update_node(root_node) + self._balance(root_node) def rotate(self, direction: Rotate, node: TreeNode = None) -> TreeNode: if node is None: @@ -165,35 +252,6 @@ class BinarySearchTree: return pivot - def print(self, node: TreeNode = None, level: int = 0): - if node is None: - if level == 0 and self.root is not None: - node = self.root - else: - return - - self.print(node.right, level + 1) - print(" " * 4 * level + '->', node) - self.print(node.left, level + 1) - - def __contains__(self, obj: Any) -> bool: - if self.root is None: - return False - - c_node = self.root - while c_node is not None: - if obj == c_node.value: - return True - elif obj < c_node.value: - c_node = c_node.left - else: - c_node = c_node.right - - return False - - def __len__(self) -> int: - return self.node_count - class Heap(BinarySearchTree): def empty(self):