from math import prod from tools.aoc import AOCDay from typing import Any class Packet: def __init__(self, packet_str: str): self.packet_str = packet_str self.p_version = int(packet_str[:3], 2) self.p_type = int(packet_str[3:6], 2) self.p_len = 6 self.value = 0 self.sub_packets = [] self.__parse() def __parse(self): if self.p_type == 4: while True: self.value = self.value * 16 + int(self.packet_str[self.p_len + 1:self.p_len + 5], 2) self.p_len += 5 if self.packet_str[self.p_len - 5] == "0": break else: op_type = self.packet_str[self.p_len] self.p_len += 1 if op_type == "0": sub_pkg_len = int(self.packet_str[self.p_len:self.p_len + 15], 2) self.p_len += 15 + sub_pkg_len sub_packages = self.packet_str[self.p_len - sub_pkg_len:self.p_len] while "1" in sub_packages: sub_packet = Packet(sub_packages) self.sub_packets.append(sub_packet) sub_packages = sub_packages[sub_packet.p_len:] else: sub_pkg_count = int(self.packet_str[self.p_len:self.p_len + 11], 2) self.p_len += 11 for x in range(sub_pkg_count): sub_packet = Packet(self.packet_str[self.p_len:]) self.sub_packets.append(sub_packet) self.p_len += sub_packet.p_len def get_versions(self) -> int: return self.p_version + sum(p.get_versions() for p in self.sub_packets) def get_value(self) -> int: if self.p_type == 0: return sum(p.get_value() for p in self.sub_packets) elif self.p_type == 1: return prod(p.get_value() for p in self.sub_packets) elif self.p_type == 2: return min(p.get_value() for p in self.sub_packets) elif self.p_type == 3: return max(p.get_value() for p in self.sub_packets) elif self.p_type == 4: return self.value elif self.p_type == 5: return self.sub_packets[0].get_value() > self.sub_packets[1].get_value() elif self.p_type == 6: return self.sub_packets[0].get_value() < self.sub_packets[1].get_value() elif self.p_type == 7: return self.sub_packets[0].get_value() == self.sub_packets[1].get_value() class Day(AOCDay): inputs = [ [ (16, "test_input16_1_0"), (12, "test_input16_1_1"), (23, "test_input16_1_2"), (31, "test_input16_1_3"), (1007, "input16") ], [ (3, "test_input16_2_0"), (54, "test_input16_2_1"), (7, "test_input16_2_2"), (9, "test_input16_2_3"), (True, "test_input16_2_4"), (False, "test_input16_2_5"), (False, "test_input16_2_6"), (True, "test_input16_2_7"), (834151779165, "input16") ] ] def getBits(self) -> str: bits = bin(int(self.getInput(), 16))[2:] return "0" * (len(self.getInput()) * 4 - len(bits)) + bits def part1(self) -> Any: return Packet(self.getBits()).get_versions() def part2(self) -> Any: return Packet(self.getBits()).get_value() if __name__ == '__main__': day = Day(2021, 16) day.run(verbose=True)