diff --git a/tools/grid.py b/tools/grid.py index 5c2b98e..48439d0 100644 --- a/tools/grid.py +++ b/tools/grid.py @@ -1,6 +1,8 @@ from __future__ import annotations -from .coordinate import Coordinate +from .coordinate import Coordinate, DistanceAlgorithm +from dataclasses import dataclass from enum import Enum +from math import inf from typing import Any, List, Union OFF = False @@ -190,6 +192,61 @@ class Grid: self.transform(GridTransformation.ROTATE_RIGHT) self.transform(GridTransformation.ROTATE_RIGHT) + def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None)\ + -> List[Coordinate]: + @dataclass(frozen=True) + class Node: + f_cost: int + h_cost: int + parent: Coordinate + + if walls in None: + walls = [self.__default] + + openNodes: Dict[Coordinate, Node] = {} + closedNodes: Dict[Coordinate, Node] = {} # Dict[Coordinate, Node] + + openNodes[pos_from] = Node( + pos_from.getDistanceTo(pos_to, DistanceAlgorithm.MANHATTAN, includeDiagonal), + pos_from.getDistanceTo(pos_to, DistanceAlgorithm.MANHATTAN, includeDiagonal), + None + ) + + while openNodes: + currentCoord = list(sorted(openNodes, key=lambda n: openNodes[n].f_cost))[0] + currentNode = openNodes[currentCoord] + + closedNodes[currentCoord] = currentNode + del openNodes[currentCoord] + if currentCoord == pos_to: + break + + for neighbour in self.getNeighboursOf(currentCoord, True, includeDiagonal): + if self.get(neighbour) in walls or neighbour in closedNodes: + continue + + neighbourDist = currentCoord.getDistanceTo(neighbour, DistanceAlgorithm.MANHATTAN, includeDiagonal) + targetDist = neighbour.getDistanceTo(pos_to, DistanceAlgorithm.MANHATTAN, includeDiagonal) + neighbourNode = Node( + targetDist + neighbourDist + currentNode.h_cost, + currentNode[2] + neighbourDist, + currentCoord + ) + + if neighbour not in openNodes or neighbourNode.f_cost < openNodes[neighbour].f_cost: + openNodes[neighbour] = neighbourNode + + if pos_to not in closedNodes: + return None + else: + currentNode = closedNodes[pos_to] + pathCoords = [pos_to] + while currentNode.parent: + pathCoords.append(currentNode.parent) + currentNode = closedNodes[currentNode.parent] + + return pathCoords + def print(self, spacer: str = ""): for y in range(self.minY, self.maxY + 1): for x in range(self.minX, self.maxX + 1):