grid.getPath(): Use heap to ease finding smallest f_cost node

This commit is contained in:
Stefan Harmuth 2021-12-15 11:09:58 +01:00
parent af2aea1a34
commit 2c859033fd

View File

@ -1,10 +1,8 @@
from __future__ import annotations from __future__ import annotations
from .coordinate import Coordinate, DistanceAlgorithm from .coordinate import Coordinate, DistanceAlgorithm
from dataclasses import dataclass
from enum import Enum from enum import Enum
from heapq import heappop, heappush from heapq import heappop, heappush
from math import inf from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
OFF = False OFF = False
ON = True ON = True
@ -206,33 +204,21 @@ class Grid:
def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None, def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None,
weighted: bool = False) -> Union[None, List[Coordinate]]: weighted: bool = False) -> Union[None, List[Coordinate]]:
@dataclass(frozen=True, order=True) f_costs = []
class Node:
f_cost: Numeric
h_cost: Numeric
coord: Coordinate
parent: Optional[Coordinate]
if walls is None: if walls is None:
walls = [self.__default] walls = [self.__default]
openNodes: Dict[Coordinate, Node] = {} openNodes: Dict[Coordinate, tuple] = {}
closedNodes: Dict[Coordinate, Node] = {} closedNodes: Dict[Coordinate, tuple] = {}
openNodes[pos_from] = Node( openNodes[pos_from] = (0, pos_from.getDistanceTo(pos_to), None)
pos_from.getDistanceTo(pos_to), heappush(f_costs, (0, pos_from))
pos_from.getDistanceTo(pos_to),
pos_from,
None
)
while openNodes: while openNodes:
currentNode = min(openNodes.values()) _, currentCoord = heappop(f_costs)
currentCoord = currentNode.coord if currentCoord not in openNodes:
#for c, n in openNodes.items(): continue
# if n.f_cost < currentNode.f_cost: currentNode = openNodes[currentCoord]
# currentNode = n
# currentCoord = c
closedNodes[currentCoord] = currentNode closedNodes[currentCoord] = currentNode
del openNodes[currentCoord] del openNodes[currentCoord]
@ -245,28 +231,27 @@ class Grid:
if weighted: if weighted:
neighbourDist = self.get(neighbour) neighbourDist = self.get(neighbour)
elif not includeDiagonal:
neighbourDist = 1
else: else:
neighbourDist = currentCoord.getDistanceTo(neighbour, DistanceAlgorithm.MANHATTAN, includeDiagonal) neighbourDist = currentCoord.getDistanceTo(neighbour, DistanceAlgorithm.MANHATTAN, includeDiagonal)
targetDist = neighbour.getDistanceTo(pos_to) targetDist = neighbour.getDistanceTo(pos_to)
neighbourNode = Node( f_cost = targetDist + neighbourDist + currentNode[1]
targetDist + neighbourDist + currentNode.h_cost, neighbourNode = (f_cost, currentNode[1] + neighbourDist, currentCoord)
currentNode.h_cost + neighbourDist,
neighbour,
currentCoord
)
if neighbour not in openNodes or neighbourNode.f_cost < openNodes[neighbour].f_cost: if neighbour not in openNodes or f_cost < openNodes[neighbour][0]:
openNodes[neighbour] = neighbourNode openNodes[neighbour] = neighbourNode
heappush(f_costs, (f_cost, neighbour))
if pos_to not in closedNodes: if pos_to not in closedNodes:
return None return None
else: else:
currentNode = closedNodes[pos_to] currentNode = closedNodes[pos_to]
pathCoords = [pos_to] pathCoords = [pos_to]
while currentNode.parent: while currentNode[2]:
pathCoords.append(currentNode.parent) pathCoords.append(currentNode[2])
currentNode = closedNodes[currentNode.parent] currentNode = closedNodes[currentNode[2]]
return pathCoords return pathCoords