283 lines
10 KiB
Python
283 lines
10 KiB
Python
from __future__ import annotations
|
|
from .coordinate import Coordinate, DistanceAlgorithm
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from heapq import heappop, heappush
|
|
from math import inf
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
OFF = False
|
|
ON = True
|
|
Numeric = Union[int, float]
|
|
|
|
|
|
class GridTransformation(Enum):
|
|
FLIP_X = 1
|
|
FLIP_HORIZONTALLY = 1 # alias for FLIP_X; prep for 3d-transformations
|
|
FLIP_VERTICALLY = 2
|
|
FLIP_DIAGONALLY = 3
|
|
FLIP_DIAGONALLY_REV = 4
|
|
ROTATE_RIGHT = 5
|
|
ROTATE_LEFT = 6
|
|
ROTATE_TWICE = 7
|
|
|
|
|
|
class Grid:
|
|
def __init__(self, default=False):
|
|
self.__default = default
|
|
self.__grid = {}
|
|
self.minX = 0
|
|
self.minY = 0
|
|
self.maxX = 0
|
|
self.maxY = 0
|
|
self.minZ = 0
|
|
self.maxZ = 0
|
|
self.mode3D = False
|
|
|
|
def __trackBoundaries(self, pos: Coordinate):
|
|
self.minX = min(pos.x, self.minX)
|
|
self.minY = min(pos.y, self.minY)
|
|
self.maxX = max(pos.x, self.maxX)
|
|
self.maxY = max(pos.y, self.maxY)
|
|
if self.mode3D:
|
|
self.minZ = min(pos.z, self.minZ)
|
|
self.maxZ = max(pos.z, self.maxZ)
|
|
|
|
def rangeX(self):
|
|
return range(self.minX, self.maxX + 1)
|
|
|
|
def rangeY(self):
|
|
return range(self.minY, self.maxY + 1)
|
|
|
|
def rangeZ(self):
|
|
return range(self.minZ, self.maxZ + 1)
|
|
|
|
def toggle(self, pos: Coordinate):
|
|
if pos in self.__grid:
|
|
del self.__grid[pos]
|
|
else:
|
|
self.__trackBoundaries(pos)
|
|
self.__grid[pos] = not self.__default
|
|
|
|
def set(self, pos: Coordinate, value: Any = True) -> Any:
|
|
if pos.z is not None:
|
|
self.mode3D = True
|
|
|
|
if (value == self.__default) and pos in self.__grid:
|
|
del self.__grid[pos]
|
|
elif value != self.__default:
|
|
self.__trackBoundaries(pos)
|
|
self.__grid[pos] = value
|
|
|
|
return value
|
|
|
|
def add(self, pos: Coordinate, value: Numeric = 1) -> Numeric:
|
|
return self.set(pos, self.get(pos) + value)
|
|
|
|
def sub(self, pos: Coordinate, value: Numeric = 1) -> Numeric:
|
|
return self.set(pos, self.get(pos) - value)
|
|
|
|
def mul(self, pos: Coordinate, value: Numeric = 1) -> Numeric:
|
|
return self.set(pos, self.get(pos) * value)
|
|
|
|
def div(self, pos: Coordinate, value: Numeric = 1) -> Numeric:
|
|
return self.set(pos, self.get(pos) / value)
|
|
|
|
def get(self, pos: Coordinate) -> Any:
|
|
return self.__grid.get(pos, self.__default)
|
|
|
|
def getOnCount(self) -> int:
|
|
return len(self.__grid)
|
|
|
|
def isSet(self, pos: Coordinate) -> bool:
|
|
return pos in self.__grid
|
|
|
|
def getCorners(self) -> List[Coordinate]:
|
|
if not self.mode3D:
|
|
return [
|
|
Coordinate(self.minX, self.minY),
|
|
Coordinate(self.minX, self.maxY),
|
|
Coordinate(self.maxX, self.minY),
|
|
Coordinate(self.maxX, self.maxY),
|
|
]
|
|
else:
|
|
return [
|
|
Coordinate(self.minX, self.minY, self.minZ),
|
|
Coordinate(self.minX, self.minY, self.maxZ),
|
|
Coordinate(self.minX, self.maxY, self.minZ),
|
|
Coordinate(self.minX, self.maxY, self.maxZ),
|
|
Coordinate(self.maxX, self.minY, self.minZ),
|
|
Coordinate(self.maxX, self.minY, self.maxZ),
|
|
Coordinate(self.maxX, self.maxY, self.minZ),
|
|
Coordinate(self.maxX, self.maxY, self.maxZ),
|
|
]
|
|
|
|
def isCorner(self, pos: Coordinate) -> bool:
|
|
return pos in self.getCorners()
|
|
|
|
def isWithinBoundaries(self, pos: Coordinate) -> bool:
|
|
if self.mode3D:
|
|
return self.minX <= pos.x <= self.maxX and self.minY <= pos.y <= self.maxY \
|
|
and self.minZ <= pos.z <= self.maxZ
|
|
else:
|
|
return self.minX <= pos.x <= self.maxX and self.minY <= pos.y <= self.maxY
|
|
|
|
def getActiveCells(self, x: int = None, y: int = None) -> List[Coordinate]:
|
|
if x:
|
|
return [c for c in self.__grid.keys() if c.x == x]
|
|
elif y:
|
|
return [c for c in self.__grid.keys() if c.y == y]
|
|
else:
|
|
return list(self.__grid.keys())
|
|
|
|
def getSum(self, includeNegative: bool = True) -> Numeric:
|
|
grid_sum = 0
|
|
for value in self.__grid.values():
|
|
if includeNegative or value > 0:
|
|
grid_sum += value
|
|
|
|
return grid_sum
|
|
|
|
def getNeighboursOf(self, pos: Coordinate, includeDefault: bool = False, includeDiagonal: bool = True) \
|
|
-> List[Coordinate]:
|
|
neighbours = pos.getNeighbours(
|
|
includeDiagonal=includeDiagonal,
|
|
minX=self.minX, minY=self.minY, minZ=self.minZ,
|
|
maxX=self.maxX, maxY=self.maxY, maxZ=self.maxZ
|
|
)
|
|
if includeDefault:
|
|
return neighbours
|
|
else:
|
|
return [x for x in neighbours if self.get(x) != self.__default]
|
|
|
|
def getNeighbourSum(self, pos: Coordinate, includeNegative: bool = True, includeDiagonal: bool = True) -> Numeric:
|
|
neighbour_sum = 0
|
|
for neighbour in pos.getNeighbours(
|
|
includeDiagonal=includeDiagonal,
|
|
minX=self.minX, minY=self.minY, minZ=self.minZ,
|
|
maxX=self.maxX, maxY=self.maxY, maxZ=self.maxZ
|
|
):
|
|
if neighbour in self.__grid:
|
|
if includeNegative or self.__grid[neighbour] > 0:
|
|
neighbour_sum += self.__grid[neighbour]
|
|
|
|
return neighbour_sum
|
|
|
|
def flip(self, c1: Coordinate, c2: Coordinate):
|
|
buf = self.get(c1)
|
|
self.set(c1, self.get(c2))
|
|
self.set(c2, buf)
|
|
|
|
def transform(self, mode: GridTransformation):
|
|
if self.mode3D:
|
|
raise NotImplementedError() # that will take some time and thought
|
|
|
|
if mode == GridTransformation.FLIP_HORIZONTALLY:
|
|
for x in range(self.minX, (self.maxX - self.minX) // 2 + 1):
|
|
for y in range(self.minY, self.maxY + 1):
|
|
self.flip(Coordinate(x, y), Coordinate(self.maxX - x, y))
|
|
elif mode == GridTransformation.FLIP_VERTICALLY:
|
|
for y in range(self.minY, (self.maxY - self.minY) // 2 + 1):
|
|
for x in range(self.minX, self.maxX + 1):
|
|
self.flip(Coordinate(x, y), Coordinate(x, self.maxY - y))
|
|
elif mode == GridTransformation.FLIP_DIAGONALLY:
|
|
self.transform(GridTransformation.ROTATE_LEFT)
|
|
self.transform(GridTransformation.FLIP_HORIZONTALLY)
|
|
elif mode == GridTransformation.FLIP_DIAGONALLY_REV:
|
|
self.transform(GridTransformation.ROTATE_RIGHT)
|
|
self.transform(GridTransformation.FLIP_HORIZONTALLY)
|
|
elif mode == GridTransformation.ROTATE_LEFT:
|
|
newGrid = Grid()
|
|
for x in range(self.maxX, self.minX - 1, -1):
|
|
for y in range(self.minY, self.maxY + 1):
|
|
newGrid.set(Coordinate(y, self.maxX - x), self.get(Coordinate(x, y)))
|
|
|
|
self.__dict__.update(newGrid.__dict__)
|
|
elif mode == GridTransformation.ROTATE_RIGHT:
|
|
newGrid = Grid()
|
|
for x in range(self.minX, self.maxX + 1):
|
|
for y in range(self.maxY, self.minY - 1, -1):
|
|
newGrid.set(Coordinate(self.maxY - y, x), self.get(Coordinate(x, y)))
|
|
|
|
self.__dict__.update(newGrid.__dict__)
|
|
elif mode == GridTransformation.ROTATE_TWICE:
|
|
self.transform(GridTransformation.ROTATE_RIGHT)
|
|
self.transform(GridTransformation.ROTATE_RIGHT)
|
|
|
|
def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None,
|
|
weighted: bool = False) -> Union[None, List[Coordinate]]:
|
|
@dataclass(frozen=True, order=True)
|
|
class Node:
|
|
f_cost: Numeric
|
|
h_cost: Numeric
|
|
coord: Coordinate
|
|
parent: Optional[Coordinate]
|
|
|
|
if walls is None:
|
|
walls = [self.__default]
|
|
|
|
openNodes: Dict[Coordinate, Node] = {}
|
|
closedNodes: Dict[Coordinate, Node] = {}
|
|
|
|
openNodes[pos_from] = Node(
|
|
pos_from.getDistanceTo(pos_to),
|
|
pos_from.getDistanceTo(pos_to),
|
|
pos_from,
|
|
None
|
|
)
|
|
|
|
while openNodes:
|
|
currentNode = min(openNodes.values())
|
|
currentCoord = currentNode.coord
|
|
#for c, n in openNodes.items():
|
|
# if n.f_cost < currentNode.f_cost:
|
|
# currentNode = n
|
|
# currentCoord = c
|
|
|
|
closedNodes[currentCoord] = currentNode
|
|
del openNodes[currentCoord]
|
|
if currentCoord == pos_to:
|
|
break
|
|
|
|
for neighbour in self.getNeighboursOf(currentCoord, True, includeDiagonal):
|
|
if self.get(neighbour) in walls or neighbour in closedNodes:
|
|
continue
|
|
|
|
if weighted:
|
|
neighbourDist = self.get(neighbour)
|
|
else:
|
|
neighbourDist = currentCoord.getDistanceTo(neighbour, DistanceAlgorithm.MANHATTAN, includeDiagonal)
|
|
|
|
targetDist = neighbour.getDistanceTo(pos_to)
|
|
neighbourNode = Node(
|
|
targetDist + neighbourDist + currentNode.h_cost,
|
|
currentNode.h_cost + neighbourDist,
|
|
neighbour,
|
|
currentCoord
|
|
)
|
|
|
|
if neighbour not in openNodes or neighbourNode.f_cost < openNodes[neighbour].f_cost:
|
|
openNodes[neighbour] = neighbourNode
|
|
|
|
if pos_to not in closedNodes:
|
|
return None
|
|
else:
|
|
currentNode = closedNodes[pos_to]
|
|
pathCoords = [pos_to]
|
|
while currentNode.parent:
|
|
pathCoords.append(currentNode.parent)
|
|
currentNode = closedNodes[currentNode.parent]
|
|
|
|
return pathCoords
|
|
|
|
def print(self, spacer: str = "", true_char: str = None):
|
|
for y in range(self.minY, self.maxY + 1):
|
|
for x in range(self.minX, self.maxX + 1):
|
|
if true_char:
|
|
print(true_char if self.get(Coordinate(x, y)) else " ", end="")
|
|
else:
|
|
print(self.get(Coordinate(x, y)), end="")
|
|
print(spacer, end="")
|
|
|
|
print()
|