day16 - p2 ~17 times faster (1m20s); probably still improvable
This commit is contained in:
parent
e4e916d7bb
commit
5a699478e2
157
day16.py
157
day16.py
@ -1,15 +1,50 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
|
from dataclasses import dataclass
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from tools.aoc import AOCDay
|
from tools.aoc import AOCDay
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class Valve:
|
class Valve:
|
||||||
def __init__(self, name: str, flowrate: int) -> None:
|
def __init__(self, name: str, flowrate: int):
|
||||||
self.name = name
|
self.name: str = name
|
||||||
self.flowrate = flowrate
|
self.flowrate: int = flowrate
|
||||||
self.tunnels = []
|
self.tunnels: set = set()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Tunnel:
|
||||||
|
target: Valve
|
||||||
|
length: int = 1
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Tunnel(target={self.target.name}, length={self.length})"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_zero_flow(root: Valve) -> None:
|
||||||
|
q = deque()
|
||||||
|
q.append(root)
|
||||||
|
v = set()
|
||||||
|
while q:
|
||||||
|
valve = q.popleft()
|
||||||
|
if valve.name in v:
|
||||||
|
continue
|
||||||
|
v.add(valve.name)
|
||||||
|
for tunnel in valve.tunnels.copy():
|
||||||
|
if tunnel.target.flowrate > 0:
|
||||||
|
continue
|
||||||
|
q.append(tunnel.target)
|
||||||
|
valve.tunnels.remove(tunnel)
|
||||||
|
for c_tunnel in tunnel.target.tunnels.copy():
|
||||||
|
if c_tunnel.target == valve:
|
||||||
|
continue
|
||||||
|
elif c_tunnel.target == tunnel.target:
|
||||||
|
tunnel.target.tunnels.add(Tunnel(valve, tunnel.length + c_tunnel.length))
|
||||||
|
else:
|
||||||
|
valve.tunnels.add(Tunnel(c_tunnel.target, tunnel.length + c_tunnel.length))
|
||||||
|
|
||||||
|
|
||||||
def get_openable_valve_tunnels(valve: Valve, open_valves: set, time_remaining: int) -> set:
|
def get_openable_valve_tunnels(valve: Valve, open_valves: set, time_remaining: int) -> set:
|
||||||
@ -22,87 +57,84 @@ def get_openable_valve_tunnels(valve: Valve, open_valves: set, time_remaining: i
|
|||||||
if v.name in visited:
|
if v.name in visited:
|
||||||
continue
|
continue
|
||||||
visited.add(v.name)
|
visited.add(v.name)
|
||||||
if v.name not in open_valves and v.flowrate > 0 and d + 2 <= time_remaining:
|
if v.name not in open_valves and d + 2 <= time_remaining:
|
||||||
tunnels.add((d, v))
|
tunnels.add((d, v))
|
||||||
for x in v.tunnels:
|
for x in v.tunnels:
|
||||||
queue.append((d + 1, x))
|
queue.append((d + x.length, x.target))
|
||||||
|
|
||||||
return tunnels
|
return tunnels
|
||||||
|
|
||||||
|
|
||||||
def get_max_flow(valve: Valve, open_valves: set, time_remaining: int = 30) -> int:
|
def get_max_flow(valve: Valve, open_valves: list, time_remaining: int = 30, depth: int = 0) -> int:
|
||||||
max_flow = 0
|
max_flow = 0
|
||||||
|
|
||||||
ov = get_openable_valve_tunnels(valve, open_valves, time_remaining)
|
ov = {t for t in valve.tunnels if t.target not in open_valves and t.length < time_remaining - 2}
|
||||||
if time_remaining <= 0 or not ov:
|
if time_remaining <= 0 or not ov:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
for d, v in ov:
|
for tunnel in ov:
|
||||||
this_open_valves = open_valves.copy()
|
this_open_flow = tunnel.target.flowrate * (time_remaining - tunnel.length - 1)
|
||||||
if d + 2 > time_remaining:
|
this_flow = get_max_flow(tunnel.target, open_valves + [tunnel.target], time_remaining - tunnel.length - 1, depth + 1)
|
||||||
continue
|
|
||||||
this_open_valves.add(v.name)
|
|
||||||
this_open_flow = v.flowrate * (time_remaining - d - 1)
|
|
||||||
this_flow = get_max_flow(v, this_open_valves, time_remaining - d - 1)
|
|
||||||
if this_flow + this_open_flow > max_flow:
|
if this_flow + this_open_flow > max_flow:
|
||||||
max_flow = this_flow + this_open_flow
|
max_flow = this_flow + this_open_flow
|
||||||
|
|
||||||
|
|
||||||
return max_flow
|
return max_flow
|
||||||
|
|
||||||
|
|
||||||
def get_max_flow_double(valve1: Valve, valve2: Valve, open_valves: set, DP: dict, time_remaining_1: int = 26, time_remaining_2: int = 26, depth: int = 0) -> int:
|
def get_max_flow_double(valve1: Valve, valve2: Valve, open_valves: list, DP: dict, time_remaining_1: int = 26, time_remaining_2: int = 26, depth: int = 0) -> int:
|
||||||
dp_key = valve1.name + valve2.name + "%02d" % time_remaining_1 + "%02d" % time_remaining_2 + "".join(list(sorted(open_valves)))
|
dp_key = valve1.name + valve2.name + "%02d" % time_remaining_1 + "%02d" % time_remaining_2 + "".join(open_valves)
|
||||||
dp_key2 = valve2.name + valve1.name + "%02d" % time_remaining_1 + "%02d" % time_remaining_2 + "".join(list(sorted(open_valves)))
|
dp_key2 = valve2.name + valve1.name + "%02d" % time_remaining_1 + "%02d" % time_remaining_2 + "".join(open_valves)
|
||||||
if dp_key in DP:
|
if dp_key in DP:
|
||||||
return DP[dp_key]
|
return DP[dp_key]
|
||||||
|
|
||||||
if time_remaining_1 <= 0 and time_remaining_2 <= 0:
|
if time_remaining_1 <= 0 and time_remaining_2 <= 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
ov1 = get_openable_valve_tunnels(valve1, open_valves, time_remaining_1)
|
ov1 = {t for t in valve1.tunnels if t.target.name not in open_valves and t.length < time_remaining_1 - 2}
|
||||||
ov2 = get_openable_valve_tunnels(valve2, open_valves, time_remaining_2)
|
ov2 = {t for t in valve2.tunnels if t.target.name not in open_valves and t.length < time_remaining_2 - 2}
|
||||||
|
|
||||||
if not ov1 and not ov2:
|
if not ov1 and not ov2:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if not ov1:
|
if not ov1:
|
||||||
ov1 = {(99, valve1)}
|
ov1 = {Tunnel(valve1, 99)}
|
||||||
|
|
||||||
if not ov2:
|
if not ov2:
|
||||||
ov2 = {(99, valve2)}
|
ov2 = {Tunnel(valve2, 99)}
|
||||||
|
|
||||||
permut = product(ov1, ov2)
|
permut = product(ov1, ov2)
|
||||||
|
|
||||||
if depth == 0:
|
|
||||||
pbar = tqdm(total=210)
|
|
||||||
|
|
||||||
max_flow = 0
|
max_flow = 0
|
||||||
for v1, v2 in permut:
|
for v1, v2 in permut:
|
||||||
if v1[1].name == v2[1].name:
|
if v1.target.name == v2.target.name:
|
||||||
continue
|
continue
|
||||||
if depth == 0:
|
|
||||||
pbar.update(1)
|
|
||||||
this_open_values = open_valves.copy()
|
|
||||||
d1, tv1 = v1
|
|
||||||
d2, tv2 = v2
|
|
||||||
this_open_flow = 0
|
|
||||||
if d1 + 2 <= time_remaining_1:
|
|
||||||
this_open_values.add(tv1.name)
|
|
||||||
this_open_flow += tv1.flowrate * (time_remaining_1 - d1 - 1)
|
|
||||||
else:
|
|
||||||
tv1 = valve1
|
|
||||||
if d2 + 2 <= time_remaining_2:
|
|
||||||
this_open_values.add(tv2.name)
|
|
||||||
this_open_flow += tv2.flowrate * (time_remaining_2 - d2 - 1)
|
|
||||||
else:
|
|
||||||
tv2 = valve2
|
|
||||||
|
|
||||||
this_flow_rate = get_max_flow_double(tv1, tv2, this_open_values, DP, time_remaining_1 - d1 - 1, time_remaining_2 - d2 - 1, depth + 1)
|
this_open_flow = 0
|
||||||
|
|
||||||
|
t1 = valve1
|
||||||
|
if v1.length + 2 <= time_remaining_1:
|
||||||
|
this_open_flow += v1.target.flowrate * (time_remaining_1 - v1.length - 1)
|
||||||
|
t1 = v1.target
|
||||||
|
|
||||||
|
t2 = valve2
|
||||||
|
if v2.length + 2 <= time_remaining_2:
|
||||||
|
this_open_flow += v2.target.flowrate * (time_remaining_2 - v2.length - 1)
|
||||||
|
t2 = v2.target
|
||||||
|
|
||||||
|
this_flow_rate = get_max_flow_double(
|
||||||
|
t1,
|
||||||
|
t2,
|
||||||
|
sorted(open_valves + [v1.target.name, v2.target.name]),
|
||||||
|
DP,
|
||||||
|
time_remaining_1 - v1.length - 1,
|
||||||
|
time_remaining_2 - v2.length - 1,
|
||||||
|
depth + 1
|
||||||
|
)
|
||||||
|
|
||||||
if this_flow_rate + this_open_flow > max_flow:
|
if this_flow_rate + this_open_flow > max_flow:
|
||||||
max_flow = this_flow_rate + this_open_flow
|
max_flow = this_flow_rate + this_open_flow
|
||||||
|
|
||||||
if depth == 0:
|
|
||||||
pbar.close()
|
|
||||||
DP[dp_key] = max_flow
|
DP[dp_key] = max_flow
|
||||||
DP[dp_key2] = max_flow
|
DP[dp_key2] = max_flow
|
||||||
return max_flow
|
return max_flow
|
||||||
@ -122,29 +154,46 @@ class Day(AOCDay):
|
|||||||
|
|
||||||
def get_valve_graph(self) -> Valve:
|
def get_valve_graph(self) -> Valve:
|
||||||
valves = {}
|
valves = {}
|
||||||
|
tmp_tunnels = {}
|
||||||
for line in self.getInput():
|
for line in self.getInput():
|
||||||
p = line.split(" ")
|
p = line.split(" ")
|
||||||
valve_name = p[1]
|
valve_name = p[1]
|
||||||
flowrate = int(p[4][5:-1])
|
flowrate = int(p[4][5:-1])
|
||||||
tunnels = "".join(p[9:]).split(",")
|
tunnels = "".join(p[9:]).split(",")
|
||||||
valves[valve_name] = Valve(valve_name, flowrate)
|
valves[valve_name] = Valve(valve_name, flowrate)
|
||||||
valves[valve_name].tunnels = tunnels
|
tmp_tunnels[valve_name] = tunnels
|
||||||
|
|
||||||
|
for name, tunnels in tmp_tunnels.items():
|
||||||
|
valves[name].tunnels = {Tunnel(valves[t]) for t in tunnels}
|
||||||
|
|
||||||
for valve in valves.values():
|
for valve in valves.values():
|
||||||
valve.tunnels = [valves[x] for x in valve.tunnels]
|
tunnels = set()
|
||||||
for valve in valves.values():
|
queue = deque()
|
||||||
for v in valve.tunnels:
|
visited = set()
|
||||||
if valve not in v.tunnels:
|
queue.append((0, valve))
|
||||||
v.tunnels.append(valve)
|
while queue:
|
||||||
|
d, v = queue.popleft()
|
||||||
|
if v in visited:
|
||||||
|
continue
|
||||||
|
visited.add(v)
|
||||||
|
if v != valve and v.flowrate > 0:
|
||||||
|
tunnels.add(Tunnel(v, d))
|
||||||
|
for x in v.tunnels:
|
||||||
|
queue.append((d + 1, x.target))
|
||||||
|
|
||||||
|
tmp_tunnels[valve.name] = tunnels
|
||||||
|
|
||||||
|
for name, tunnels in tmp_tunnels.items():
|
||||||
|
valves[name].tunnels = tunnels
|
||||||
|
|
||||||
return valves["AA"]
|
return valves["AA"]
|
||||||
|
|
||||||
def part1(self) -> Any:
|
def part1(self) -> Any:
|
||||||
return get_max_flow(self.get_valve_graph(), set())
|
return get_max_flow(self.get_valve_graph(), [])
|
||||||
|
|
||||||
def part2(self) -> Any:
|
def part2(self) -> Any:
|
||||||
root = self.get_valve_graph()
|
root = self.get_valve_graph()
|
||||||
return get_max_flow_double(root, root, set(), {})
|
return get_max_flow_double(root, root, [], {})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user