still trying to make grid.getPath() faster

This commit is contained in:
Stefan Harmuth 2021-12-15 09:42:10 +01:00
parent 235a545c70
commit af2aea1a34

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from .coordinate import Coordinate, DistanceAlgorithm from .coordinate import Coordinate, DistanceAlgorithm
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from heapq import heappop, heappush
from math import inf from math import inf
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
@ -205,10 +206,11 @@ class Grid:
def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None, def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None,
weighted: bool = False) -> Union[None, List[Coordinate]]: weighted: bool = False) -> Union[None, List[Coordinate]]:
@dataclass(frozen=True) @dataclass(frozen=True, order=True)
class Node: class Node:
f_cost: Numeric f_cost: Numeric
h_cost: Numeric h_cost: Numeric
coord: Coordinate
parent: Optional[Coordinate] parent: Optional[Coordinate]
if walls is None: if walls is None:
@ -220,16 +222,17 @@ class Grid:
openNodes[pos_from] = Node( openNodes[pos_from] = Node(
pos_from.getDistanceTo(pos_to), pos_from.getDistanceTo(pos_to),
pos_from.getDistanceTo(pos_to), pos_from.getDistanceTo(pos_to),
pos_from,
None None
) )
while openNodes: while openNodes:
currentNode = Node(inf, 0, None) currentNode = min(openNodes.values())
currentCoord = None currentCoord = currentNode.coord
for c, n in openNodes.items(): #for c, n in openNodes.items():
if n.f_cost < currentNode.f_cost: # if n.f_cost < currentNode.f_cost:
currentNode = n # currentNode = n
currentCoord = c # currentCoord = c
closedNodes[currentCoord] = currentNode closedNodes[currentCoord] = currentNode
del openNodes[currentCoord] del openNodes[currentCoord]
@ -249,6 +252,7 @@ class Grid:
neighbourNode = Node( neighbourNode = Node(
targetDist + neighbourDist + currentNode.h_cost, targetDist + neighbourDist + currentNode.h_cost,
currentNode.h_cost + neighbourDist, currentNode.h_cost + neighbourDist,
neighbour,
currentCoord currentCoord
) )