aoc/year2016/
day14.rs

1//! # One-Time Pad
2//!
3//! Brute force slog through all possible keys, parallelized as much as possible. An optimization
4//! for part two is a quick method to convert `u32` to 8 ASCII digits.
5use crate::util::md5::*;
6use crate::util::thread::*;
7use std::collections::{BTreeMap, BTreeSet};
8use std::sync::Mutex;
9
10/// Atomics can be safely shared between threads.
11struct Shared<'a> {
12    input: &'a str,
13    part_two: bool,
14    iter: AtomicIter,
15    mutex: Mutex<Exclusive>,
16}
17
18/// Regular data structures need to be protected by a mutex.
19struct Exclusive {
20    threes: BTreeMap<i32, u32>,
21    fives: BTreeMap<i32, u32>,
22    found: BTreeSet<i32>,
23}
24
25pub fn parse(input: &str) -> &str {
26    input.trim()
27}
28
29/// Hash each key once.
30pub fn part1(input: &str) -> i32 {
31    generate_pad(input, false)
32}
33
34/// Hash each key an additional 2016 times.
35pub fn part2(input: &str) -> i32 {
36    generate_pad(input, true)
37}
38
39/// Find the first 64 keys that satisfy the rules.
40fn generate_pad(input: &str, part_two: bool) -> i32 {
41    let step = if cfg!(feature = "simd") { 32 } else { 1 };
42
43    let exclusive =
44        Exclusive { threes: BTreeMap::new(), fives: BTreeMap::new(), found: BTreeSet::new() };
45    let shared =
46        Shared { input, part_two, iter: AtomicIter::new(0, step), mutex: Mutex::new(exclusive) };
47
48    // Use as many cores as possible to parallelize the search.
49    #[cfg(not(feature = "simd"))]
50    spawn(|| scalar::worker(&shared));
51    #[cfg(feature = "simd")]
52    spawn(|| simd::worker(&shared));
53
54    let exclusive = shared.mutex.into_inner().unwrap();
55    *exclusive.found.iter().nth(63).unwrap()
56}
57
58/// Write the salt and integer index as ASCII characters.
59fn format_string(prefix: &str, mut n: i32) -> ([u8; 64], usize) {
60    let prefix_len = prefix.len();
61    let digits = n.max(1).ilog10() as usize + 1;
62    let size = prefix_len + digits;
63
64    let mut buffer = [0; 64];
65    buffer[..prefix_len].copy_from_slice(prefix.as_bytes());
66
67    for i in (prefix_len..size).rev() {
68        buffer[i] = b'0' + (n % 10) as u8;
69        n /= 10;
70    }
71
72    (buffer, size)
73}
74
75/// Quickly convert a `u32` to an array of 8 ASCII values.
76#[inline]
77fn to_ascii(n: u32) -> [u8; 8] {
78    // Spread each nibble into its own byte, for example `1234abcd` becomes `010203040a0b0c0d`.
79    let mut n = n as u64;
80    n = ((n << 16) & 0x0000ffff00000000) | (n & 0x000000000000ffff);
81    n = ((n << 8) & 0x00ff000000ff0000) | (n & 0x000000ff000000ff);
82    n = ((n << 4) & 0x0f000f000f000f00) | (n & 0x000f000f000f000f);
83
84    // If a digit is 0 to 9 then we need to add `0x30` to convert to an ASCII digit.
85    // For digits from 10 to 15 we need to further add `0x27` to convert to lowercase ASCII.
86    // Steps:
87    // * Add 6 to each digit
88    // * If digit is 10 or higher then the highest bit in each nibble will be set
89    // * Shift this bit to create a mask
90    // * Multiply mask by 0x27 to get ASCII conversion offset
91    // For example mask of `010203040a0b0c0d` is `0000000001010101`.
92
93    let mask = ((n + 0x0606060606060606) >> 4) & 0x0101010101010101;
94    n = n + 0x3030303030303030 + 0x27 * mask;
95    n.to_be_bytes()
96}
97
98#[cfg(not(feature = "simd"))]
99mod scalar {
100    use super::*;
101
102    pub(super) fn worker(shared: &Shared<'_>) {
103        while let Some(n) = shared.iter.next() {
104            // Get the next key to check.
105            let n = n as i32;
106
107            // Calculate the hash.
108            let (mut buffer, size) = format_string(shared.input, n);
109            let mut result = hash(&mut buffer, size);
110
111            if shared.part_two {
112                for _ in 0..2016 {
113                    buffer[0..8].copy_from_slice(&to_ascii(result[0]));
114                    buffer[8..16].copy_from_slice(&to_ascii(result[1]));
115                    buffer[16..24].copy_from_slice(&to_ascii(result[2]));
116                    buffer[24..32].copy_from_slice(&to_ascii(result[3]));
117                    result = hash(&mut buffer, 32);
118                }
119            }
120
121            check(shared, n, result);
122        }
123    }
124
125    /// Check for sequences of 3 or 5 consecutive matching digits.
126    fn check(shared: &Shared<'_>, n: i32, hash: [u32; 4]) {
127        let [a, b, c, d] = hash;
128
129        let mut prev = u32::MAX;
130        let mut same = 1;
131        let mut three = 0;
132        let mut five = 0;
133
134        for mut word in [d, c, b, a] {
135            for _ in 0..8 {
136                let next = word & 0xf;
137
138                if next == prev {
139                    same += 1;
140                } else {
141                    same = 1;
142                }
143
144                if same == 3 {
145                    three = 1 << next;
146                }
147                if same == 5 {
148                    five |= 1 << next;
149                }
150
151                word >>= 4;
152                prev = next;
153            }
154        }
155
156        if three != 0 || five != 0 {
157            let mut exclusive = shared.mutex.lock().unwrap();
158            let mut candidates = Vec::new();
159
160            // Compare against all 5 digit sequences.
161            if three != 0 {
162                exclusive.threes.insert(n, three);
163
164                for (_, mask) in exclusive.fives.range(n + 1..n + 1001) {
165                    if three & mask != 0 {
166                        candidates.push(n);
167                    }
168                }
169            }
170
171            // Compare against all 3 digit sequences.
172            if five != 0 {
173                exclusive.fives.insert(n, five);
174
175                for (&index, &mask) in exclusive.threes.range(n - 1000..n) {
176                    if five & mask != 0 {
177                        candidates.push(index);
178                    }
179                }
180            }
181
182            // Add any matching keys found, finishing once we have at least 64 keys.
183            exclusive.found.extend(candidates);
184
185            if exclusive.found.len() >= 64 {
186                shared.iter.stop();
187            }
188        }
189    }
190}
191
192#[cfg(feature = "simd")]
193mod simd {
194    use super::*;
195    use crate::util::bitset::*;
196    use crate::util::md5::simd::hash_fixed;
197    use std::simd::cmp::SimdPartialEq as _;
198    use std::simd::*;
199
200    /// Use SIMD to compute hashes in parallel in blocks of 32.
201    #[cfg(feature = "simd")]
202    #[expect(clippy::needless_range_loop)]
203    pub(super) fn worker(shared: &Shared<'_>) {
204        let mut result = [Simd::splat(0); 4];
205        let mut buffers = [[0; 64]; 32];
206
207        while let Some(start) = shared.iter.next() {
208            // Get the next key to check.
209            let start = start as i32;
210
211            // Calculate the hash.
212            for i in 0..32 {
213                let (mut buffer, size) = format_string(shared.input, start + i as i32);
214                let [a, b, c, d] = hash(&mut buffer, size);
215
216                result[0][i] = a;
217                result[1][i] = b;
218                result[2][i] = c;
219                result[3][i] = d;
220            }
221
222            if shared.part_two {
223                for _ in 0..2016 {
224                    for i in 0..32 {
225                        buffers[i][0..8].copy_from_slice(&to_ascii(result[0][i]));
226                        buffers[i][8..16].copy_from_slice(&to_ascii(result[1][i]));
227                        buffers[i][16..24].copy_from_slice(&to_ascii(result[2][i]));
228                        buffers[i][24..32].copy_from_slice(&to_ascii(result[3][i]));
229                    }
230                    result = hash_fixed(&mut buffers, 32);
231                }
232            }
233
234            check(shared, start, &result);
235        }
236    }
237
238    /// Check for sequences of 3 or 5 consecutive matching digits.
239    #[inline]
240    fn check(shared: &Shared<'_>, start: i32, hash: &[Simd<u32, 32>; 4]) {
241        let &[a, b, c, d] = hash;
242
243        let mut prev: Simd<u32, 32> = Simd::splat(u32::MAX);
244        let mut same: Simd<u32, 32> = Simd::splat(1);
245        let mut three: Simd<u32, 32> = Simd::splat(0);
246        let mut five: Simd<u32, 32> = Simd::splat(0);
247
248        for mut word in [d, c, b, a] {
249            for _ in 0..8 {
250                let next = word & Simd::splat(0xf);
251                same = next.simd_eq(prev).select(same + Simd::splat(1), Simd::splat(1));
252
253                three = same.simd_eq(Simd::splat(3)).select(Simd::splat(1) << next, three);
254                five |= same.simd_eq(Simd::splat(5)).select(Simd::splat(1) << next, Simd::splat(0));
255
256                word >>= 4;
257                prev = next;
258            }
259        }
260
261        let three_mask = three.simd_ne(Simd::splat(0)).to_bitmask();
262        let five_mask = five.simd_ne(Simd::splat(0)).to_bitmask();
263
264        if three_mask != 0 || five_mask != 0 {
265            let mut exclusive = shared.mutex.lock().unwrap();
266            let mut candidates = Vec::new();
267
268            for i in three_mask.biterator() {
269                let three = three[i];
270                let n = start + i as i32;
271
272                // Compare against all 5 digit sequences.
273                if three != 0 {
274                    exclusive.threes.insert(n, three);
275
276                    for (_, mask) in exclusive.fives.range(n + 1..n + 1001) {
277                        if three & mask != 0 {
278                            candidates.push(n);
279                        }
280                    }
281                }
282            }
283
284            for i in five_mask.biterator() {
285                let five = five[i];
286                let n = start + i as i32;
287
288                // Compare against all 3 digit sequences.
289                if five != 0 {
290                    exclusive.fives.insert(n, five);
291
292                    for (&index, &mask) in exclusive.threes.range(n - 1000..n) {
293                        if five & mask != 0 {
294                            candidates.push(index);
295                        }
296                    }
297                }
298            }
299
300            // Add any matching keys found, finishing once we have at least 64 keys.
301            exclusive.found.extend(candidates);
302
303            if exclusive.found.len() >= 64 {
304                shared.iter.stop();
305            }
306        }
307    }
308}