from dataclasses import dataclass from enum import Enum from tools.lists import Queue from typing import Any, Union class Rotate(Enum): LEFT = 0 RIGHT = 1 @dataclass class TreeNode: value: Any parent: Union['TreeNode', None] = None left: Union['TreeNode', None] = None right: Union['TreeNode', None] = None balance_factor: int = 0 height: int = 0 def __str__(self): return "TreeNode:(%s; bf: %d, d: %d, p: %s, l: %s, r: %s)" \ % (self.value, self.balance_factor, self.height, self.parent.value if self.parent else "None", self.left.value if self.left else "None", self.right.value if self.right else "None") def __repr__(self): return str(self) def update_node(node: TreeNode): left_depth = node.left.height if node.left is not None else -1 right_depth = node.right.height if node.right is not None else -1 node.height = 1 + max(left_depth, right_depth) node.balance_factor = right_depth - left_depth class BinarySearchTree: root: Union[TreeNode, None] = None node_count: int = 0 def _balance(self, node: TreeNode) -> TreeNode: if node.balance_factor == -2: if node.left.balance_factor <= 0: return self.rotate(Rotate.RIGHT, node) else: return self.rotate(Rotate.RIGHT, self.rotate(Rotate.LEFT, node.left)) elif node.balance_factor == 2: if node.right.balance_factor >= 0: return self.rotate(Rotate.LEFT, node) else: return self.rotate(Rotate.LEFT, self.rotate(Rotate.RIGHT, node.right)) else: return node def _insert(self, node: TreeNode, parent: TreeNode, obj: Any) -> TreeNode: if node is None: return TreeNode(obj, parent) if obj < node.value: node.left = self._insert(node.left, node, obj) else: node.right = self._insert(node.right, node, obj) update_node(node) return self._balance(node) def add(self, obj: Any): if obj is None or obj in self: raise ValueError("obj is None or already present in tree") self.root = self._insert(self.root, self.root, obj) self.node_count += 1 def remove(self, obj: Any, root_node: TreeNode = None): if self.root is None: raise IndexError("remove from empty tree") if root_node is None: root_node = self.root node = root_node while node is not None: if obj < node.value: node = node.left continue elif obj > node.value: node = node.right continue else: if node.left is None and node.right is None: # leaf node if node.parent is not None: if node.parent.left == node: node.parent.left = None else: node.parent.right = None elif node.left is not None and node.right is not None: # both subtrees present d_node = node.left while d_node.right is not None: d_node = d_node.right node.value = d_node.value self.remove(node.value, d_node) elif node.left is None: # only a subtree on the right if node.parent is not None: if node.parent.left == node: node.parent.left = node.right else: node.parent.right = node.right node.right.parent = node.parent else: # only a subtree on the left if node.parent is not None: if node.parent.left == node: node.parent.left = node.left else: node.parent.right = node.left node.left.parent = node.parent update_node(root_node) self._balance(root_node) return raise ValueError("obj not in tree:", obj) def rotate(self, direction: Rotate, node: TreeNode = None) -> TreeNode: if node is None: node = self.root parent = node.parent if direction == Rotate.LEFT: pivot = node.right node.right = pivot.left if pivot.left is not None: pivot.left.parent = node pivot.left = node else: pivot = node.left node.left = pivot.right if pivot.right is not None: pivot.right.parent = node pivot.right = node node.parent = pivot pivot.parent = parent if parent is not None: if parent.left == node: parent.left = pivot else: parent.right = pivot if node == self.root: self.root = pivot update_node(node) update_node(pivot) return pivot def print(self, node: TreeNode = None, level: int = 0): if node is None: if level == 0 and self.root is not None: node = self.root else: return self.print(node.right, level + 1) print(" " * 4 * level + '->', node) self.print(node.left, level + 1) def __contains__(self, obj: Any) -> bool: if self.root is None: return False c_node = self.root while c_node is not None: if obj == c_node.value: return True elif obj < c_node.value: c_node = c_node.left else: c_node = c_node.right return False def __len__(self) -> int: return self.node_count class Heap(BinarySearchTree): def _find_left(self) -> TreeNode: heap = Queue() heap.put(self.root) while c_node := heap.get(): if c_node.left is None or c_node.right is None: return c_node else: heap.put(c_node.left) heap.put(c_node.right) def _find_right(self) -> TreeNode: heap = [self.root] while c_node := heap.pop(): if c_node.left is None and c_node.right is None: return c_node else: if c_node.left is not None: heap.append(c_node.left) if c_node.right is not None: heap.append(c_node.right) def _sort_up(self, node: TreeNode): pass def _heapify(self): pass def empty(self) -> bool: return self.root is None def add(self, obj: Any): node = TreeNode(obj) if self.root is None: self.root = node else: t_node = self._find_left() if t_node.left is None: t_node.left = node else: t_node.right = node node.parent = t_node self._sort_up(node) def pop(self) -> Any: if self.root is None: raise IndexError("pop from empty heap") ret = self.root.value if self.root.left is None and self.root.right is None: self.root = None else: d_node = self._find_right() self.root.value = d_node.value if d_node.parent.left == d_node: d_node.parent.left = None else: d_node.parent.right = None self._heapify() return ret class MinHeap(Heap): def _sort_up(self, node: TreeNode): while node.parent is not None and node.value < node.parent.value: node.value, node.parent.value = node.parent.value, node.value node = node.parent def _heapify(self): node = self.root while node is not None: if node.left and node.left.value < node.value and (not node.right or node.right.value > node.left.value): node.left.value, node.value = node.value, node.left.value node = node.left elif node.right and node.right.value < node.value and (not node.left or node.left.value > node.right.value): node.right.value, node.value = node.value, node.right.value node = node.right else: break class MaxHeap(Heap): def _sort_up(self, node: TreeNode): while node.parent is not None and node.value > node.parent.value: node.value, node.parent.value = node.parent.value, node.value node = node.parent def _heapify(self): node = self.root while node is not None: if node.left and node.left.value > node.value and (not node.right or node.right.value < node.left.value): node.left.value, node.value = node.value, node.left.value node = node.left elif node.right and node.right.value > node.value and (not node.left or node.left.value < node.right.value): node.right.value, node.value = node.value, node.right.value node = node.right else: break