aoc/year2023/day12.rs
1//! # Hot Springs
2//!
3//! A [dynamic programming](https://en.wikipedia.org/wiki/Dynamic_programming) approach calculates
4//! the possible arrangements for each entry in `O(n * w)` complexity where:
5//!
6//! * `n` Number of springs
7//! * `w` "Wiggle" is the amount of extra free space the springs can slide around in the pattern.
8//!
9//! We build a table for each entry with a row for each spring and a column for every character
10//! in the pattern. Adding a trailing `.` character makes bounds checking easier without changing
11//! the number of arrangements. The result will be the bottom right value.
12//!
13//! Using the sample `?###???????? 3,2,1`:
14//!
15//! ```none
16//! n = 3
17//! w = 13 - (3 + 2 + 1) - 3 + 1 = 5
18//!
19//! ? # # # ? ? ? ? ? ? ? ? .
20//! ┌----------------------------------------
21//! 3 | 0 0 0 [0 1 1 1 1] 0 0 0 0 0
22//! 2 | 0 0 0 0 0 0 [0 1 2 3 4] 0 0
23//! 1 | 0 0 0 0 0 0 0 0 [0 1 3 6 10]
24//! ```
25//!
26//! Each pattern updates the total at the index one *after* its end, if it can fit at that location
27//! For example the first spring can only match at indices `[1, 2, 3]` so it updates the total
28//! at index 4 to 1.
29//!
30//! The key insight is that the number of arrangements is then propagated as a prefix sum
31//! from left to right for each row as long as the character at the index is not a `#` or until
32//! `wiggle` characters are reached, whichever comes sooner.
33//!
34//! To calculate the next row, each matching pattern adds the value from the row above at the
35//! index one before its start. The first row is a special case where this value is always 1.
36//!
37//! As a nice side effect this approach always overwrites values so we can re-use the memory buffer
38//! for different entries without having to zero out values.
39//!
40//! ## Alternate approach
41//!
42//! Another way to look at the problem is to search to the left from each matching position
43//! until a `#` character is found. The previous pattern can't leave any trailing `#` characters
44//! otherwise it wouldn't be the previous pattern.
45//!
46//! Using the same example `?###???????? 3,2,1` and adding a trailing `.`:
47//!
48//! * `###` can only match at one location giving:
49//! ```none
50//! . # # # . . . . . . . . .
51//! [0 0 0 0 1 0 0 0 0 0 0 0 0]
52//! ````
53//!
54//!* The next `##` can match at 4 possible locations:
55//! ```none
56//! . # # # . # # . . . . . .
57//! [0 0 0 0 1 0 0 0 0 0 0 0 0]
58//! <<
59//! [0 0 0 0 0 0 0 1 0 0 0 0 0]
60//! ```
61//! * 2nd location:
62//! ```none
63//! . # # # . . # # . . . . .
64//! [0 0 0 0 1 0 0 0 0 0 0 0 0]
65//! <<<<
66//! [0 0 0 0 0 0 0 1 1 0 0 0 0]
67//! ```
68//! * 3rd location:
69//! ```none
70//! . # # # . . . # # . . . .
71//! [0 0 0 0 1 0 0 0 0 0 0 0 0]
72//! <<<<<<
73//! [0 0 0 0 0 0 0 1 1 1 0 0 0]
74//! ```
75//! * 4th location:
76//! ```none
77//! . # # # . . . . # # . . .
78//! [0 0 0 0 1 0 0 0 0 0 0 0 0]
79//! <<<<<<<<
80//! [0 0 0 0 0 0 0 1 1 1 1 0 0]
81//! ```
82//!* The final `#` can also match at 4 possible locations (for brevity only showing the 2nd pattern
83//! in a single position):
84//! ```none
85//! . # # # . # # . # . . . .
86//! [0 0 0 0 1 0 0 0 0 0 0 0 0]
87//! [0 0 0 0 0 0 0 1 1 1 1 0 0]
88//! <<
89//! [0 0 0 0 0 0 0 0 1 0 0 0 0]
90//! ```
91//! * 2nd location:
92//! ```none
93//! . # # # . # # . . # . . .
94//! [0 0 0 0 1 0 0 0 0 0 0 0 0]
95//! [0 0 0 0 0 0 0 1 1 1 1 0 0]
96//! <<<<
97//! [0 0 0 0 0 0 0 0 1 2 0 0 0]
98//! ```
99//! * 3rd location:
100//! ```none
101//! . # # # . # # . . . # . .
102//! [0 0 0 0 1 0 0 0 0 0 0 0 0]
103//! [0 0 0 0 0 0 0 1 1 1 1 0 0]
104//! <<<<<<
105//! [0 0 0 0 0 0 0 0 1 2 3 0 0]
106//! ```
107//! * 4th location:
108//! ```none
109//! . # # # . # # . . . . # .
110//! [0 0 0 0 1 0 0 0 0 0 0 0 0]
111//! [0 0 0 0 0 0 0 1 1 1 1 0 0]
112//! <<<<<<<<
113//! [0 0 0 0 0 0 0 0 1 2 3 4 0]
114//! ```
115//!
116//! The final result is then the sum of the bottom row with the nuance that any numbers before the
117//! last `#` don't count as they represent an invalid pattern.
118//!
119//! This is equivalent to the prefix sum approach described above but a little clearer to
120//! understand however slower to calculate.
121use crate::util::parse::*;
122use crate::util::thread::*;
123use std::sync::atomic::{AtomicU64, Ordering};
124
125type Spring<'a> = (&'a [u8], Vec<usize>);
126
127pub fn parse(input: &str) -> Vec<Spring<'_>> {
128 input
129 .lines()
130 .map(|line| {
131 let (prefix, suffix) = line.split_once(' ').unwrap();
132 let first = prefix.as_bytes();
133 let second = suffix.iter_unsigned().collect();
134 (first, second)
135 })
136 .collect()
137}
138
139pub fn part1(input: &[Spring<'_>]) -> u64 {
140 solve(input.iter(), 1)
141}
142
143pub fn part2(input: &[Spring<'_>]) -> u64 {
144 // Use as many cores as possible to parallelize the calculation.
145 let shared = AtomicU64::new(0);
146 spawn_parallel_iterator(input, |iter| {
147 let partial = solve(iter, 5);
148 shared.fetch_add(partial, Ordering::Relaxed);
149 });
150 shared.load(Ordering::Relaxed)
151}
152
153pub fn solve<'a, I>(iter: I, repeat: usize) -> u64
154where
155 I: Iterator<Item = &'a Spring<'a>>,
156{
157 let mut result = 0;
158 let mut pattern = Vec::new();
159 let mut springs = Vec::new();
160 // Exact size is not too important as long as there's enough space.
161 let mut broken = vec![0; 200];
162 let mut table = vec![0; 200 * 50];
163
164 for (first, second) in iter {
165 // Create input sequence reusing the buffers to minimize memory allocations.
166 pattern.clear();
167 springs.clear();
168
169 for _ in 1..repeat {
170 pattern.extend_from_slice(first);
171 pattern.push(b'?');
172 springs.extend_from_slice(second);
173 }
174
175 // Add a trailing '.' so that we don't have to check bounds when testing the last pattern.
176 // This has no effect on the number of possible combinations.
177 pattern.extend_from_slice(first);
178 pattern.push(b'.');
179 springs.extend_from_slice(second);
180
181 // Calculate prefix sum of the number of broken springs and unknowns before each index
182 // to quickly check if a range can contain a broken spring without checking every element.
183 // For example `.??..??...?##` becomes `[0, 0, 1, 2, 2, 2, 3, 4, 4, 4, 4, 5, 6, 7, 7]`.
184 let mut sum = 0;
185 broken.push(0);
186
187 for (i, &b) in pattern.iter().enumerate() {
188 if b != b'.' {
189 sum += 1;
190 }
191 broken[i + 1] = sum;
192 }
193
194 // Determine how many spaces each pattern can slide around to speed things up.
195 // We only need to check at most that many spaces for each pattern.
196 let wiggle = pattern.len() - springs.iter().sum::<usize>() - springs.len() + 1;
197
198 // Count combinations, handling the first row as a special case.
199 let size = springs[0];
200 let mut sum = 0;
201 let mut valid = true;
202
203 for i in 0..wiggle {
204 // In order to be a broken spring, an interval must only contains `#` or `?`
205 // characters and not have a '#' character immediately before or after.
206 if pattern[i + size] == b'#' {
207 sum = 0;
208 } else if valid && broken[i + size] - broken[i] == size {
209 sum += 1;
210 }
211
212 table[i + size] = sum;
213
214 // The first pattern can't have any '#' characters anywhere to its left
215 // otherwise it wouldn't be the first pattern.
216 valid &= pattern[i] != b'#';
217 }
218
219 // Count each subsequent spring. The previous patterns take at least the sum of their size
220 // and 1 space afterwards so no need to check indices before that.
221 let mut start = size + 1;
222
223 for (row, &size) in springs.iter().enumerate().skip(1) {
224 // We're using a 1 dimensional vec to implement a two dimensional table.
225 // Calculate the starting index of current and previous row for convenience.
226 let previous = (row - 1) * pattern.len();
227 let current = row * pattern.len();
228
229 // Reset the running sum.
230 sum = 0;
231
232 for i in start..start + wiggle {
233 // As a minor optimization only check the pattern if the previous row
234 // will contribute a non-zero value.
235 if pattern[i + size] == b'#' {
236 sum = 0;
237 } else if table[previous + i - 1] > 0
238 && pattern[i - 1] != b'#'
239 && broken[i + size] - broken[i] == size
240 {
241 sum += table[previous + i - 1];
242 }
243
244 table[current + i + size] = sum;
245 }
246
247 start += size + 1;
248 }
249
250 // The final value of sum (the bottom right of the table) is the number of possible
251 // arrangements of the pattern.
252 result += sum;
253 }
254
255 result
256}