diff --git a/tools/grid.py b/tools/grid.py index d2d9190..9a46e58 100644 --- a/tools/grid.py +++ b/tools/grid.py @@ -2,6 +2,7 @@ from __future__ import annotations from .coordinate import Coordinate, DistanceAlgorithm from dataclasses import dataclass from enum import Enum +from heapq import heappop, heappush from math import inf from typing import Any, Dict, List, Optional, Union @@ -205,10 +206,11 @@ class Grid: def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None, weighted: bool = False) -> Union[None, List[Coordinate]]: - @dataclass(frozen=True) + @dataclass(frozen=True, order=True) class Node: f_cost: Numeric h_cost: Numeric + coord: Coordinate parent: Optional[Coordinate] if walls is None: @@ -220,16 +222,17 @@ class Grid: openNodes[pos_from] = Node( pos_from.getDistanceTo(pos_to), pos_from.getDistanceTo(pos_to), + pos_from, None ) while openNodes: - currentNode = Node(inf, 0, None) - currentCoord = None - for c, n in openNodes.items(): - if n.f_cost < currentNode.f_cost: - currentNode = n - currentCoord = c + currentNode = min(openNodes.values()) + currentCoord = currentNode.coord + #for c, n in openNodes.items(): + # if n.f_cost < currentNode.f_cost: + # currentNode = n + # currentCoord = c closedNodes[currentCoord] = currentNode del openNodes[currentCoord] @@ -249,6 +252,7 @@ class Grid: neighbourNode = Node( targetDist + neighbourDist + currentNode.h_cost, currentNode.h_cost + neighbourDist, + neighbour, currentCoord )