diff --git a/tools/grid.py b/tools/grid.py index fa481af..d2d9190 100644 --- a/tools/grid.py +++ b/tools/grid.py @@ -2,6 +2,7 @@ from __future__ import annotations from .coordinate import Coordinate, DistanceAlgorithm from dataclasses import dataclass from enum import Enum +from math import inf from typing import Any, Dict, List, Optional, Union OFF = False @@ -202,29 +203,33 @@ class Grid: self.transform(GridTransformation.ROTATE_RIGHT) self.transform(GridTransformation.ROTATE_RIGHT) - def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None)\ - -> Union[None, List[Coordinate]]: + def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None, + weighted: bool = False) -> Union[None, List[Coordinate]]: @dataclass(frozen=True) class Node: - f_cost: int - h_cost: int + f_cost: Numeric + h_cost: Numeric parent: Optional[Coordinate] - if walls in None: + if walls is None: walls = [self.__default] openNodes: Dict[Coordinate, Node] = {} closedNodes: Dict[Coordinate, Node] = {} openNodes[pos_from] = Node( - pos_from.getDistanceTo(pos_to, DistanceAlgorithm.MANHATTAN, includeDiagonal), - pos_from.getDistanceTo(pos_to, DistanceAlgorithm.MANHATTAN, includeDiagonal), + pos_from.getDistanceTo(pos_to), + pos_from.getDistanceTo(pos_to), None ) while openNodes: - currentCoord = list(sorted(openNodes, key=lambda n: openNodes[n].f_cost))[0] - currentNode = openNodes[currentCoord] + currentNode = Node(inf, 0, None) + currentCoord = None + for c, n in openNodes.items(): + if n.f_cost < currentNode.f_cost: + currentNode = n + currentCoord = c closedNodes[currentCoord] = currentNode del openNodes[currentCoord] @@ -235,8 +240,12 @@ class Grid: if self.get(neighbour) in walls or neighbour in closedNodes: continue - neighbourDist = currentCoord.getDistanceTo(neighbour, DistanceAlgorithm.MANHATTAN, includeDiagonal) - targetDist = neighbour.getDistanceTo(pos_to, DistanceAlgorithm.MANHATTAN, includeDiagonal) + if weighted: + neighbourDist = self.get(neighbour) + else: + neighbourDist = currentCoord.getDistanceTo(neighbour, DistanceAlgorithm.MANHATTAN, includeDiagonal) + + targetDist = neighbour.getDistanceTo(pos_to) neighbourNode = Node( targetDist + neighbourDist + currentNode.h_cost, currentNode.h_cost + neighbourDist,