From d916790cb1967150a217e046e6ffbafadf08d6c0 Mon Sep 17 00:00:00 2001 From: Stefan Harmuth Date: Tue, 13 Dec 2022 12:05:58 +0100 Subject: [PATCH] day13 - now it's a correctly behaving comparator --- day13.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/day13.py b/day13.py index 6e6cd6d..f1c9223 100644 --- a/day13.py +++ b/day13.py @@ -7,21 +7,21 @@ from typing import Any def packet_compare(left: list, right: list) -> int: for l_value, r_value in zip(left, right): - if isinstance(l_value, list) and isinstance(r_value, int): - r_value = [r_value] - elif isinstance(l_value, int) and isinstance(r_value, list): - l_value = [l_value] + if l_value == r_value: + continue - if r_value != l_value: - if isinstance(l_value, int): + match l_value, r_value: + case int(), int(): return compare(l_value, r_value) - else: - return packet_compare(l_value, r_value) + case int(), list(): + l_value = [l_value] + case list(), int(): + r_value = [r_value] - if len(left) != len(right): - return compare(len(left), len(right)) + if c := packet_compare(l_value, r_value): + return c - return -1 + return compare(len(left), len(right)) class Day(AOCDay): @@ -45,7 +45,7 @@ class Day(AOCDay): packets = self.parse_input() index_sum = 0 for x in range(len(packets) // 2): - if packet_compare(packets[x * 2], packets[x * 2 + 1]) == -1: + if packet_compare(packets[x * 2], packets[x * 2 + 1]) < 1: index_sum += x + 1 return index_sum @@ -53,9 +53,9 @@ class Day(AOCDay): def part2(self) -> Any: packets = self.parse_input() packets.extend([[[2]], [[6]]]) - sorted_packets = sorted(packets, key=cmp_to_key(packet_compare)) + packets.sort(key=cmp_to_key(packet_compare)) - return (sorted_packets.index([[2]]) + 1) * (sorted_packets.index([[6]]) + 1) + return (packets.index([[2]]) + 1) * (packets.index([[6]]) + 1) if __name__ == '__main__':