diff --git a/src/main/scala/org/lemon/advent/lib/graph/search.scala b/src/main/scala/org/lemon/advent/lib/graph/search.scala index b21f6f9..13d0437 100644 --- a/src/main/scala/org/lemon/advent/lib/graph/search.scala +++ b/src/main/scala/org/lemon/advent/lib/graph/search.scala @@ -27,44 +27,35 @@ def pathFind[N, D: Numeric](adjacency: N => Seq[(N, D)], start: N, ends: N => Bo queue.headOption -def allShortestPaths[N, D: Numeric]( - adjacency: N => Seq[(N, D)], - start: N, - ends: N => Boolean, - best: D -): Set[Path[N, D]] = - given Ordering[Path[N, D]] = Ordering.by[Path[N, D], D](_.distance).reverse +/** Finds all shortest paths in the graph from `start` to `end`. + * + * @param adjacency function to return edges for a given node + * @param start the start node + * @param ends function to check if a node is an ending node + * @return set of all shortest paths between `start` and `end` + * @tparam N the node type + * @tparam D the distance type + */ +def allShortestPaths[N, D: Numeric](adjacency: N => Seq[(N, D)], start: N, ends: N => Boolean): Set[Path[N, D]] = val paths = mutable.Set.empty[Path[N, D]] - val queue = mutable.PriorityQueue(Path(path = Seq(start), distance = summon[Numeric[D]].zero)) + val queue = mutable.Queue(Path(path = Seq(start), distance = summon[Numeric[D]].zero)) + val costs = mutable.Map(start -> summon[Numeric[D]].zero) - var exit = false - while !queue.isEmpty && !exit do + while !queue.isEmpty do val node @ Path(path, distance) = queue.dequeue if ends(node.at) then - if best == distance then - println(s"match: $distance $node") - paths.add(node) - else - queue ++= adjacency(node.at) - .filter((_, d) => distance + d <= best) - .map((neigh, dist) => Path(neigh +: path, distance + dist)) - - paths.toSet + if paths.isEmpty || distance < paths.head.distance then paths.clear() + if paths.isEmpty || distance <= paths.head.distance then paths.add(node) -def allShortestPaths2[N, D: Numeric](adjacency: N => Seq[(N, D)], start: N, ends: N => Boolean): Set[Path[N, D]] = - val visited = mutable.Set.empty[N] - val paths = mutable.Set.empty[Path[N, D]] + queue ++= adjacency(node.at) + .filter((neigh, dist) => + val costTo = distance + dist + costs.get(neigh) match + case Some(known) if known < costTo => false + case _ => costs(neigh) = costTo; true + ) + .map((neigh, dist) => Path(neigh +: path, distance + dist)) - def dfs(loc: N, dist: D, path: Seq[N]): Unit = - if !visited(loc) then - if ends(loc) then - println(s"found: $dist $path") - paths.add(Path(loc +: path, dist)) - else - visited.add(loc) - adjacency(loc).foreach((neigh, d) => dfs(neigh, dist + d, loc +: path)) - visited.remove(loc) - dfs(start, summon[Numeric[D]].zero, Seq.empty) paths.toSet /** Performs a dijkstra's search of the graph from `start` to `end`, returning diff --git a/src/main/scala/org/lemon/advent/year2024/Day16.scala b/src/main/scala/org/lemon/advent/year2024/Day16.scala index abf7454..39f5944 100644 --- a/src/main/scala/org/lemon/advent/year2024/Day16.scala +++ b/src/main/scala/org/lemon/advent/year2024/Day16.scala @@ -15,22 +15,18 @@ private object Day16: else Seq((coord, dir.turnLeft) -> 1000, (coord, dir.turnRight) -> 1000, (coord + dir, dir) -> 1) - def bestPath(grid: Map[Coord, Char]) = + def part1(input: String) = + val grid = parse(input) val start = grid.find(_._2 == 'S').get._1 val end = grid.find(_._2 == 'E').get._1 val facing = Direction.Right - pathFind(adjacency(grid), (start, facing), _._1 == end) - - def part1(input: String) = - val grid = parse(input) - bestPath(grid).get.distance + pathFind(adjacency(grid), (start, facing), _._1 == end).get.distance def part2(input: String) = val grid = parse(input) val start = grid.find(_._2 == 'S').get._1 val end = grid.find(_._2 == 'E').get._1 val facing = Direction.Right - val best = bestPath(grid).get.distance - allShortestPaths(adjacency(grid), (start, facing), _._1 == end, best) + allShortestPaths(adjacency(grid), (start, facing), _._1 == end) .flatMap(path => path.path.map(_._1)) .size diff --git a/src/test/scala/org/lemon/advent/year2024/Day16Test.scala b/src/test/scala/org/lemon/advent/year2024/Day16Test.scala index ddcf676..96b518c 100644 --- a/src/test/scala/org/lemon/advent/year2024/Day16Test.scala +++ b/src/test/scala/org/lemon/advent/year2024/Day16Test.scala @@ -68,27 +68,27 @@ class Day16Test extends UnitTest: part2(in) shouldBe 45 } - // test("part 2 example 2") { - // val in = """|################# - // |#...#...#...#..E# - // |#.#.#.#.#.#.#.#.# - // |#.#.#.#...#...#.# - // |#.#.#.#.###.#.#.# - // |#...#.#.#.....#.# - // |#.#.#.#.#.#####.# - // |#.#...#.#.#.....# - // |#.#.#####.#.###.# - // |#.#.#.......#...# - // |#.#.###.#####.### - // |#.#.#...#.....#.# - // |#.#.#.#####.###.# - // |#.#.#.........#.# - // |#.#.#.#########.# - // |#S#.............# - // |#################""".stripMargin - // part2(in) shouldBe 64 - // } + test("part 2 example 2") { + val in = """|################# + |#...#...#...#..E# + |#.#.#.#.#.#.#.#.# + |#.#.#.#...#...#.# + |#.#.#.#.###.#.#.# + |#...#.#.#.....#.# + |#.#.#.#.#.#####.# + |#.#...#.#.#.....# + |#.#.#####.#.###.# + |#.#.#.......#...# + |#.#.###.#####.### + |#.#.#...#.....#.# + |#.#.#.#####.###.# + |#.#.#.........#.# + |#.#.#.#########.# + |#S#.............# + |#################""".stripMargin + part2(in) shouldBe 64 + } - // test("part 2") { - // part2(read(file(2024)(16))) shouldBe 0 - // } + test("part 2") { + part2(read(file(2024)(16))) shouldBe 559 + }