grid.getPath(): Use heap to ease finding smallest f_cost node

This commit is contained in:
Stefan Harmuth 2021-12-15 11:09:58 +01:00
parent af2aea1a34
commit 2c859033fd

View File

@ -1,10 +1,8 @@
from __future__ import annotations
from .coordinate import Coordinate, DistanceAlgorithm
from dataclasses import dataclass
from enum import Enum
from heapq import heappop, heappush
from math import inf
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Union
OFF = False
ON = True
@ -206,33 +204,21 @@ class Grid:
def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None,
weighted: bool = False) -> Union[None, List[Coordinate]]:
@dataclass(frozen=True, order=True)
class Node:
f_cost: Numeric
h_cost: Numeric
coord: Coordinate
parent: Optional[Coordinate]
f_costs = []
if walls is None:
walls = [self.__default]
openNodes: Dict[Coordinate, Node] = {}
closedNodes: Dict[Coordinate, Node] = {}
openNodes: Dict[Coordinate, tuple] = {}
closedNodes: Dict[Coordinate, tuple] = {}
openNodes[pos_from] = Node(
pos_from.getDistanceTo(pos_to),
pos_from.getDistanceTo(pos_to),
pos_from,
None
)
openNodes[pos_from] = (0, pos_from.getDistanceTo(pos_to), None)
heappush(f_costs, (0, pos_from))
while openNodes:
currentNode = min(openNodes.values())
currentCoord = currentNode.coord
#for c, n in openNodes.items():
# if n.f_cost < currentNode.f_cost:
# currentNode = n
# currentCoord = c
_, currentCoord = heappop(f_costs)
if currentCoord not in openNodes:
continue
currentNode = openNodes[currentCoord]
closedNodes[currentCoord] = currentNode
del openNodes[currentCoord]
@ -245,28 +231,27 @@ class Grid:
if weighted:
neighbourDist = self.get(neighbour)
elif not includeDiagonal:
neighbourDist = 1
else:
neighbourDist = currentCoord.getDistanceTo(neighbour, DistanceAlgorithm.MANHATTAN, includeDiagonal)
targetDist = neighbour.getDistanceTo(pos_to)
neighbourNode = Node(
targetDist + neighbourDist + currentNode.h_cost,
currentNode.h_cost + neighbourDist,
neighbour,
currentCoord
)
f_cost = targetDist + neighbourDist + currentNode[1]
neighbourNode = (f_cost, currentNode[1] + neighbourDist, currentCoord)
if neighbour not in openNodes or neighbourNode.f_cost < openNodes[neighbour].f_cost:
if neighbour not in openNodes or f_cost < openNodes[neighbour][0]:
openNodes[neighbour] = neighbourNode
heappush(f_costs, (f_cost, neighbour))
if pos_to not in closedNodes:
return None
else:
currentNode = closedNodes[pos_to]
pathCoords = [pos_to]
while currentNode.parent:
pathCoords.append(currentNode.parent)
currentNode = closedNodes[currentNode.parent]
while currentNode[2]:
pathCoords.append(currentNode[2])
currentNode = closedNodes[currentNode[2]]
return pathCoords