aoc/year2016/
day05.rs

1//! # How About a Nice Game of Chess?
2//!
3//! Essentially a repeat of [`Year 2015 Day 4`]. We brute force MD5 hashes as quickly as
4//! possible in parallel in blocks of 1000 at a time.
5//!
6//! [`Year 2015 Day 4`]: crate::year2015::day04
7use crate::util::md5::*;
8use crate::util::thread::*;
9use std::sync::Mutex;
10use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
11
12struct Shared {
13    prefix: String,
14    done: AtomicBool,
15    counter: AtomicU32,
16    mutex: Mutex<Exclusive>,
17}
18
19struct Exclusive {
20    found: Vec<(u32, u32)>,
21    mask: u16,
22}
23
24pub fn parse(input: &str) -> Vec<u32> {
25    let shared = Shared {
26        prefix: input.trim().to_owned(),
27        done: AtomicBool::new(false),
28        counter: AtomicU32::new(1000),
29        mutex: Mutex::new(Exclusive { found: vec![], mask: 0 }),
30    };
31
32    // Handle the first 999 numbers specially as the number of digits varies.
33    for n in 1..1000 {
34        let (mut buffer, size) = format_string(&shared.prefix, n);
35        check_hash(&mut buffer, size, n, &shared);
36    }
37
38    // Use as many cores as possible to parallelize the remaining search.
39    spawn(|| {
40        #[cfg(not(feature = "simd"))]
41        worker(&shared);
42        #[cfg(feature = "simd")]
43        simd::worker(&shared);
44    });
45
46    let mut found = shared.mutex.into_inner().unwrap().found;
47    found.sort_unstable();
48    found.iter().map(|&(_, n)| n).collect()
49}
50
51pub fn part1(input: &[u32]) -> String {
52    let password = input.iter().take(8).fold(0, |acc, n| (acc << 4) | (n >> 8));
53    format!("{password:08x}")
54}
55
56pub fn part2(input: &[u32]) -> String {
57    let mut password = 0;
58    let mut mask = 0xffffffff;
59
60    for n in input {
61        let sixth = n >> 8;
62        if sixth < 8 {
63            let shift = 4 * (7 - sixth);
64            let seventh = (n >> 4) & 0xf;
65            password |= (seventh << shift) & mask;
66            mask &= !(0xf << shift);
67        }
68    }
69
70    format!("{password:08x}")
71}
72
73fn format_string(prefix: &str, n: u32) -> ([u8; 64], usize) {
74    let string = format!("{prefix}{n}");
75    let size = string.len();
76
77    let mut buffer = [0; 64];
78    buffer[0..size].copy_from_slice(string.as_bytes());
79
80    (buffer, size)
81}
82
83fn check_hash(buffer: &mut [u8], size: usize, n: u32, shared: &Shared) {
84    let (result, ..) = hash(buffer, size);
85
86    if result & 0xfffff000 == 0 {
87        let mut exclusive = shared.mutex.lock().unwrap();
88
89        exclusive.found.push((n, result));
90        exclusive.mask |= 1 << (result >> 8);
91
92        if exclusive.mask & 0xff == 0xff {
93            shared.done.store(true, Ordering::Relaxed);
94        }
95    }
96}
97
98#[cfg(not(feature = "simd"))]
99fn worker(shared: &Shared) {
100    while !shared.done.load(Ordering::Relaxed) {
101        let offset = shared.counter.fetch_add(1000, Ordering::Relaxed);
102        let (mut buffer, size) = format_string(&shared.prefix, offset);
103
104        for n in 0..1000 {
105            // Format macro is very slow, so update digits directly
106            buffer[size - 3] = b'0' + (n / 100) as u8;
107            buffer[size - 2] = b'0' + ((n / 10) % 10) as u8;
108            buffer[size - 1] = b'0' + (n % 10) as u8;
109
110            check_hash(&mut buffer, size, offset + n, shared);
111        }
112    }
113}
114
115#[cfg(feature = "simd")]
116mod simd {
117    use super::*;
118    use crate::util::md5::simd::hash;
119    use std::simd::{LaneCount, SupportedLaneCount};
120
121    #[expect(clippy::needless_range_loop)]
122    fn check_hash_simd<const N: usize>(
123        buffers: &mut [[u8; 64]],
124        size: usize,
125        start: u32,
126        offset: u32,
127        shared: &Shared,
128    ) where
129        LaneCount<N>: SupportedLaneCount,
130    {
131        // Format macro is very slow, so update digits directly
132        for i in 0..N {
133            let n = offset + i as u32;
134            buffers[i][size - 3] = b'0' + (n / 100) as u8;
135            buffers[i][size - 2] = b'0' + ((n / 10) % 10) as u8;
136            buffers[i][size - 1] = b'0' + (n % 10) as u8;
137        }
138
139        let (result, ..) = hash::<N>(buffers, size);
140
141        for i in 0..N {
142            if result[i] & 0xfffff000 == 0 {
143                let mut exclusive = shared.mutex.lock().unwrap();
144
145                exclusive.found.push((start + offset + i as u32, result[i]));
146                exclusive.mask |= 1 << (result[i] >> 8);
147
148                if exclusive.mask & 0xff == 0xff {
149                    shared.done.store(true, Ordering::Relaxed);
150                }
151            }
152        }
153    }
154
155    pub(super) fn worker(shared: &Shared) {
156        while !shared.done.load(Ordering::Relaxed) {
157            let start = shared.counter.fetch_add(1000, Ordering::Relaxed);
158            let (prefix, size) = format_string(&shared.prefix, start);
159            let mut buffers = [prefix; 32];
160
161            for offset in (0..992).step_by(32) {
162                check_hash_simd::<32>(&mut buffers, size, start, offset, shared);
163            }
164
165            check_hash_simd::<8>(&mut buffers, size, start, 992, shared);
166        }
167    }
168}