make things faster/cleaner

This commit is contained in:
Stefan Harmuth 2021-12-15 11:30:30 +01:00
parent 2c859033fd
commit 11604338e8
2 changed files with 12 additions and 23 deletions

View File

@ -12,38 +12,35 @@ class DistanceAlgorithm(Enum):
CHEBYSHEV = 2 CHEBYSHEV = 2
CHESSBOARD = 2 CHESSBOARD = 2
@dataclass(frozen=True, order=True) @dataclass(frozen=True, order=True)
class Coordinate: class Coordinate:
x: int x: int
y: int y: int
z: Optional[int] = None z: Optional[int] = None
def getDistanceTo(self, target: Coordinate, mode: DistanceAlgorithm = DistanceAlgorithm.EUCLIDEAN, def getDistanceTo(self, target: Coordinate, algorithm: DistanceAlgorithm = DistanceAlgorithm.EUCLIDEAN,
includeDiagonals: bool = False) -> Union[int, float]: includeDiagonals: bool = False) -> Union[int, float]:
""" """
Get distance to target Coordinate Get distance to target Coordinate
:param target: :param target:
:param mode: Calculation Mode (0 = Manhattan, 1 = Pythagoras) :param algorithm: Calculation Mode (0 = Manhattan, 1 = Pythagoras)
:param includeDiagonals: in Manhattan Mode specify if diagonal :param includeDiagonals: in Manhattan Mode specify if diagonal
movements are allowed (counts as 1.4 in 2D, 1.7 in 3D) movements are allowed (counts as 1.4 in 2D, 1.7 in 3D)
:return: Distance to Target :return: Distance to Target
""" """
assert isinstance(target, Coordinate) if algorithm == DistanceAlgorithm.EUCLIDEAN:
assert isinstance(mode, DistanceAlgorithm)
assert isinstance(includeDiagonals, bool)
if mode == DistanceAlgorithm.EUCLIDEAN:
if self.z is None: if self.z is None:
return sqrt(abs(self.x - target.x) ** 2 + abs(self.y - target.y) ** 2) return sqrt(abs(self.x - target.x) ** 2 + abs(self.y - target.y) ** 2)
else: else:
return sqrt(abs(self.x - target.x) ** 2 + abs(self.y - target.y) ** 2 + abs(self.z - target.z) ** 2) return sqrt(abs(self.x - target.x) ** 2 + abs(self.y - target.y) ** 2 + abs(self.z - target.z) ** 2)
elif mode == DistanceAlgorithm.CHEBYSHEV: elif algorithm == DistanceAlgorithm.CHEBYSHEV:
if self.z is None: if self.z is None:
return max(abs(target.x - self.x), abs(target.y - self.y)) return max(abs(target.x - self.x), abs(target.y - self.y))
else: else:
return max(abs(target.x - self.x), abs(target.y - self.y), abs(target.z - self.z)) return max(abs(target.x - self.x), abs(target.y - self.y), abs(target.z - self.z))
elif mode == DistanceAlgorithm.MANHATTAN: elif algorithm == DistanceAlgorithm.MANHATTAN:
if not includeDiagonals: if not includeDiagonals:
if self.z is None: if self.z is None:
return abs(self.x - target.x) + abs(self.y - target.y) return abs(self.x - target.x) + abs(self.y - target.y)
@ -93,9 +90,7 @@ class Coordinate:
nb_list = [(x, y, z) for x in [-1, 0, 1] for y in [-1, 0, 1] for z in [-1, 0, 1]] nb_list = [(x, y, z) for x in [-1, 0, 1] for y in [-1, 0, 1] for z in [-1, 0, 1]]
nb_list.remove((0, 0, 0)) nb_list.remove((0, 0, 0))
else: else:
nb_list = [ nb_list = [(-1, 0, 0), (0, -1, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1), (0, 0, -1)]
(-1, 0, 0), (0, -1, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1), (0, 0, -1)
]
for dx, dy, dz in nb_list: for dx, dy, dz in nb_list:
tx = self.x + dx tx = self.x + dx

View File

@ -146,18 +146,13 @@ class Grid:
if includeDefault: if includeDefault:
return neighbours return neighbours
else: else:
return [x for x in neighbours if self.get(x) != self.__default] return [x for x in neighbours if x in self.__grid]
def getNeighbourSum(self, pos: Coordinate, includeNegative: bool = True, includeDiagonal: bool = True) -> Numeric: def getNeighbourSum(self, pos: Coordinate, includeNegative: bool = True, includeDiagonal: bool = True) -> Numeric:
neighbour_sum = 0 neighbour_sum = 0
for neighbour in pos.getNeighbours( for neighbour in self.getNeighboursOf(pos, includeDefault=includeDiagonal):
includeDiagonal=includeDiagonal, if includeNegative or self.get(neighbour) > 0:
minX=self.minX, minY=self.minY, minZ=self.minZ, neighbour_sum += self.get(neighbour)
maxX=self.maxX, maxY=self.maxY, maxZ=self.maxZ
):
if neighbour in self.__grid:
if includeNegative or self.__grid[neighbour] > 0:
neighbour_sum += self.__grid[neighbour]
return neighbour_sum return neighbour_sum
@ -238,10 +233,9 @@ class Grid:
targetDist = neighbour.getDistanceTo(pos_to) targetDist = neighbour.getDistanceTo(pos_to)
f_cost = targetDist + neighbourDist + currentNode[1] f_cost = targetDist + neighbourDist + currentNode[1]
neighbourNode = (f_cost, currentNode[1] + neighbourDist, currentCoord)
if neighbour not in openNodes or f_cost < openNodes[neighbour][0]: if neighbour not in openNodes or f_cost < openNodes[neighbour][0]:
openNodes[neighbour] = neighbourNode openNodes[neighbour] = (f_cost, currentNode[1] + neighbourDist, currentCoord)
heappush(f_costs, (f_cost, neighbour)) heappush(f_costs, (f_cost, neighbour))
if pos_to not in closedNodes: if pos_to not in closedNodes: