aoc/year2024/
day22.rs

1//! # Monkey Market
2//!
3//! Solves both parts simultaneously, parallelizing the work over multiple threads since
4//! each secret number is independent. The process of generating the next secret number is a
5//! [linear feedback shift register](https://en.wikipedia.org/wiki/Linear-feedback_shift_register).
6//! with a cycle of 2²⁴.
7//!
8//! Interestingly this means that with some clever math it's possible to generate the `n`th number
9//! from any starting secret number with only 24 calculations. Unfortunately this doesn't help for
10//! part two since we need to check every possible price change. However to speed things up we can
11//! make several optimizations:
12//!
13//! * First the sequence of 4 prices is converted from -9..9 to a base 19 index of 0..19.
14//! * Whether a monkey has seen a sequence before and the total bananas for each sequence are
15//!   stored in an array. This is much faster than a `HashMap`. Using base 19 gives much better
16//!   cache locality needing only 130321 elements, for example compared to shifting each new cost
17//!   by 5 bits and storing in an array of 2²⁰ = 1048675 elements. Multiplication on modern
18//!   processors is cheap (and several instructions can issue at once) but random memory access
19//!   is expensive.
20//!
21//! A SIMD variant processes 8 hashes at a time, taking about 60% of the time of the scalar version.
22//! The bottleneck is that disjoint indices must be written in sequence reducing the amount of work
23//! that can be parallelized.
24use crate::util::parse::*;
25use crate::util::thread::*;
26use std::sync::Mutex;
27
28type Input = (u64, u16);
29
30struct Exclusive {
31    part_one: u64,
32    part_two: Vec<u16>,
33}
34
35pub fn parse(input: &str) -> Input {
36    let mutex = Mutex::new(Exclusive { part_one: 0, part_two: vec![0; 130321] });
37
38    #[cfg(not(feature = "simd"))]
39    scalar::parallel(input, &mutex);
40    #[cfg(feature = "simd")]
41    simd::parallel(input, &mutex);
42
43    let Exclusive { part_one, part_two } = mutex.into_inner().unwrap();
44    (part_one, *part_two.iter().max().unwrap())
45}
46
47pub fn part1(input: &Input) -> u64 {
48    input.0
49}
50
51pub fn part2(input: &Input) -> u16 {
52    input.1
53}
54
55#[cfg(not(feature = "simd"))]
56mod scalar {
57    use super::*;
58
59    // Use as many cores as possible to parallelize the remaining search.
60    pub(super) fn parallel(input: &str, mutex: &Mutex<Exclusive>) {
61        let numbers: Vec<_> = input.iter_unsigned().collect();
62        spawn_parallel_iterator(&numbers, |iter| worker(mutex, iter));
63    }
64
65    fn worker(mutex: &Mutex<Exclusive>, iter: ParIter<'_, u32>) {
66        let mut part_one = 0;
67        let mut part_two = vec![0; 130321];
68        let mut seen = vec![u16::MAX; 130321];
69
70        for (id, number) in iter.enumerate() {
71            let id = id as u16;
72
73            let zeroth = *number;
74            let first = hash(zeroth);
75            let second = hash(first);
76            let third = hash(second);
77
78            let mut a;
79            let mut b = to_index(zeroth, first);
80            let mut c = to_index(first, second);
81            let mut d = to_index(second, third);
82
83            let mut number = third;
84            let mut previous = third % 10;
85
86            for _ in 3..2000 {
87                number = hash(number);
88                let price = number % 10;
89
90                // Compute index into the array.
91                (a, b, c, d) = (b, c, d, to_index(previous, price));
92                let index = (6859 * a + 361 * b + 19 * c + d) as usize;
93                previous = price;
94
95                // Only sell the first time we see a sequence.
96                // By storing the id in the array we don't need to zero every iteration which is faster.
97                if seen[index] != id {
98                    part_two[index] += price as u16;
99                    seen[index] = id;
100                }
101            }
102
103            part_one += number as u64;
104        }
105
106        // Merge into global results.
107        let mut exclusive = mutex.lock().unwrap();
108        exclusive.part_one += part_one;
109        exclusive.part_two.iter_mut().zip(part_two).for_each(|(a, b)| *a += b);
110    }
111
112    /// Compute next secret number using a
113    /// [Xorshift LFSR](https://en.wikipedia.org/wiki/Linear-feedback_shift_register#Xorshift_LFSRs).
114    fn hash(mut n: u32) -> u32 {
115        n = (n ^ (n << 6)) & 0xffffff;
116        n = (n ^ (n >> 5)) & 0xffffff;
117        (n ^ (n << 11)) & 0xffffff
118    }
119
120    /// Convert -9..9 to 0..18.
121    fn to_index(previous: u32, current: u32) -> u32 {
122        9 + current % 10 - previous % 10
123    }
124}
125
126#[cfg(feature = "simd")]
127mod simd {
128    use super::*;
129    use std::simd::Simd;
130    use std::simd::num::SimdUint as _;
131
132    type Vector = Simd<u32, 8>;
133
134    pub(super) fn parallel(input: &str, mutex: &Mutex<Exclusive>) {
135        let mut numbers: Vec<_> = input.iter_unsigned().collect();
136
137        // Add zero elements so that size is a multiple of 8.
138        // Zero always hashes to zero and does not contribute to score.
139        numbers.resize(numbers.len().next_multiple_of(8), 0);
140        let chunks: Vec<_> = numbers.chunks_exact(8).collect();
141
142        spawn_parallel_iterator(&chunks, |iter| worker(mutex, iter));
143    }
144
145    /// Similar to scalar version but using SIMD vectors instead.
146    /// 8 lanes is the sweet spot for performance as the bottleneck is the scalar loop writing
147    /// to disjoint indices after each step.
148    fn worker(mutex: &Mutex<Exclusive>, iter: ParIter<'_, &[u32]>) {
149        let ten = Simd::splat(10);
150        let x = Simd::splat(6859);
151        let y = Simd::splat(361);
152        let z = Simd::splat(19);
153
154        let mut part_one = 0;
155        let mut part_two = vec![0; 130321];
156
157        for slice in iter {
158            // Each lane uses a different bit to track if a sequence has been seen before.
159            let mut seen = vec![u8::MAX; 130321];
160
161            let zeroth = Simd::from_slice(slice);
162            let first = hash(zeroth);
163            let second = hash(first);
164            let third = hash(second);
165
166            let mut a;
167            let mut b = to_index(zeroth, first);
168            let mut c = to_index(first, second);
169            let mut d = to_index(second, third);
170
171            let mut number = third;
172            let mut previous = third % ten;
173
174            for _ in 3..2000 {
175                number = hash(number);
176                let prices = number % ten;
177
178                // Compute index into the array.
179                (a, b, c, d) = (b, c, d, to_index(previous, prices));
180                let indices = x * a + y * b + z * c + d;
181                previous = prices;
182
183                // Only sell the first time we see a sequence.
184                let indices = indices.to_array();
185                let prices = prices.to_array();
186
187                for i in 0..8 {
188                    let index = indices[i] as usize;
189
190                    // Avoid branching to improve speed, instead multiply by either 0 or 1,
191                    // depending if sequence has been seen before or not.
192                    let bit = (seen[index] >> i) & 1;
193                    seen[index] &= !(1 << i);
194
195                    part_two[index] += prices[i] as u16 * bit as u16;
196                }
197            }
198
199            part_one += number.reduce_sum() as u64;
200        }
201
202        // Merge into global results.
203        let mut exclusive = mutex.lock().unwrap();
204        exclusive.part_one += part_one;
205        exclusive.part_two.iter_mut().zip(part_two).for_each(|(a, b)| *a += b);
206    }
207
208    /// SIMD vector arguments are passed in memory so inline functions to avoid slow transfers
209    /// to and from memory.
210    #[inline]
211    fn hash(mut n: Vector) -> Vector {
212        let mask = Simd::splat(0xffffff);
213        n = (n ^ (n << 6)) & mask;
214        n = (n ^ (n >> 5)) & mask;
215        (n ^ (n << 11)) & mask
216    }
217
218    #[inline]
219    fn to_index(previous: Vector, current: Vector) -> Vector {
220        let nine = Simd::splat(9);
221        let ten = Simd::splat(10);
222        nine + (current % ten) - (previous % ten)
223    }
224}