aoc2021/day16.py
2021-12-17 16:39:20 +01:00

77 lines
2.9 KiB
Python

from math import prod
from tools.aoc import AOCDay
from typing import Any
class Packet:
def __init__(self, packet_str: str):
self.packet_str = packet_str
self.p_version = int(packet_str[:3], 2)
self.p_type = int(packet_str[3:6], 2)
self.p_len = 6
self.value = 0
self.sub_packets = []
self.__parse()
def __parse(self):
if self.p_type == 4:
while True:
self.value = self.value * 16 + int(self.packet_str[self.p_len + 1:self.p_len + 5], 2)
self.p_len += 5
if self.packet_str[self.p_len - 5] == "0":
break
else:
op_type = self.packet_str[self.p_len]
self.p_len += 1
if op_type == "0":
sub_pkg_len = int(self.packet_str[self.p_len:self.p_len + 15], 2)
self.p_len += 15 + sub_pkg_len
sub_packages = self.packet_str[self.p_len - sub_pkg_len:self.p_len]
while "1" in sub_packages:
sub_packet = Packet(sub_packages)
self.sub_packets.append(sub_packet)
sub_packages = sub_packages[sub_packet.p_len:]
else:
sub_pkg_count = int(self.packet_str[self.p_len:self.p_len + 11], 2)
self.p_len += 11
for x in range(sub_pkg_count):
sub_packet = Packet(self.packet_str[self.p_len:])
self.sub_packets.append(sub_packet)
self.p_len += sub_packet.p_len
def get_versions(self) -> int:
return self.p_version + sum(p.get_versions() for p in self.sub_packets)
def get_value(self) -> int:
if self.p_type == 0:
return sum(p.get_value() for p in self.sub_packets)
elif self.p_type == 1:
return prod(p.get_value() for p in self.sub_packets)
elif self.p_type == 2:
return min(p.get_value() for p in self.sub_packets)
elif self.p_type == 3:
return max(p.get_value() for p in self.sub_packets)
elif self.p_type == 4:
return self.value
elif self.p_type == 5:
return self.sub_packets[0].get_value() > self.sub_packets[1].get_value()
elif self.p_type == 6:
return self.sub_packets[0].get_value() < self.sub_packets[1].get_value()
elif self.p_type == 7:
return self.sub_packets[0].get_value() == self.sub_packets[1].get_value()
class Day(AOCDay):
test_solutions_p1 = [16, 12, 23, 31, 1007]
test_solutions_p2 = [3, 54, 7, 9, 1, 0, 0, 1, 834151779165]
def getBits(self) -> str:
bits = bin(int(self.getInput(), 16))[2:]
return "0" * (len(self.getInput()) * 4 - len(bits)) + bits
def part1(self) -> Any:
return Packet(self.getBits()).get_versions()
def part2(self) -> Any:
return Packet(self.getBits()).get_value()