diff --git a/src/tools/aoc.py b/src/tools/aoc.py index 748f45c..1164de8 100644 --- a/src/tools/aoc.py +++ b/src/tools/aoc.py @@ -1,14 +1,16 @@ from __future__ import annotations import os import re -import subprocess import requests +import subprocess +import sys import time +import uuid import webbrowser from bs4 import BeautifulSoup from .datafiles import JSONFile from .stopwatch import StopWatch -from typing import Any, Callable, List, Tuple, Type +from typing import Any, Callable, Type from tqdm.auto import tqdm from .tools import get_script_dir @@ -19,9 +21,9 @@ INPUTS_PATH = os.path.join(BASE_PATH, "inputs") class AOCDay: year: int day: int - input: List[str] # our input is always a list of str/lines - inputs: List[List[Tuple[Any, str]]] - part_func: List[Callable] + input: list[str] # our input is always a list of str/lines + inputs: list[list[tuple[Any, str]]] + part_func: list[Callable] def __init__(self, year: int, day: int): self.day = day @@ -29,6 +31,8 @@ class AOCDay: self.part_func = [self.part1, self.part2] self._current_test_file = None self._current_test_solution = None + self.__main_progress_bar_id = None + self.progress_bars = {} def part1(self) -> Any: raise NotImplementedError() @@ -37,7 +41,14 @@ class AOCDay: raise NotImplementedError() def is_test(self) -> bool: - return self._current_test_solution is not None + return "test" in self._current_test_file + + def _call_part_func(self, func: Callable) -> Any: + ans = func() + for p, pbar in self.progress_bars.items(): + pbar.close() + self.progress_bars = {} + return ans def run_part( self, @@ -54,12 +65,12 @@ class AOCDay: self._load_input(input_file) if not measure_runtime or case_count < len(self.inputs[part]) - 1: - answer = self.part_func[part]() + answer = self._call_part_func(self.part_func[part]) else: stopwatch = StopWatch(auto_start=False) for _ in tqdm(range(timeit_number), desc=f"Part {part+1}", leave=False): stopwatch.start() - answer = self.part_func[part]() + answer = self._call_part_func(self.part_func[part]) stopwatch.stop() exec_time = stopwatch.avg_string(timeit_number) @@ -222,7 +233,7 @@ class AOCDay: def getMultiLineInputAsArray( self, return_type: Type = None, join_char: str = None - ) -> List: + ) -> list[list[Any]]: """ get input for day x as 2d array, split by empty lines """ @@ -248,8 +259,8 @@ class AOCDay: return return_array def getInputAsArraySplit( - self, split_char: str = ",", return_type: Type | List[Type] = None - ) -> List: + self, split_char: str = ",", return_type: Type | list[Type] = None + ) -> list[Any] | list[list[Any]]: """ get input for day x with the lines split by split_char if input has only one line, returns a 1d array with the values @@ -270,6 +281,24 @@ class AOCDay: return return_array + def progress(self, total: int, add: int = 1, bar_id: str = None) -> None: + if bar_id is None: + if self.__main_progress_bar_id is None: + self.__main_progress_bar_id = uuid.uuid4() + bar_id = self.__main_progress_bar_id + + if bar_id not in self.progress_bars: + pbar = tqdm( + total=total, + position=len(self.progress_bars), + leave=False, + file=sys.stdout, + ) + self.progress_bars[bar_id] = pbar + + pbar = self.progress_bars[bar_id] + pbar.update(add) + def print_solution( day: int, @@ -305,7 +334,7 @@ def print_solution( print("Day %s, Part %s - Average run time: %s" % (day, part, exec_time)) -def split_line(line, split_char: str = ",", return_type: Type | List[Type] = None): +def split_line(line, split_char: str = ",", return_type: Type | list[Type] = None): if split_char: line = line.split(split_char)