From 23f9d800f8b3a532b0e4c84a1780d901db0e79c7 Mon Sep 17 00:00:00 2001 From: Stefan Harmuth Date: Fri, 17 Dec 2021 16:37:00 +0100 Subject: [PATCH] day16: muuuuch nicer code --- day16.py | 144 +++++++++++++++++++++---------------------------------- 1 file changed, 54 insertions(+), 90 deletions(-) diff --git a/day16.py b/day16.py index 8ccb452..373dc9c 100644 --- a/day16.py +++ b/day16.py @@ -1,100 +1,64 @@ from math import prod - from tools.aoc import AOCDay from typing import Any, List -def get_packet(bit_rep: str) -> (str, int): # packet, packet_length - msg_type = int(bit_rep[3:6], 2) - index = 6 +class Packet: + def __init__(self, packet_str: str): + self.packet_str = packet_str + self.p_version = int(packet_str[:3], 2) + self.p_type = int(packet_str[3:6], 2) + self.p_len = 6 + self.value = 0 + self.sub_packets = [] + self.__parse() - if msg_type == 4: - while bit_rep[index] == "1": - index += 5 - index += 5 - else: - op_type = bit_rep[index] - index += 1 - if op_type == "0": - sub_pkg_len = int(bit_rep[index:index+15], 2) - index += 15 + sub_pkg_len + def __parse(self): + if self.p_type == 4: + while self.packet_str[self.p_len] == "1": + self.value = self.value * 16 + int(self.packet_str[self.p_len + 1:self.p_len + 5], 2) + self.p_len += 5 + self.value = self.value * 16 + int(self.packet_str[self.p_len + 1:self.p_len + 5], 2) + self.p_len += 5 else: - sub_pkg_count = int(bit_rep[index:index+11], 2) - index += 11 - for x in range(sub_pkg_count): - _, sub_pkg_len = get_packet(bit_rep[index:]) - index += sub_pkg_len + op_type = self.packet_str[self.p_len] + self.p_len += 1 + if op_type == "0": + sub_pkg_len = int(self.packet_str[self.p_len:self.p_len + 15], 2) + self.p_len += 15 + sub_pkg_len + sub_packages = self.packet_str[self.p_len - sub_pkg_len:self.p_len] + while "1" in sub_packages: + sub_packet = Packet(sub_packages) + self.sub_packets.append(sub_packet) + sub_packages = sub_packages[sub_packet.p_len:] + else: + sub_pkg_count = int(self.packet_str[self.p_len:self.p_len + 11], 2) + self.p_len += 11 + for x in range(sub_pkg_count): + sub_packet = Packet(self.packet_str[self.p_len:]) + self.sub_packets.append(sub_packet) + self.p_len += sub_packet.p_len - return bit_rep[:index], index + def get_versions(self) -> int: + return self.p_version + sum(p.get_versions() for p in self.sub_packets) - -def get_subpackages(packet: str) -> List[str]: - ptype = int(packet[3:6], 2) - if ptype == 4: - return [packet] - else: - plist = [] - op_type = packet[6] - if op_type == "0": - sub_pkg_len = int(packet[7:22], 2) - sub_packages = packet[22:22+sub_pkg_len] - while "1" in sub_packages: - sub_packet, sub_pkg_len = get_packet(sub_packages) - sub_packages = sub_packages[sub_pkg_len:] - plist.append(sub_packet) - else: - sub_pkg_count = int(packet[7:18], 2) - index = 18 - for x in range(sub_pkg_count): - sub_packet, sub_pkg_len = get_packet(packet[index:]) - index += sub_pkg_len - plist.append(sub_packet) - - return plist - - -def get_versions(packet: str) -> int: - version = int(packet[0:3], 2) - ptype = int(packet[3:6], 2) - if ptype == 4: - return version - else: - for p in get_subpackages(packet): - version += get_versions(p) - - return version - - -def get_value(packet: str) -> int: - while "1" in packet: - this_packet, pkg_len = get_packet(packet) - pkg_type = int(this_packet[3:6], 2) - packet = packet[pkg_len:] - if pkg_type == 0: - return sum(get_value(p) for p in get_subpackages(this_packet)) - elif pkg_type == 1: - return prod(get_value(p) for p in get_subpackages(this_packet)) - elif pkg_type == 2: - return min(get_value(p) for p in get_subpackages(this_packet)) - elif pkg_type == 3: - return max(get_value(p) for p in get_subpackages(this_packet)) - elif pkg_type == 4: - index = 6 - value = 0 - while this_packet[index] == "1": - value = value * 16 + int(this_packet[index + 1:index + 5], 2) - index += 5 - value = value * 16 + int(this_packet[index + 1:index + 5], 2) - return value - elif pkg_type == 5: - sub_pkg = get_subpackages(this_packet) - return get_value(sub_pkg[0]) > get_value(sub_pkg[1]) - elif pkg_type == 6: - sub_pkg = get_subpackages(this_packet) - return get_value(sub_pkg[0]) < get_value(sub_pkg[1]) - elif pkg_type == 7: - sub_pkg = get_subpackages(this_packet) - return get_value(sub_pkg[0]) == get_value(sub_pkg[1]) + def get_value(self) -> int: + if self.p_type == 0: + return sum(p.get_value() for p in self.sub_packets) + elif self.p_type == 1: + return prod(p.get_value() for p in self.sub_packets) + elif self.p_type == 2: + return min(p.get_value() for p in self.sub_packets) + elif self.p_type == 3: + return max(p.get_value() for p in self.sub_packets) + elif self.p_type == 4: + return self.value + elif self.p_type == 5: + return self.sub_packets[0].get_value() > self.sub_packets[1].get_value() + elif self.p_type == 6: + return self.sub_packets[0].get_value() < self.sub_packets[1].get_value() + elif self.p_type == 7: + return self.sub_packets[0].get_value() == self.sub_packets[1].get_value() class Day(AOCDay): @@ -106,7 +70,7 @@ class Day(AOCDay): return "0" * (len(self.getInput()) * 4 - len(bits)) + bits def part1(self) -> Any: - return get_versions(self.getBits()) + return Packet(self.getBits()).get_versions() def part2(self) -> Any: - return get_value(self.getBits()) + return Packet(self.getBits()).get_value()