aoc2024/day24.py

260 lines
8.8 KiB
Python

from __future__ import annotations
import itertools
from tools.aoc import AOCDay
from typing import Any
OPS = {
"AND": lambda a, b: a & b,
"OR": lambda a, b: a | b,
"XOR": lambda a, b: a ^ b,
}
def padded_bin(number: int, pad: int = 45):
b = bin(number)[2:]
return "0" * (pad - len(b)) + b
class Gate:
def __init__(self, input_1: Wire, input_2: Wire, op: str, out: Wire = None) -> None:
self.input_1 = input_1
self.input_2 = input_2
self.op_name = op
self.op = OPS[op]
self.out = out
def get_state(self) -> int:
return self.op(self.input_1.get_state(), self.input_2.get_state())
class Wire:
def __init__(self, name: str, gate: Gate | None = None, state: int = 0) -> None:
self.name = name
self.gate = gate
self.state = state
def get_state(self) -> int:
if self.gate is not None:
self.state = self.gate.get_state()
return self.state
def print(self, depth: int = 0) -> None:
pad = " " * depth
if self.gate is not None:
print(f"{pad} {self.name} <- {self.gate.input_1.name} {self.gate.op_name} {self.gate.input_2.name}")
self.gate.input_1.print(depth + 2)
self.gate.input_2.print(depth + 2)
class Wires(list):
def __init__(self):
super().__init__()
self.wires = {}
self.lines = {"x": [], "y": [], "z": []}
self.wire_count = {"x": 0, "y": 0, "z": 0}
def get_number(self, line: str = "z") -> int:
return int("".join(str(x.get_state()) for x in self.lines[line][::-1]), 2)
def set_number(self, line: str, number: int) -> None:
wires = self.lines[line][::-1]
for i, c in enumerate(padded_bin(number, len(wires))):
wires[i].state = int(c)
def swap_gates(self, wire_1: Wire | str, wire_2: Wire | str) -> None:
if isinstance(wire_1, str):
wire_1 = self.wires[wire_1]
if isinstance(wire_2, str):
wire_2 = self.wires[wire_2]
wire_1.gate, wire_2.gate = wire_2.gate, wire_1.gate
def get_bit_diff(self, number: int) -> list[int]:
my_z = self.get_number()
if my_z == number:
return []
number_bin = padded_bin(number, self.wire_count["z"])
my_bin_z = padded_bin(my_z, self.wire_count["z"])
return [i for i, c in enumerate(number_bin) if c != my_bin_z[i]]
def append(self, wire: Wire) -> None:
self.wires[wire.name] = wire
if wire.name[0] in self.wire_count:
self.wire_count[wire.name[0]] += 1
if wire.name[0] in self.lines:
self.lines[wire.name[0]] = list(sorted(self.lines[wire.name[0]] + [wire], key=lambda w: w.name))
super().append(wire)
def __getitem__(self, item: int | str) -> Wire:
if isinstance(item, int):
return super()[item]
else:
return self.wires[item]
def __contains__(self, item: Wire | str) -> bool:
if isinstance(item, str):
return item in self.wires
else:
return super().__contains__(item)
class Day(AOCDay):
inputs = [
[
(4, "input24_test"),
(2024, "input24_test2"),
(42883464055378, "input24"),
],
[
(None, "input24"),
],
]
def parse_input(self) -> tuple[Wires, list[Gate]]:
init_state, init_gates = self.getMultiLineInputAsArray()
wires = Wires()
for wire in init_state:
wire, wire_state = wire.split(": ")
wires.append(Wire(name=wire, state=int(wire_state)))
gates = []
done = set()
while len(gates) < len(init_gates):
for gate in init_gates:
if gate in done:
continue
input_1, op, input_2, _, output = gate.split()
if input_1 not in wires or input_2 not in wires:
continue
done.add(gate)
gate = Gate(wires[input_1], wires[input_2], op)
gates.append(gate)
wires.append(Wire(name=output, gate=gate))
gate.out = wires[output]
return wires, gates
def part1(self) -> Any:
return self.parse_input()[0].get_number()
def part2(self) -> Any:
wires, gates = self.parse_input()
# deduced by manually debugging the output from below
# need to find some automated solution in the future
wires.swap_gates("dtk", "vgs")
wires.swap_gates("z39", "pfw")
wires.swap_gates("z21", "shh")
wires.swap_gates("z33", "dqr")
wires["z21"].print()
fishy_wires = []
for w in wires:
if w.gate is None:
continue
if w.name.startswith("z") and w.gate.op_name != "XOR":
fishy_wires.append(w)
if w.gate.op_name == "OR":
if w.gate.input_1.gate is None or w.gate.input_1.gate.op_name != "AND":
fishy_wires.append(w)
if w.gate.input_2.gate is None or w.gate.input_2.gate.op_name != "AND":
fishy_wires.append(w)
else:
if w.gate.input_1.gate is not None and w.gate.input_1.gate.op_name not in ["XOR", "OR"]:
fishy_wires.append(w)
if w.gate.input_2.gate is not None and w.gate.input_2.gate.op_name not in ["XOR", "OR"]:
fishy_wires.append(w)
for w in fishy_wires:
print(
f"{w.name} -> Input Gate: {w.gate.op_name} from {w.gate.input_1.name} ({w.gate.input_1.gate.op_name if w.gate.input_1.gate is not None else 'None'}) and {w.gate.input_2.name} ({w.gate.input_2.gate.op_name if w.gate.input_2.gate is not None else 'None'})"
)
# z21 -> Input Gate: AND from jsq (XOR) and vcj (OR)
# z33 -> Input Gate: OR from mtw (AND) and fdv (AND)
# z39 -> Input Gate: AND from x39 (None) and y39 (None)
# sjk -> Input Gate: OR from shh (XOR) and kcq (AND)
# z26 -> Input Gate: XOR from skh (OR) and vgs (AND)
# bmg -> Input Gate: AND from vgs (AND) and skh (OR)
# qcr -> Input Gate: OR from bmg (AND) and dtk (XOR)
# jqk -> Input Gate: OR from pfw (XOR) and kdd (AND)
return ""
old_x = wires.get_number("x")
old_y = wires.get_number("y")
old_real_z = wires.get_number("x") + wires.get_number("y")
wires.set_number("x", int("101010101010101010101010101010101010101010101", 2))
wires.set_number("y", int("010101010101010101010101010101010101010101010", 2))
real_z = wires.get_number("x") + wires.get_number("y")
# z-lines have to finish on an XOR gate, except for the most significant bit, which should have an OR gate
wrong_gates = []
for i in range(46):
z_0 = [w for w in wires if w.name == "z%02d" % i][0]
if i < 45 and z_0.gate.op_name != "XOR":
wrong_gates.append(z_0)
elif i == 45 and z_0.gate.op_name != "OR":
wrong_gates.append(z_0)
xor_gates = [xor.out for xor in gates if xor.op_name == "XOR" and not xor.out.name.startswith("z")]
best = 10e9
best_pairs = None
for comp in itertools.combinations(xor_gates, len(wrong_gates)):
for mesh in [list(zip(x, wrong_gates)) for x in itertools.permutations(comp, len(wrong_gates))]:
for wire_1, wire_2 in mesh:
wires.swap_gates(wire_1, wire_2)
try:
bit_rot = len(wires.get_bit_diff(real_z))
except RecursionError:
bit_rot = 10e9
if bit_rot < best:
best = bit_rot
best_pairs = mesh
for wire_1, wire_2 in mesh:
wires.swap_gates(wire_1, wire_2)
print(best)
swapped = []
for wire_1, wire_2 in best_pairs:
wires.swap_gates(wire_1, wire_2)
swapped.append(wire_1)
swapped.append(wire_2)
wrong_wire = wires["z%02d" % (wires.wire_count["z"] - 1 - wires.get_bit_diff(real_z)[0])]
swapped.append(wrong_wire)
for x_wire in xor_gates:
if x_wire in swapped:
continue
wires.swap_gates(x_wire, wrong_wire)
try:
if len(wires.get_bit_diff(real_z)) == 0:
swapped.append(x_wire)
break
except RecursionError:
pass
wires.swap_gates(x_wire, wrong_wire)
print(real_z)
print(wires.get_number())
wires.set_number("x", old_x)
wires.set_number("y", old_y)
print(old_real_z)
print(wires.get_number())
print(",".join(sorted(w.name for w in swapped)))
return ""
if __name__ == "__main__":
day = Day(2024, 24)
day.run(verbose=True)