Skip to main content

aoc/year2016/
day14.rs

1//! # One-Time Pad
2//!
3//! Brute force slog through all possible keys, parallelized as much as possible. An optimization
4//! for part two is a quick method to convert `u32` to 8 ASCII digits.
5use crate::util::md5::*;
6use crate::util::thread::*;
7use implementation::*;
8use std::collections::{BTreeMap, BTreeSet};
9use std::sync::Mutex;
10
11/// Atomics can be safely shared between threads.
12struct Shared<'a> {
13    input: &'a str,
14    part_two: bool,
15    iter: AtomicIter,
16    mutex: Mutex<Exclusive>,
17}
18
19/// Regular data structures need to be protected by a mutex.
20struct Exclusive {
21    threes: BTreeMap<i32, u32>,
22    fives: BTreeMap<i32, u32>,
23    found: BTreeSet<i32>,
24}
25
26pub fn parse(input: &str) -> &str {
27    input.trim()
28}
29
30/// Hash each key once.
31pub fn part1(input: &str) -> i32 {
32    generate_pad(input, false)
33}
34
35/// Hash each key an additional 2016 times.
36pub fn part2(input: &str) -> i32 {
37    generate_pad(input, true)
38}
39
40/// Find the first 64 keys that satisfy the rules.
41fn generate_pad(input: &str, part_two: bool) -> i32 {
42    let step = if cfg!(feature = "simd") { 32 } else { 1 };
43
44    let exclusive =
45        Exclusive { threes: BTreeMap::new(), fives: BTreeMap::new(), found: BTreeSet::new() };
46    let shared =
47        Shared { input, part_two, iter: AtomicIter::new(0, step), mutex: Mutex::new(exclusive) };
48
49    // Use as many cores as possible to parallelize the search.
50    spawn(|| worker(&shared));
51
52    let exclusive = shared.mutex.into_inner().unwrap();
53    *exclusive.found.iter().nth(63).unwrap()
54}
55
56/// Write the salt and integer index as ASCII characters.
57fn format_string(prefix: &str, mut n: i32) -> ([u8; 64], usize) {
58    let prefix_len = prefix.len();
59    let digits = n.max(1).ilog10() as usize + 1;
60    let size = prefix_len + digits;
61
62    let mut buffer = [0; 64];
63    buffer[..prefix_len].copy_from_slice(prefix.as_bytes());
64
65    for i in (prefix_len..size).rev() {
66        buffer[i] = b'0' + (n % 10) as u8;
67        n /= 10;
68    }
69
70    (buffer, size)
71}
72
73/// Quickly convert a `u32` to an array of 8 ASCII values.
74#[inline]
75fn to_ascii(n: u32) -> [u8; 8] {
76    // Spread each nibble into its own byte, for example `1234abcd` becomes `010203040a0b0c0d`.
77    let mut n = n as u64;
78    n = ((n << 16) & 0x0000ffff00000000) | (n & 0x000000000000ffff);
79    n = ((n << 8) & 0x00ff000000ff0000) | (n & 0x000000ff000000ff);
80    n = ((n << 4) & 0x0f000f000f000f00) | (n & 0x000f000f000f000f);
81
82    // If a digit is 0 to 9 then we need to add `0x30` to convert to an ASCII digit.
83    // For digits from 10 to 15 we need to further add `0x27` to convert to lowercase ASCII.
84    // Steps:
85    // * Add 6 to each digit
86    // * If digit is 10 or higher then the highest bit in each nibble will be set
87    // * Shift this bit to create a mask
88    // * Multiply mask by 0x27 to get ASCII conversion offset
89    // For example, mask of `010203040a0b0c0d` is `0000000001010101`.
90
91    let mask = ((n + 0x0606060606060606) >> 4) & 0x0101010101010101;
92    n = n + 0x3030303030303030 + 0x27 * mask;
93    n.to_be_bytes()
94}
95
96#[cfg(not(feature = "simd"))]
97mod implementation {
98    use super::*;
99
100    pub(super) fn worker(shared: &Shared<'_>) {
101        while let Some(n) = shared.iter.next() {
102            // Get the next key to check.
103            let n = n as i32;
104
105            // Calculate the hash.
106            let (mut buffer, size) = format_string(shared.input, n);
107            let mut result = hash(&mut buffer, size);
108
109            if shared.part_two {
110                for _ in 0..2016 {
111                    buffer[0..8].copy_from_slice(&to_ascii(result[0]));
112                    buffer[8..16].copy_from_slice(&to_ascii(result[1]));
113                    buffer[16..24].copy_from_slice(&to_ascii(result[2]));
114                    buffer[24..32].copy_from_slice(&to_ascii(result[3]));
115                    result = hash(&mut buffer, 32);
116                }
117            }
118
119            check(shared, n, result);
120        }
121    }
122
123    /// Check for sequences of 3 or 5 consecutive matching digits.
124    fn check(shared: &Shared<'_>, n: i32, hash: [u32; 4]) {
125        let [a, b, c, d] = hash;
126
127        let mut prev = u32::MAX;
128        let mut same = 1;
129        let mut three = 0;
130        let mut five = 0;
131
132        for mut word in [d, c, b, a] {
133            for _ in 0..8 {
134                let next = word & 0xf;
135
136                if next == prev {
137                    same += 1;
138                } else {
139                    same = 1;
140                }
141
142                if same == 3 {
143                    three = 1 << next;
144                }
145                if same == 5 {
146                    five |= 1 << next;
147                }
148
149                word >>= 4;
150                prev = next;
151            }
152        }
153
154        if three != 0 || five != 0 {
155            let mut exclusive = shared.mutex.lock().unwrap();
156            let mut candidates = Vec::new();
157
158            // Compare against all 5 digit sequences.
159            if three != 0 {
160                exclusive.threes.insert(n, three);
161
162                for (_, mask) in exclusive.fives.range(n + 1..n + 1001) {
163                    if three & mask != 0 {
164                        candidates.push(n);
165                    }
166                }
167            }
168
169            // Compare against all 3 digit sequences.
170            if five != 0 {
171                exclusive.fives.insert(n, five);
172
173                for (&index, &mask) in exclusive.threes.range(n - 1000..n) {
174                    if five & mask != 0 {
175                        candidates.push(index);
176                    }
177                }
178            }
179
180            // Add any matching keys found, finishing once we have at least 64 keys.
181            exclusive.found.extend(candidates);
182
183            if exclusive.found.len() >= 64 {
184                shared.iter.stop();
185            }
186        }
187    }
188}
189
190#[cfg(feature = "simd")]
191mod implementation {
192    use super::*;
193    use crate::util::bitset::*;
194    use crate::util::md5::simd::hash_fixed;
195    use std::simd::cmp::SimdPartialEq as _;
196    use std::simd::*;
197
198    /// Use SIMD to compute hashes in parallel in blocks of 32.
199    #[expect(clippy::needless_range_loop)]
200    pub(super) fn worker(shared: &Shared<'_>) {
201        let mut result = [Simd::splat(0); 4];
202        let mut buffers = [[0; 64]; 32];
203
204        while let Some(start) = shared.iter.next() {
205            // Get the next key to check.
206            let start = start as i32;
207
208            // Calculate the hash.
209            for i in 0..32 {
210                let (mut buffer, size) = format_string(shared.input, start + i as i32);
211                let [a, b, c, d] = hash(&mut buffer, size);
212
213                result[0][i] = a;
214                result[1][i] = b;
215                result[2][i] = c;
216                result[3][i] = d;
217            }
218
219            if shared.part_two {
220                for _ in 0..2016 {
221                    for i in 0..32 {
222                        buffers[i][0..8].copy_from_slice(&to_ascii(result[0][i]));
223                        buffers[i][8..16].copy_from_slice(&to_ascii(result[1][i]));
224                        buffers[i][16..24].copy_from_slice(&to_ascii(result[2][i]));
225                        buffers[i][24..32].copy_from_slice(&to_ascii(result[3][i]));
226                    }
227                    result = hash_fixed(&mut buffers, 32);
228                }
229            }
230
231            check(shared, start, &result);
232        }
233    }
234
235    /// Check for sequences of 3 or 5 consecutive matching digits.
236    #[inline]
237    fn check(shared: &Shared<'_>, start: i32, hash: &[Simd<u32, 32>; 4]) {
238        let &[a, b, c, d] = hash;
239
240        let mut prev: Simd<u32, 32> = Simd::splat(u32::MAX);
241        let mut same: Simd<u32, 32> = Simd::splat(1);
242        let mut three: Simd<u32, 32> = Simd::splat(0);
243        let mut five: Simd<u32, 32> = Simd::splat(0);
244
245        for mut word in [d, c, b, a] {
246            for _ in 0..8 {
247                let next = word & Simd::splat(0xf);
248                same = next.simd_eq(prev).select(same + Simd::splat(1), Simd::splat(1));
249
250                three = same.simd_eq(Simd::splat(3)).select(Simd::splat(1) << next, three);
251                five |= same.simd_eq(Simd::splat(5)).select(Simd::splat(1) << next, Simd::splat(0));
252
253                word >>= 4;
254                prev = next;
255            }
256        }
257
258        let three_mask = three.simd_ne(Simd::splat(0)).to_bitmask();
259        let five_mask = five.simd_ne(Simd::splat(0)).to_bitmask();
260
261        if three_mask != 0 || five_mask != 0 {
262            let mut exclusive = shared.mutex.lock().unwrap();
263            let mut candidates = Vec::new();
264
265            for i in three_mask.biterator() {
266                let three = three[i];
267                let n = start + i as i32;
268
269                // Compare against all 5 digit sequences.
270                if three != 0 {
271                    exclusive.threes.insert(n, three);
272
273                    for (_, mask) in exclusive.fives.range(n + 1..n + 1001) {
274                        if three & mask != 0 {
275                            candidates.push(n);
276                        }
277                    }
278                }
279            }
280
281            for i in five_mask.biterator() {
282                let five = five[i];
283                let n = start + i as i32;
284
285                // Compare against all 3 digit sequences.
286                if five != 0 {
287                    exclusive.fives.insert(n, five);
288
289                    for (&index, &mask) in exclusive.threes.range(n - 1000..n) {
290                        if five & mask != 0 {
291                            candidates.push(index);
292                        }
293                    }
294                }
295            }
296
297            // Add any matching keys found, finishing once we have at least 64 keys.
298            exclusive.found.extend(candidates);
299
300            if exclusive.found.len() >= 64 {
301                shared.iter.stop();
302            }
303        }
304    }
305}