diff --git a/src/day16.rs b/src/day16.rs index 8717e7b..5215f00 100644 --- a/src/day16.rs +++ b/src/day16.rs @@ -1,5 +1,5 @@ use std::{ - collections::{BinaryHeap, HashMap}, + collections::{BinaryHeap, HashMap, HashSet}, fs::read_to_string, }; @@ -33,7 +33,7 @@ impl Dir { } } -struct Step(Coord, Dir, usize, Vec); +struct Step(Coord, Dir, usize, Option); impl Ord for Step { fn cmp(&self, other: &Self) -> std::cmp::Ordering { @@ -59,7 +59,7 @@ fn part1(input: &str) -> RiddleResult { let end = maze.entries().find(|(_p, c)| **c == 'E').unwrap().0; let mut visited = HashMap::<(Coord, Dir), usize>::new(); let mut stack: BinaryHeap = BinaryHeap::new(); - stack.push(Step(start, E, 0, vec![])); + stack.push(Step(start, E, 0, None)); while let Some(Step(np, nd, cost, _)) = stack.pop() { if visited.contains_key(&(np, nd)) { continue; @@ -70,7 +70,7 @@ fn part1(input: &str) -> RiddleResult { } for d in nd.nexts() { if !visited.contains_key(&(np, d)) { - stack.push(Step(np, d, cost + 1000, vec![])); + stack.push(Step(np, d, cost + 1000, None)); } } let forward = match nd { @@ -80,38 +80,52 @@ fn part1(input: &str) -> RiddleResult { W => (np.0 - 1, np.1), }; if maze[forward] != '#' && !visited.contains_key(&(forward, nd)) { - stack.push(Step(forward, nd, cost + 1, vec![])); + stack.push(Step(forward, nd, cost + 1, None)); } } panic!("no path found") } +type Node = (Coord, Dir); + fn part2(input: &str) -> RiddleResult { use Dir::*; let maze = Grid::parse(input); let start = maze.entries().find(|(_p, c)| **c == 'S').unwrap().0; let end = maze.entries().find(|(_p, c)| **c == 'E').unwrap().0; - let mut visited = HashMap::<(Coord, Dir), usize>::new(); + let mut visited = HashMap::)>::new(); let mut stack: BinaryHeap = BinaryHeap::new(); - stack.push(Step(start, E, 0, vec![])); + stack.push(Step(start, E, 0, None)); let mut best: Option = None; - let mut visited_by_best: Vec = vec![start, end]; - while let Some(Step(np, nd, cost, mut pred)) = stack.pop() { + while let Some(Step(np, nd, cost, pred)) = stack.pop() { if let Some(b) = best { if b < cost { break; // can't reach the end point with best cost anymore } } - visited.insert((np, nd), cost); + let entry = visited.entry((np, nd)).or_insert_with(|| { + ( + cost, + if let Some(pred) = pred { + vec![pred] + } else { + vec![] + }, + ) + }); + if entry.0 < cost { + continue; + } + if let Some(pred) = pred { + entry.1.push(pred); + } if np == end { best = Some(cost); - visited_by_best.append(&mut pred); } - pred.push(np); for d in nd.nexts() { if !visited.contains_key(&(np, d)) { - stack.push(Step(np, d, cost + 1000, pred.clone())); + stack.push(Step(np, d, cost + 1000, Some((np, nd)))); } } let forward = match nd { @@ -121,10 +135,26 @@ fn part2(input: &str) -> RiddleResult { W => (np.0 - 1, np.1), }; if maze[forward] != '#' && !visited.contains_key(&(forward, nd)) { - stack.push(Step(forward, nd, cost + 1, pred)); + stack.push(Step(forward, nd, cost + 1, Some((np, nd)))); } } - visited_by_best.into_iter().unique().count() + let mut accounted = HashSet::::new(); + let mut stack: Vec = visited + .iter() + .filter(|((pos, _dir), _)| *pos == end) + .map(|(node, _)| *node) + .to_owned() + .collect_vec(); + while let Some(node) = stack.pop() { + if accounted.contains(&node) { + continue; + } + accounted.insert(node); + for pred in visited[&node].1.iter() { + stack.push(*pred); + } + } + accounted.into_iter().map(|(pos, _)| pos).unique().count() } #[cfg(test)]