diff --git a/aoc2024/src/day08/python/solution.py b/aoc2024/src/day08/python/solution.py index 8e88a18..0e3d370 100644 --- a/aoc2024/src/day08/python/solution.py +++ b/aoc2024/src/day08/python/solution.py @@ -19,6 +19,7 @@ # Antenas type Pos = tuple[int, int] + def _find_antennas(city_map: Sequence[str]) -> dict[str, set[Pos]]: """Finds antenna locations and groups them by frequency. Frequency is represented by character, for example 'a'.""" @@ -31,10 +32,14 @@ def _find_antennas(city_map: Sequence[str]) -> dict[str, set[Pos]]: return antennas -def count_antinodes(city_map: Sequence[str]) -> int: +def _within_bounds(pos: Pos, bounds: Pos) -> bool: + """Returns true if a position is within city bounds.""" + return 0 <= pos[0] < bounds[0] and 0 <= pos[1] < bounds[1] + + +def count_antinodes(city_map: Sequence[str], any_grid_position=False) -> int: """Counts the antinodes resulting from antennas in the city.""" - city_width = len(city_map[0]) - city_height = len(city_map) + city_bounds = (len(city_map[0]), len(city_map)) antennas: dict[chr, set[Pos]] = _find_antennas(city_map) antinode_positions: set[Pos] = set() for positions in antennas.values(): @@ -43,8 +48,13 @@ def count_antinodes(city_map: Sequence[str]) -> int: if curr_pos == other_pos: continue # Skip same antenna. distance = (other_pos[0] - curr_pos[0], other_pos[1] - curr_pos[1]) - antinode_pos = (other_pos[0] + distance[0], other_pos[1] + distance[1]) - if 0 <= antinode_pos[0] < city_height and 0 <= antinode_pos[1] < city_width: - # Antinode is within city bounds - antinode_positions.add(antinode_pos) - return len(antinode_positions) \ No newline at end of file + if any_grid_position: + next_pos = (curr_pos[0] + distance[0], curr_pos[1] + distance[1]) + while _within_bounds(next_pos, city_bounds): + antinode_positions.add(next_pos) + next_pos = (next_pos[0] + distance[0], next_pos[1] + distance[1]) + else: + antinode_pos = (other_pos[0] + distance[0], other_pos[1] + distance[1]) + if _within_bounds(pos=antinode_pos, bounds=city_bounds): + antinode_positions.add(antinode_pos) + return len(antinode_positions) diff --git a/aoc2024/test/day08/python/test_solution.py b/aoc2024/test/day08/python/test_solution.py index f321eb0..722492f 100644 --- a/aoc2024/test/day08/python/test_solution.py +++ b/aoc2024/test/day08/python/test_solution.py @@ -27,6 +27,12 @@ def test_part1_withExample_counts(self): def test_part1_withPuzzleInput_counts(self): self.assertEqual(327, count_antinodes(self.input)) + def test_part2_withExample_counts(self): + self.assertEqual(34, count_antinodes(self.examples[0], any_grid_position=True)) + + def test_part2_withPuzzleInput_counts(self): + self.assertEqual(1233, count_antinodes(self.input, any_grid_position=True)) + if __name__ == '__main__': unittest.main()