Skip to main content

aoc/util/
thread.rs

1//! Utility methods to spawn a number of
2//! [scoped](https://doc.rust-lang.org/stable/std/thread/fn.scope.html)
3//! threads equal to the number of cores on the machine. Unlike normal threads, scoped threads
4//! can borrow data from their environment.
5use std::iter::repeat_with;
6use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering::Relaxed};
7use std::thread::*;
8
9/// Usually the number of physical cores.
10pub fn threads() -> usize {
11    available_parallelism().unwrap().get()
12}
13
14/// Spawn `n` scoped threads, where `n` is the available parallelism.
15pub fn spawn<F, R>(f: F) -> Vec<R>
16where
17    F: Fn() -> R + Copy + Send,
18    R: Send,
19{
20    scope(|scope| {
21        let handles: Vec<_> = repeat_with(|| scope.spawn(f)).take(threads()).collect();
22        handles.into_iter().flat_map(ScopedJoinHandle::join).collect()
23    })
24}
25
26/// Spawns `n` scoped threads that each receive a
27/// [work stealing](https://en.wikipedia.org/wiki/Work_stealing) iterator.
28/// Work stealing is an efficient strategy that keeps each CPU core busy when some items take longer
29/// than others to process, used by popular libraries such as [rayon](https://github.com/rayon-rs/rayon).
30/// Processing at different rates also happens on many modern CPUs with
31/// [heterogeneous performance and efficiency cores](https://en.wikipedia.org/wiki/ARM_big.LITTLE).
32pub fn spawn_parallel_iterator<F, R, T>(items: &[T], f: F) -> Vec<R>
33where
34    F: Fn(ParIter<'_, T>) -> R + Copy + Send,
35    R: Send,
36    T: Sync,
37{
38    let threads = threads();
39    let size = items.len().div_ceil(threads);
40
41    // Initially divide work as evenly as possible among the worker threads.
42    let workers: Vec<_> = (0..threads)
43        .map(|id| {
44            let start = (id * size).min(items.len());
45            let end = (start + size).min(items.len());
46            CachePadding::new(pack(start, end))
47        })
48        .collect();
49    let workers = workers.as_slice();
50
51    scope(|scope| {
52        let handles: Vec<_> =
53            (0..threads).map(|id| scope.spawn(move || f(ParIter { id, items, workers }))).collect();
54        handles.into_iter().flat_map(ScopedJoinHandle::join).collect()
55    })
56}
57
58pub struct ParIter<'a, T> {
59    id: usize,
60    items: &'a [T],
61    workers: &'a [CachePadding],
62}
63
64impl<'a, T> Iterator for ParIter<'a, T> {
65    type Item = &'a T;
66
67    fn next(&mut self) -> Option<&'a T> {
68        // First try taking from our own queue.
69        let worker = &self.workers[self.id];
70        let current = worker.increment();
71        let (start, end) = unpack(current);
72
73        // There are still items to process.
74        if start < end {
75            return Some(&self.items[start]);
76        }
77
78        // Steal from another worker, [spinlocking](https://en.wikipedia.org/wiki/Spinlock)
79        // until we acquire new items to process or there's nothing left to do.
80        loop {
81            // Find worker with the most remaining items, breaking out of the loop
82            // and returning `None` if there is no work remaining.
83            let (other, current, size) = self
84                .workers
85                .iter()
86                .filter_map(|other| {
87                    let current = other.load();
88                    let (start, end) = unpack(current);
89                    let size = end.saturating_sub(start);
90
91                    (size > 0).then_some((other, current, size))
92                })
93                .max_by_key(|&(_, _, size)| size)?;
94
95            // Split the work items into two roughly equal piles.
96            let (start, end) = unpack(current);
97            let middle = start + size.div_ceil(2);
98
99            let next = pack(middle, end);
100            let stolen = pack(start + 1, middle);
101
102            // We could be preempted by another thread stealing or by the owning worker
103            // thread finishing an item, so check indices are still unmodified.
104            if other.compare_exchange(current, next) {
105                worker.store(stolen);
106                break Some(&self.items[start]);
107            }
108        }
109    }
110}
111
112/// Intentionally force alignment to 128 bytes to make a best effort attempt to place each atomic
113/// on its own cache line. This reduces contention and improves performance for common
114/// CPU caching protocols such as [MESI](https://en.wikipedia.org/wiki/MESI_protocol).
115#[repr(align(128))]
116pub struct CachePadding {
117    atomic: AtomicUsize,
118}
119
120/// Convenience wrapper methods around atomic operations. Both start and end indices are packed
121/// into a single atomic so that we can use the fastest and easiest to reason about `Relaxed`
122/// ordering.
123impl CachePadding {
124    #[inline]
125    fn new(n: usize) -> Self {
126        CachePadding { atomic: AtomicUsize::new(n) }
127    }
128
129    #[inline]
130    fn increment(&self) -> usize {
131        self.atomic.fetch_add(1, Relaxed)
132    }
133
134    #[inline]
135    fn load(&self) -> usize {
136        self.atomic.load(Relaxed)
137    }
138
139    #[inline]
140    fn store(&self, n: usize) {
141        self.atomic.store(n, Relaxed);
142    }
143
144    #[inline]
145    fn compare_exchange(&self, current: usize, new: usize) -> bool {
146        self.atomic.compare_exchange(current, new, Relaxed, Relaxed).is_ok()
147    }
148}
149
150#[inline]
151fn pack(start: usize, end: usize) -> usize {
152    (end << 32) | start
153}
154
155#[inline]
156fn unpack(both: usize) -> (usize, usize) {
157    (both & 0xffffffff, both >> 32)
158}
159
160/// Shares a monotonically increasing value between multiple threads.
161pub struct AtomicIter {
162    running: AtomicBool,
163    index: AtomicU32,
164    step: u32,
165}
166
167impl AtomicIter {
168    pub fn new(start: u32, step: u32) -> Self {
169        AtomicIter { running: AtomicBool::new(true), index: AtomicU32::from(start), step }
170    }
171
172    pub fn next(&self) -> Option<u32> {
173        self.running.load(Relaxed).then(|| self.index.fetch_add(self.step, Relaxed))
174    }
175
176    pub fn stop(&self) {
177        self.running.store(false, Relaxed);
178    }
179}