diff --git a/tools/grid.py b/tools/grid.py index 9a46e58..80b0f05 100644 --- a/tools/grid.py +++ b/tools/grid.py @@ -1,10 +1,8 @@ from __future__ import annotations from .coordinate import Coordinate, DistanceAlgorithm -from dataclasses import dataclass from enum import Enum from heapq import heappop, heappush -from math import inf -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union OFF = False ON = True @@ -206,33 +204,21 @@ class Grid: def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None, weighted: bool = False) -> Union[None, List[Coordinate]]: - @dataclass(frozen=True, order=True) - class Node: - f_cost: Numeric - h_cost: Numeric - coord: Coordinate - parent: Optional[Coordinate] - + f_costs = [] if walls is None: walls = [self.__default] - openNodes: Dict[Coordinate, Node] = {} - closedNodes: Dict[Coordinate, Node] = {} + openNodes: Dict[Coordinate, tuple] = {} + closedNodes: Dict[Coordinate, tuple] = {} - openNodes[pos_from] = Node( - pos_from.getDistanceTo(pos_to), - pos_from.getDistanceTo(pos_to), - pos_from, - None - ) + openNodes[pos_from] = (0, pos_from.getDistanceTo(pos_to), None) + heappush(f_costs, (0, pos_from)) while openNodes: - currentNode = min(openNodes.values()) - currentCoord = currentNode.coord - #for c, n in openNodes.items(): - # if n.f_cost < currentNode.f_cost: - # currentNode = n - # currentCoord = c + _, currentCoord = heappop(f_costs) + if currentCoord not in openNodes: + continue + currentNode = openNodes[currentCoord] closedNodes[currentCoord] = currentNode del openNodes[currentCoord] @@ -245,28 +231,27 @@ class Grid: if weighted: neighbourDist = self.get(neighbour) + elif not includeDiagonal: + neighbourDist = 1 else: neighbourDist = currentCoord.getDistanceTo(neighbour, DistanceAlgorithm.MANHATTAN, includeDiagonal) targetDist = neighbour.getDistanceTo(pos_to) - neighbourNode = Node( - targetDist + neighbourDist + currentNode.h_cost, - currentNode.h_cost + neighbourDist, - neighbour, - currentCoord - ) + f_cost = targetDist + neighbourDist + currentNode[1] + neighbourNode = (f_cost, currentNode[1] + neighbourDist, currentCoord) - if neighbour not in openNodes or neighbourNode.f_cost < openNodes[neighbour].f_cost: + if neighbour not in openNodes or f_cost < openNodes[neighbour][0]: openNodes[neighbour] = neighbourNode + heappush(f_costs, (f_cost, neighbour)) if pos_to not in closedNodes: return None else: currentNode = closedNodes[pos_to] pathCoords = [pos_to] - while currentNode.parent: - pathCoords.append(currentNode.parent) - currentNode = closedNodes[currentNode.parent] + while currentNode[2]: + pathCoords.append(currentNode[2]) + currentNode = closedNodes[currentNode[2]] return pathCoords