diff --git a/day16.py b/day16.py index 6348a0e..05b4d2a 100644 --- a/day16.py +++ b/day16.py @@ -1,15 +1,50 @@ from collections import deque +from dataclasses import dataclass from itertools import product from tools.aoc import AOCDay from typing import Any -from tqdm import tqdm class Valve: - def __init__(self, name: str, flowrate: int) -> None: - self.name = name - self.flowrate = flowrate - self.tunnels = [] + def __init__(self, name: str, flowrate: int): + self.name: str = name + self.flowrate: int = flowrate + self.tunnels: set = set() + + +@dataclass(frozen=True) +class Tunnel: + target: Valve + length: int = 1 + + def __str__(self): + return f"Tunnel(target={self.target.name}, length={self.length})" + + def __repr__(self): + return str(self) + + +def strip_zero_flow(root: Valve) -> None: + q = deque() + q.append(root) + v = set() + while q: + valve = q.popleft() + if valve.name in v: + continue + v.add(valve.name) + for tunnel in valve.tunnels.copy(): + if tunnel.target.flowrate > 0: + continue + q.append(tunnel.target) + valve.tunnels.remove(tunnel) + for c_tunnel in tunnel.target.tunnels.copy(): + if c_tunnel.target == valve: + continue + elif c_tunnel.target == tunnel.target: + tunnel.target.tunnels.add(Tunnel(valve, tunnel.length + c_tunnel.length)) + else: + valve.tunnels.add(Tunnel(c_tunnel.target, tunnel.length + c_tunnel.length)) def get_openable_valve_tunnels(valve: Valve, open_valves: set, time_remaining: int) -> set: @@ -22,87 +57,84 @@ def get_openable_valve_tunnels(valve: Valve, open_valves: set, time_remaining: i if v.name in visited: continue visited.add(v.name) - if v.name not in open_valves and v.flowrate > 0 and d + 2 <= time_remaining: + if v.name not in open_valves and d + 2 <= time_remaining: tunnels.add((d, v)) for x in v.tunnels: - queue.append((d + 1, x)) + queue.append((d + x.length, x.target)) return tunnels -def get_max_flow(valve: Valve, open_valves: set, time_remaining: int = 30) -> int: +def get_max_flow(valve: Valve, open_valves: list, time_remaining: int = 30, depth: int = 0) -> int: max_flow = 0 - ov = get_openable_valve_tunnels(valve, open_valves, time_remaining) + ov = {t for t in valve.tunnels if t.target not in open_valves and t.length < time_remaining - 2} if time_remaining <= 0 or not ov: return 0 - for d, v in ov: - this_open_valves = open_valves.copy() - if d + 2 > time_remaining: - continue - this_open_valves.add(v.name) - this_open_flow = v.flowrate * (time_remaining - d - 1) - this_flow = get_max_flow(v, this_open_valves, time_remaining - d - 1) + for tunnel in ov: + this_open_flow = tunnel.target.flowrate * (time_remaining - tunnel.length - 1) + this_flow = get_max_flow(tunnel.target, open_valves + [tunnel.target], time_remaining - tunnel.length - 1, depth + 1) if this_flow + this_open_flow > max_flow: max_flow = this_flow + this_open_flow + return max_flow -def get_max_flow_double(valve1: Valve, valve2: Valve, open_valves: set, DP: dict, time_remaining_1: int = 26, time_remaining_2: int = 26, depth: int = 0) -> int: - dp_key = valve1.name + valve2.name + "%02d" % time_remaining_1 + "%02d" % time_remaining_2 + "".join(list(sorted(open_valves))) - dp_key2 = valve2.name + valve1.name + "%02d" % time_remaining_1 + "%02d" % time_remaining_2 + "".join(list(sorted(open_valves))) +def get_max_flow_double(valve1: Valve, valve2: Valve, open_valves: list, DP: dict, time_remaining_1: int = 26, time_remaining_2: int = 26, depth: int = 0) -> int: + dp_key = valve1.name + valve2.name + "%02d" % time_remaining_1 + "%02d" % time_remaining_2 + "".join(open_valves) + dp_key2 = valve2.name + valve1.name + "%02d" % time_remaining_1 + "%02d" % time_remaining_2 + "".join(open_valves) if dp_key in DP: return DP[dp_key] if time_remaining_1 <= 0 and time_remaining_2 <= 0: return 0 - ov1 = get_openable_valve_tunnels(valve1, open_valves, time_remaining_1) - ov2 = get_openable_valve_tunnels(valve2, open_valves, time_remaining_2) + ov1 = {t for t in valve1.tunnels if t.target.name not in open_valves and t.length < time_remaining_1 - 2} + ov2 = {t for t in valve2.tunnels if t.target.name not in open_valves and t.length < time_remaining_2 - 2} if not ov1 and not ov2: return 0 if not ov1: - ov1 = {(99, valve1)} + ov1 = {Tunnel(valve1, 99)} if not ov2: - ov2 = {(99, valve2)} + ov2 = {Tunnel(valve2, 99)} permut = product(ov1, ov2) - if depth == 0: - pbar = tqdm(total=210) - max_flow = 0 for v1, v2 in permut: - if v1[1].name == v2[1].name: + if v1.target.name == v2.target.name: continue - if depth == 0: - pbar.update(1) - this_open_values = open_valves.copy() - d1, tv1 = v1 - d2, tv2 = v2 - this_open_flow = 0 - if d1 + 2 <= time_remaining_1: - this_open_values.add(tv1.name) - this_open_flow += tv1.flowrate * (time_remaining_1 - d1 - 1) - else: - tv1 = valve1 - if d2 + 2 <= time_remaining_2: - this_open_values.add(tv2.name) - this_open_flow += tv2.flowrate * (time_remaining_2 - d2 - 1) - else: - tv2 = valve2 - this_flow_rate = get_max_flow_double(tv1, tv2, this_open_values, DP, time_remaining_1 - d1 - 1, time_remaining_2 - d2 - 1, depth + 1) + this_open_flow = 0 + + t1 = valve1 + if v1.length + 2 <= time_remaining_1: + this_open_flow += v1.target.flowrate * (time_remaining_1 - v1.length - 1) + t1 = v1.target + + t2 = valve2 + if v2.length + 2 <= time_remaining_2: + this_open_flow += v2.target.flowrate * (time_remaining_2 - v2.length - 1) + t2 = v2.target + + this_flow_rate = get_max_flow_double( + t1, + t2, + sorted(open_valves + [v1.target.name, v2.target.name]), + DP, + time_remaining_1 - v1.length - 1, + time_remaining_2 - v2.length - 1, + depth + 1 + ) + if this_flow_rate + this_open_flow > max_flow: max_flow = this_flow_rate + this_open_flow - if depth == 0: - pbar.close() DP[dp_key] = max_flow DP[dp_key2] = max_flow return max_flow @@ -122,29 +154,46 @@ class Day(AOCDay): def get_valve_graph(self) -> Valve: valves = {} + tmp_tunnels = {} for line in self.getInput(): p = line.split(" ") valve_name = p[1] flowrate = int(p[4][5:-1]) tunnels = "".join(p[9:]).split(",") valves[valve_name] = Valve(valve_name, flowrate) - valves[valve_name].tunnels = tunnels + tmp_tunnels[valve_name] = tunnels + + for name, tunnels in tmp_tunnels.items(): + valves[name].tunnels = {Tunnel(valves[t]) for t in tunnels} for valve in valves.values(): - valve.tunnels = [valves[x] for x in valve.tunnels] - for valve in valves.values(): - for v in valve.tunnels: - if valve not in v.tunnels: - v.tunnels.append(valve) + tunnels = set() + queue = deque() + visited = set() + queue.append((0, valve)) + while queue: + d, v = queue.popleft() + if v in visited: + continue + visited.add(v) + if v != valve and v.flowrate > 0: + tunnels.add(Tunnel(v, d)) + for x in v.tunnels: + queue.append((d + 1, x.target)) + + tmp_tunnels[valve.name] = tunnels + + for name, tunnels in tmp_tunnels.items(): + valves[name].tunnels = tunnels return valves["AA"] def part1(self) -> Any: - return get_max_flow(self.get_valve_graph(), set()) + return get_max_flow(self.get_valve_graph(), []) def part2(self) -> Any: root = self.get_valve_graph() - return get_max_flow_double(root, root, set(), {}) + return get_max_flow_double(root, root, [], {}) if __name__ == '__main__':