py-tools/tools/trees.py
2022-01-16 17:39:37 +01:00

288 lines
9.2 KiB
Python

from dataclasses import dataclass
from enum import Enum
from tools.lists import Queue
from typing import Any, Union
class Rotate(Enum):
LEFT = 0
RIGHT = 1
@dataclass
class TreeNode:
value: Any
parent: Union['TreeNode', None] = None
left: Union['TreeNode', None] = None
right: Union['TreeNode', None] = None
balance_factor: int = 0
height: int = 0
def __str__(self):
return "TreeNode:(%s; bf: %d, d: %d, p: %s, l: %s, r: %s)" \
% (self.value, self.balance_factor, self.height,
self.parent.value if self.parent else "None",
self.left.value if self.left else "None",
self.right.value if self.right else "None")
def __repr__(self):
return str(self)
def update_node(node: TreeNode):
left_depth = node.left.height if node.left is not None else -1
right_depth = node.right.height if node.right is not None else -1
node.height = 1 + max(left_depth, right_depth)
node.balance_factor = right_depth - left_depth
class BinarySearchTree:
root: Union[TreeNode, None] = None
node_count: int = 0
def _balance(self, node: TreeNode) -> TreeNode:
if node.balance_factor == -2:
if node.left.balance_factor <= 0:
return self.rotate(Rotate.RIGHT, node)
else:
return self.rotate(Rotate.RIGHT, self.rotate(Rotate.LEFT, node.left))
elif node.balance_factor == 2:
if node.right.balance_factor >= 0:
return self.rotate(Rotate.LEFT, node)
else:
return self.rotate(Rotate.LEFT, self.rotate(Rotate.RIGHT, node.right))
else:
return node
def _insert(self, node: TreeNode, parent: TreeNode, obj: Any) -> TreeNode:
if node is None:
return TreeNode(obj, parent)
if obj < node.value:
node.left = self._insert(node.left, node, obj)
else:
node.right = self._insert(node.right, node, obj)
update_node(node)
return self._balance(node)
def add(self, obj: Any):
if obj is None or obj in self:
raise ValueError("obj is None or already present in tree")
self.root = self._insert(self.root, self.root, obj)
self.node_count += 1
def remove(self, obj: Any, root_node: TreeNode = None):
if self.root is None:
raise IndexError("remove from empty tree")
if root_node is None:
root_node = self.root
node = root_node
while node is not None:
if obj < node.value:
node = node.left
continue
elif obj > node.value:
node = node.right
continue
else:
if node.left is None and node.right is None: # leaf node
if node.parent is not None:
if node.parent.left == node:
node.parent.left = None
else:
node.parent.right = None
elif node.left is not None and node.right is not None: # both subtrees present
d_node = node.left
while d_node.right is not None:
d_node = d_node.right
node.value = d_node.value
self.remove(node.value, d_node)
elif node.left is None: # only a subtree on the right
if node.parent is not None:
if node.parent.left == node:
node.parent.left = node.right
else:
node.parent.right = node.right
node.right.parent = node.parent
else: # only a subtree on the left
if node.parent is not None:
if node.parent.left == node:
node.parent.left = node.left
else:
node.parent.right = node.left
node.left.parent = node.parent
update_node(root_node)
self._balance(root_node)
return
raise ValueError("obj not in tree:", obj)
def rotate(self, direction: Rotate, node: TreeNode = None) -> TreeNode:
if node is None:
node = self.root
parent = node.parent
if direction == Rotate.LEFT:
pivot = node.right
node.right = pivot.left
if pivot.left is not None:
pivot.left.parent = node
pivot.left = node
else:
pivot = node.left
node.left = pivot.right
if pivot.right is not None:
pivot.right.parent = node
pivot.right = node
node.parent = pivot
pivot.parent = parent
if parent is not None:
if parent.left == node:
parent.left = pivot
else:
parent.right = pivot
if node == self.root:
self.root = pivot
update_node(node)
update_node(pivot)
return pivot
def print(self, node: TreeNode = None, level: int = 0):
if node is None:
if level == 0 and self.root is not None:
node = self.root
else:
return
self.print(node.right, level + 1)
print(" " * 4 * level + '->', node)
self.print(node.left, level + 1)
def __contains__(self, obj: Any) -> bool:
if self.root is None:
return False
c_node = self.root
while c_node is not None:
if obj == c_node.value:
return True
elif obj < c_node.value:
c_node = c_node.left
else:
c_node = c_node.right
return False
def __len__(self) -> int:
return self.node_count
class Heap(BinarySearchTree):
def _find_left(self) -> TreeNode:
heap = Queue()
heap.put(self.root)
while c_node := heap.get():
if c_node.left is None or c_node.right is None:
return c_node
else:
heap.put(c_node.left)
heap.put(c_node.right)
def _find_right(self) -> TreeNode:
heap = [self.root]
while c_node := heap.pop():
if c_node.left is None and c_node.right is None:
return c_node
else:
if c_node.left is not None:
heap.append(c_node.left)
if c_node.right is not None:
heap.append(c_node.right)
def _sort_up(self, node: TreeNode):
pass
def _heapify(self):
pass
def empty(self) -> bool:
return self.root is None
def add(self, obj: Any):
node = TreeNode(obj)
if self.root is None:
self.root = node
else:
t_node = self._find_left()
if t_node.left is None:
t_node.left = node
else:
t_node.right = node
node.parent = t_node
self._sort_up(node)
def pop(self) -> Any:
if self.root is None:
raise IndexError("pop from empty heap")
ret = self.root.value
if self.root.left is None and self.root.right is None:
self.root = None
else:
d_node = self._find_right()
self.root.value = d_node.value
if d_node.parent.left == d_node:
d_node.parent.left = None
else:
d_node.parent.right = None
self._heapify()
return ret
class MinHeap(Heap):
def _sort_up(self, node: TreeNode):
while node.parent is not None and node.value < node.parent.value:
node.value, node.parent.value = node.parent.value, node.value
node = node.parent
def _heapify(self):
node = self.root
while node is not None:
if node.left and node.left.value < node.value and (not node.right or node.right.value > node.left.value):
node.left.value, node.value = node.value, node.left.value
node = node.left
elif node.right and node.right.value < node.value and (not node.left or node.left.value > node.right.value):
node.right.value, node.value = node.value, node.right.value
node = node.right
else:
break
class MaxHeap(Heap):
def _sort_up(self, node: TreeNode):
while node.parent is not None and node.value > node.parent.value:
node.value, node.parent.value = node.parent.value, node.value
node = node.parent
def _heapify(self):
node = self.root
while node is not None:
if node.left and node.left.value > node.value and (not node.right or node.right.value < node.left.value):
node.left.value, node.value = node.value, node.left.value
node = node.left
elif node.right and node.right.value > node.value and (not node.left or node.left.value < node.right.value):
node.right.value, node.value = node.value, node.right.value
node = node.right
else:
break