grid.getPath(): allow for grid values to be movement/distance weights

also: lambdas are sloooooooow
This commit is contained in:
Stefan Harmuth 2021-12-15 07:49:57 +01:00
parent 1c83a41fd2
commit 235a545c70

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from .coordinate import Coordinate, DistanceAlgorithm
from dataclasses import dataclass
from enum import Enum
from math import inf
from typing import Any, Dict, List, Optional, Union
OFF = False
@ -202,29 +203,33 @@ class Grid:
self.transform(GridTransformation.ROTATE_RIGHT)
self.transform(GridTransformation.ROTATE_RIGHT)
def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None)\
-> Union[None, List[Coordinate]]:
def getPath(self, pos_from: Coordinate, pos_to: Coordinate, includeDiagonal: bool, walls: List[Any] = None,
weighted: bool = False) -> Union[None, List[Coordinate]]:
@dataclass(frozen=True)
class Node:
f_cost: int
h_cost: int
f_cost: Numeric
h_cost: Numeric
parent: Optional[Coordinate]
if walls in None:
if walls is None:
walls = [self.__default]
openNodes: Dict[Coordinate, Node] = {}
closedNodes: Dict[Coordinate, Node] = {}
openNodes[pos_from] = Node(
pos_from.getDistanceTo(pos_to, DistanceAlgorithm.MANHATTAN, includeDiagonal),
pos_from.getDistanceTo(pos_to, DistanceAlgorithm.MANHATTAN, includeDiagonal),
pos_from.getDistanceTo(pos_to),
pos_from.getDistanceTo(pos_to),
None
)
while openNodes:
currentCoord = list(sorted(openNodes, key=lambda n: openNodes[n].f_cost))[0]
currentNode = openNodes[currentCoord]
currentNode = Node(inf, 0, None)
currentCoord = None
for c, n in openNodes.items():
if n.f_cost < currentNode.f_cost:
currentNode = n
currentCoord = c
closedNodes[currentCoord] = currentNode
del openNodes[currentCoord]
@ -235,8 +240,12 @@ class Grid:
if self.get(neighbour) in walls or neighbour in closedNodes:
continue
if weighted:
neighbourDist = self.get(neighbour)
else:
neighbourDist = currentCoord.getDistanceTo(neighbour, DistanceAlgorithm.MANHATTAN, includeDiagonal)
targetDist = neighbour.getDistanceTo(pos_to, DistanceAlgorithm.MANHATTAN, includeDiagonal)
targetDist = neighbour.getDistanceTo(pos_to)
neighbourNode = Node(
targetDist + neighbourDist + currentNode.h_cost,
currentNode.h_cost + neighbourDist,