From 9d373596a1994b41dc7c8e0d2a73b47a8fdb1ee7 Mon Sep 17 00:00:00 2001 From: Stefan Harmuth Date: Sun, 16 Jan 2022 17:39:37 +0100 Subject: [PATCH] binary trees; slow af, but functional :D --- btree_test.py | 55 +++++++++ tools/lists.py | 29 ++++- tools/trees.py | 300 ++++++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 353 insertions(+), 31 deletions(-) create mode 100644 btree_test.py diff --git a/btree_test.py b/btree_test.py new file mode 100644 index 0000000..ed5c261 --- /dev/null +++ b/btree_test.py @@ -0,0 +1,55 @@ +from heapq import heappop, heappush +from tools.trees import MinHeap, BinarySearchTree +from tools.stopwatch import StopWatch + + +b = BinarySearchTree() +for x in range(16): + b.add(x) +b.print() +print("---") +b.remove(7) +b.remove(6) +b.remove(4) +b.remove(5) +b.print() + +exit() + +# timing below :'-( + +s = StopWatch() +h = [] +for x in range(10_000): + heappush(h, x) +print("Heappush:", s.elapsed()) +while h: + heappop(h) +print("Heappop:", s.elapsed()) + +s = StopWatch() +h = MinHeap() +for x in range(10_000): + h.add(x) +print("MinHeap.add():", s.elapsed()) +while not h.empty(): + h.pop() +print("MinHeap.pop():", s.elapsed()) + +s = StopWatch() +b = set() +for x in range(1_000_000): + b.add(x) +print("set.add():", s.elapsed()) +for x in range(1_000_000): + _ = x in b +print("x in set:", s.elapsed()) + +s = StopWatch() +b = BinarySearchTree() +for x in range(1_000_000): + b.add(x) +print("AVL.add():", s.elapsed()) +for x in range(1_000_000): + _ = x in b +print("x in AVL:", s.elapsed()) diff --git a/tools/lists.py b/tools/lists.py index cb90fde..047a803 100644 --- a/tools/lists.py +++ b/tools/lists.py @@ -14,6 +14,25 @@ class LinkedList: _tail: Union[Node, None] = None size: int = 0 + def _get_head(self): + return self._head + + def _get_tail(self): + return self._tail + + def _set_head(self, node: Node): + node.next = self._head + self._head.prev = node + self._head = node + + def _set_tail(self, node: Node): + node.prev = self._tail + self._tail.next = node + self._tail = node + + head = property(_get_head, _set_head) + tail = property(_get_tail, _set_tail) + def _append(self, obj: Any): node = Node(obj) if self._head is None: @@ -110,9 +129,9 @@ class LinkedList: return x.value == obj def __add__(self, other: 'LinkedList') -> 'LinkedList': - self._tail.next = other._head - other._head.prev = self._tail - self._tail = other._tail + self._tail.next = other.head + other.head.prev = self._tail + self._tail = other.tail self.size += other.size return self @@ -161,5 +180,5 @@ class Queue(LinkedList): def peek(self) -> Any: return self._head.value - push = enqueue - pop = dequeue + push = put = enqueue + pop = get = dequeue diff --git a/tools/trees.py b/tools/trees.py index 513b1bc..af6f565 100644 --- a/tools/trees.py +++ b/tools/trees.py @@ -1,39 +1,287 @@ +from dataclasses import dataclass +from enum import Enum +from tools.lists import Queue from typing import Any, Union -class BinaryTreeNode: - data: Any - left: Union['BinaryTreeNode', None] - right: Union['BinaryTreeNode', None] +class Rotate(Enum): + LEFT = 0 + RIGHT = 1 - def __init__(self, data: Any): - self.data = data - self.left = None - self.right = None - def traverse_inorder(self): - if self.left: - self.left.traverse_inorder() +@dataclass +class TreeNode: + value: Any + parent: Union['TreeNode', None] = None + left: Union['TreeNode', None] = None + right: Union['TreeNode', None] = None + balance_factor: int = 0 + height: int = 0 - yield self.data + def __str__(self): + return "TreeNode:(%s; bf: %d, d: %d, p: %s, l: %s, r: %s)" \ + % (self.value, self.balance_factor, self.height, + self.parent.value if self.parent else "None", + self.left.value if self.left else "None", + self.right.value if self.right else "None") - if self.right: - self.right.traverse_inorder() + def __repr__(self): + return str(self) - def traverse_preorder(self): - yield self.data - if self.left: - self.left.traverse_preorder() +def update_node(node: TreeNode): + left_depth = node.left.height if node.left is not None else -1 + right_depth = node.right.height if node.right is not None else -1 + node.height = 1 + max(left_depth, right_depth) + node.balance_factor = right_depth - left_depth - if self.right: - self.right.traverse_preorder() - def traverse_postorder(self): - if self.left: - self.left.traverse_preorder() +class BinarySearchTree: + root: Union[TreeNode, None] = None + node_count: int = 0 - if self.right: - self.right.traverse_postorder() + def _balance(self, node: TreeNode) -> TreeNode: + if node.balance_factor == -2: + if node.left.balance_factor <= 0: + return self.rotate(Rotate.RIGHT, node) + else: + return self.rotate(Rotate.RIGHT, self.rotate(Rotate.LEFT, node.left)) + elif node.balance_factor == 2: + if node.right.balance_factor >= 0: + return self.rotate(Rotate.LEFT, node) + else: + return self.rotate(Rotate.LEFT, self.rotate(Rotate.RIGHT, node.right)) + else: + return node - yield self.data + def _insert(self, node: TreeNode, parent: TreeNode, obj: Any) -> TreeNode: + if node is None: + return TreeNode(obj, parent) + + if obj < node.value: + node.left = self._insert(node.left, node, obj) + else: + node.right = self._insert(node.right, node, obj) + + update_node(node) + return self._balance(node) + + def add(self, obj: Any): + if obj is None or obj in self: + raise ValueError("obj is None or already present in tree") + + self.root = self._insert(self.root, self.root, obj) + self.node_count += 1 + + def remove(self, obj: Any, root_node: TreeNode = None): + if self.root is None: + raise IndexError("remove from empty tree") + if root_node is None: + root_node = self.root + + node = root_node + while node is not None: + if obj < node.value: + node = node.left + continue + elif obj > node.value: + node = node.right + continue + else: + if node.left is None and node.right is None: # leaf node + if node.parent is not None: + if node.parent.left == node: + node.parent.left = None + else: + node.parent.right = None + elif node.left is not None and node.right is not None: # both subtrees present + d_node = node.left + while d_node.right is not None: + d_node = d_node.right + node.value = d_node.value + self.remove(node.value, d_node) + elif node.left is None: # only a subtree on the right + if node.parent is not None: + if node.parent.left == node: + node.parent.left = node.right + else: + node.parent.right = node.right + node.right.parent = node.parent + else: # only a subtree on the left + if node.parent is not None: + if node.parent.left == node: + node.parent.left = node.left + else: + node.parent.right = node.left + node.left.parent = node.parent + + update_node(root_node) + self._balance(root_node) + return + + raise ValueError("obj not in tree:", obj) + + def rotate(self, direction: Rotate, node: TreeNode = None) -> TreeNode: + if node is None: + node = self.root + + parent = node.parent + if direction == Rotate.LEFT: + pivot = node.right + node.right = pivot.left + if pivot.left is not None: + pivot.left.parent = node + pivot.left = node + else: + pivot = node.left + node.left = pivot.right + if pivot.right is not None: + pivot.right.parent = node + pivot.right = node + + node.parent = pivot + pivot.parent = parent + + if parent is not None: + if parent.left == node: + parent.left = pivot + else: + parent.right = pivot + + if node == self.root: + self.root = pivot + + update_node(node) + update_node(pivot) + + return pivot + + def print(self, node: TreeNode = None, level: int = 0): + if node is None: + if level == 0 and self.root is not None: + node = self.root + else: + return + + self.print(node.right, level + 1) + print(" " * 4 * level + '->', node) + self.print(node.left, level + 1) + + def __contains__(self, obj: Any) -> bool: + if self.root is None: + return False + + c_node = self.root + while c_node is not None: + if obj == c_node.value: + return True + elif obj < c_node.value: + c_node = c_node.left + else: + c_node = c_node.right + + return False + + def __len__(self) -> int: + return self.node_count + + +class Heap(BinarySearchTree): + def _find_left(self) -> TreeNode: + heap = Queue() + heap.put(self.root) + while c_node := heap.get(): + if c_node.left is None or c_node.right is None: + return c_node + else: + heap.put(c_node.left) + heap.put(c_node.right) + + def _find_right(self) -> TreeNode: + heap = [self.root] + while c_node := heap.pop(): + if c_node.left is None and c_node.right is None: + return c_node + else: + if c_node.left is not None: + heap.append(c_node.left) + if c_node.right is not None: + heap.append(c_node.right) + + def _sort_up(self, node: TreeNode): + pass + + def _heapify(self): + pass + + def empty(self) -> bool: + return self.root is None + + def add(self, obj: Any): + node = TreeNode(obj) + if self.root is None: + self.root = node + else: + t_node = self._find_left() + if t_node.left is None: + t_node.left = node + else: + t_node.right = node + node.parent = t_node + self._sort_up(node) + + def pop(self) -> Any: + if self.root is None: + raise IndexError("pop from empty heap") + + ret = self.root.value + if self.root.left is None and self.root.right is None: + self.root = None + else: + d_node = self._find_right() + self.root.value = d_node.value + if d_node.parent.left == d_node: + d_node.parent.left = None + else: + d_node.parent.right = None + self._heapify() + + return ret + + +class MinHeap(Heap): + def _sort_up(self, node: TreeNode): + while node.parent is not None and node.value < node.parent.value: + node.value, node.parent.value = node.parent.value, node.value + node = node.parent + + def _heapify(self): + node = self.root + while node is not None: + if node.left and node.left.value < node.value and (not node.right or node.right.value > node.left.value): + node.left.value, node.value = node.value, node.left.value + node = node.left + elif node.right and node.right.value < node.value and (not node.left or node.left.value > node.right.value): + node.right.value, node.value = node.value, node.right.value + node = node.right + else: + break + + +class MaxHeap(Heap): + def _sort_up(self, node: TreeNode): + while node.parent is not None and node.value > node.parent.value: + node.value, node.parent.value = node.parent.value, node.value + node = node.parent + + def _heapify(self): + node = self.root + while node is not None: + if node.left and node.left.value > node.value and (not node.right or node.right.value < node.left.value): + node.left.value, node.value = node.value, node.left.value + node = node.left + elif node.right and node.right.value > node.value and (not node.left or node.left.value < node.right.value): + node.right.value, node.value = node.value, node.right.value + node = node.right + else: + break