aoc/util/
md5.rs

1//! MD5 hash algorithm
2//!
3//! Computes a 128bit [MD5 hash](https://en.wikipedia.org/wiki/MD5) for a slice of `u8`.
4//! The hash is returned as a tuple of four `u32` values.
5//!
6//! The slice is modified in place and must be a multiple of 64 bytes long with at least 9 bytes
7//! spare for the md5 padding. The [`buffer_size`] method calculates the necessary size.
8//!
9//! To maximize speed the loop for each of the four rounds used to create the hash is unrolled and
10//! all internal utility functions marked as
11//! [`#[inline]`](https://doc.rust-lang.org/reference/attributes/codegen.html#the-inline-attribute).
12//!
13//! An optional SIMD variant that computes multiple hashes in parallel is also implemented.
14pub fn buffer_size(n: usize) -> usize {
15    (n + 9).next_multiple_of(64)
16}
17
18#[inline]
19pub fn hash(buffer: &mut [u8], size: usize) -> [u32; 4] {
20    let end = buffer.len() - 8;
21    let bits = size * 8;
22
23    buffer[size] = 0x80;
24    buffer[end..].copy_from_slice(&bits.to_le_bytes());
25
26    let mut m = [0; 16];
27    let [mut a0, mut b0, mut c0, mut d0] = [0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476];
28
29    for block in buffer.chunks_exact(64) {
30        for (i, chunk) in block.chunks_exact(4).enumerate() {
31            m[i] = u32::from_le_bytes(chunk.try_into().unwrap());
32        }
33
34        let [mut a, mut b, mut c, mut d] = [a0, b0, c0, d0];
35
36        a = round1(a, b, c, d, m[0], 7, 0xd76aa478);
37        d = round1(d, a, b, c, m[1], 12, 0xe8c7b756);
38        c = round1(c, d, a, b, m[2], 17, 0x242070db);
39        b = round1(b, c, d, a, m[3], 22, 0xc1bdceee);
40        a = round1(a, b, c, d, m[4], 7, 0xf57c0faf);
41        d = round1(d, a, b, c, m[5], 12, 0x4787c62a);
42        c = round1(c, d, a, b, m[6], 17, 0xa8304613);
43        b = round1(b, c, d, a, m[7], 22, 0xfd469501);
44        a = round1(a, b, c, d, m[8], 7, 0x698098d8);
45        d = round1(d, a, b, c, m[9], 12, 0x8b44f7af);
46        c = round1(c, d, a, b, m[10], 17, 0xffff5bb1);
47        b = round1(b, c, d, a, m[11], 22, 0x895cd7be);
48        a = round1(a, b, c, d, m[12], 7, 0x6b901122);
49        d = round1(d, a, b, c, m[13], 12, 0xfd987193);
50        c = round1(c, d, a, b, m[14], 17, 0xa679438e);
51        b = round1(b, c, d, a, m[15], 22, 0x49b40821);
52
53        a = round2(a, b, c, d, m[1], 5, 0xf61e2562);
54        d = round2(d, a, b, c, m[6], 9, 0xc040b340);
55        c = round2(c, d, a, b, m[11], 14, 0x265e5a51);
56        b = round2(b, c, d, a, m[0], 20, 0xe9b6c7aa);
57        a = round2(a, b, c, d, m[5], 5, 0xd62f105d);
58        d = round2(d, a, b, c, m[10], 9, 0x02441453);
59        c = round2(c, d, a, b, m[15], 14, 0xd8a1e681);
60        b = round2(b, c, d, a, m[4], 20, 0xe7d3fbc8);
61        a = round2(a, b, c, d, m[9], 5, 0x21e1cde6);
62        d = round2(d, a, b, c, m[14], 9, 0xc33707d6);
63        c = round2(c, d, a, b, m[3], 14, 0xf4d50d87);
64        b = round2(b, c, d, a, m[8], 20, 0x455a14ed);
65        a = round2(a, b, c, d, m[13], 5, 0xa9e3e905);
66        d = round2(d, a, b, c, m[2], 9, 0xfcefa3f8);
67        c = round2(c, d, a, b, m[7], 14, 0x676f02d9);
68        b = round2(b, c, d, a, m[12], 20, 0x8d2a4c8a);
69
70        a = round3(a, b, c, d, m[5], 4, 0xfffa3942);
71        d = round3(d, a, b, c, m[8], 11, 0x8771f681);
72        c = round3(c, d, a, b, m[11], 16, 0x6d9d6122);
73        b = round3(b, c, d, a, m[14], 23, 0xfde5380c);
74        a = round3(a, b, c, d, m[1], 4, 0xa4beea44);
75        d = round3(d, a, b, c, m[4], 11, 0x4bdecfa9);
76        c = round3(c, d, a, b, m[7], 16, 0xf6bb4b60);
77        b = round3(b, c, d, a, m[10], 23, 0xbebfbc70);
78        a = round3(a, b, c, d, m[13], 4, 0x289b7ec6);
79        d = round3(d, a, b, c, m[0], 11, 0xeaa127fa);
80        c = round3(c, d, a, b, m[3], 16, 0xd4ef3085);
81        b = round3(b, c, d, a, m[6], 23, 0x04881d05);
82        a = round3(a, b, c, d, m[9], 4, 0xd9d4d039);
83        d = round3(d, a, b, c, m[12], 11, 0xe6db99e5);
84        c = round3(c, d, a, b, m[15], 16, 0x1fa27cf8);
85        b = round3(b, c, d, a, m[2], 23, 0xc4ac5665);
86
87        a = round4(a, b, c, d, m[0], 6, 0xf4292244);
88        d = round4(d, a, b, c, m[7], 10, 0x432aff97);
89        c = round4(c, d, a, b, m[14], 15, 0xab9423a7);
90        b = round4(b, c, d, a, m[5], 21, 0xfc93a039);
91        a = round4(a, b, c, d, m[12], 6, 0x655b59c3);
92        d = round4(d, a, b, c, m[3], 10, 0x8f0ccc92);
93        c = round4(c, d, a, b, m[10], 15, 0xffeff47d);
94        b = round4(b, c, d, a, m[1], 21, 0x85845dd1);
95        a = round4(a, b, c, d, m[8], 6, 0x6fa87e4f);
96        d = round4(d, a, b, c, m[15], 10, 0xfe2ce6e0);
97        c = round4(c, d, a, b, m[6], 15, 0xa3014314);
98        b = round4(b, c, d, a, m[13], 21, 0x4e0811a1);
99        a = round4(a, b, c, d, m[4], 6, 0xf7537e82);
100        d = round4(d, a, b, c, m[11], 10, 0xbd3af235);
101        c = round4(c, d, a, b, m[2], 15, 0x2ad7d2bb);
102        b = round4(b, c, d, a, m[9], 21, 0xeb86d391);
103
104        [a0, b0, c0, d0] =
105            [a0.wrapping_add(a), b0.wrapping_add(b), c0.wrapping_add(c), d0.wrapping_add(d)];
106    }
107
108    [a0.to_be(), b0.to_be(), c0.to_be(), d0.to_be()]
109}
110
111#[inline]
112fn round1(a: u32, b: u32, c: u32, d: u32, m: u32, s: u32, k: u32) -> u32 {
113    let f = (b & c) | (!b & d);
114    common(f, a, b, m, s, k)
115}
116
117#[inline]
118fn round2(a: u32, b: u32, c: u32, d: u32, m: u32, s: u32, k: u32) -> u32 {
119    let f = (b & d) | (c & !d);
120    common(f, a, b, m, s, k)
121}
122
123#[inline]
124fn round3(a: u32, b: u32, c: u32, d: u32, m: u32, s: u32, k: u32) -> u32 {
125    let f = b ^ c ^ d;
126    common(f, a, b, m, s, k)
127}
128
129#[inline]
130fn round4(a: u32, b: u32, c: u32, d: u32, m: u32, s: u32, k: u32) -> u32 {
131    let f = c ^ (b | !d);
132    common(f, a, b, m, s, k)
133}
134
135#[inline]
136fn common(f: u32, a: u32, b: u32, m: u32, s: u32, k: u32) -> u32 {
137    f.wrapping_add(a).wrapping_add(k).wrapping_add(m).rotate_left(s).wrapping_add(b)
138}
139
140#[cfg(feature = "simd")]
141pub mod simd {
142    use std::array::from_fn;
143    use std::simd::num::SimdUint as _;
144    use std::simd::*;
145
146    #[inline]
147    pub fn hash_fixed<const N: usize>(buffers: &mut [[u8; 64]; N], size: usize) -> [Simd<u32, N>; 4]
148    where
149        LaneCount<N>: SupportedLaneCount,
150    {
151        // Assume all buffers are the same size.
152        for buffer in buffers.iter_mut() {
153            buffer[size] = 0x80;
154        }
155
156        let [a0, b0, c0, d0] = [
157            Simd::splat(0x67452301),
158            Simd::splat(0xefcdab89),
159            Simd::splat(0x98badcfe),
160            Simd::splat(0x10325476),
161        ];
162        let [mut a, mut b, mut c, mut d] = [a0, b0, c0, d0];
163
164        let m0 = message(buffers, 0, size);
165        a = round1(a, b, c, d, m0, 7, 0xd76aa478);
166        let m1 = message(buffers, 4, size);
167        d = round1(d, a, b, c, m1, 12, 0xe8c7b756);
168        let m2 = message(buffers, 8, size);
169        c = round1(c, d, a, b, m2, 17, 0x242070db);
170        let m3 = message(buffers, 12, size);
171        b = round1(b, c, d, a, m3, 22, 0xc1bdceee);
172        let m4 = message(buffers, 16, size);
173        a = round1(a, b, c, d, m4, 7, 0xf57c0faf);
174        let m5 = message(buffers, 20, size);
175        d = round1(d, a, b, c, m5, 12, 0x4787c62a);
176        let m6 = message(buffers, 24, size);
177        c = round1(c, d, a, b, m6, 17, 0xa8304613);
178        let m7 = message(buffers, 28, size);
179        b = round1(b, c, d, a, m7, 22, 0xfd469501);
180        let m8 = message(buffers, 32, size);
181        a = round1(a, b, c, d, m8, 7, 0x698098d8);
182        let m9 = message(buffers, 36, size);
183        d = round1(d, a, b, c, m9, 12, 0x8b44f7af);
184        let m10 = message(buffers, 40, size);
185        c = round1(c, d, a, b, m10, 17, 0xffff5bb1);
186        let m11 = message(buffers, 44, size);
187        b = round1(b, c, d, a, m11, 22, 0x895cd7be);
188        let m12 = message(buffers, 48, size);
189        a = round1(a, b, c, d, m12, 7, 0x6b901122);
190        let m13 = message(buffers, 52, size);
191        d = round1(d, a, b, c, m13, 12, 0xfd987193);
192        let m14 = Simd::splat(size as u32 * 8);
193        c = round1(c, d, a, b, m14, 17, 0xa679438e);
194        let m15 = Simd::splat(0);
195        b = round1(b, c, d, a, m15, 22, 0x49b40821);
196
197        a = round2(a, b, c, d, m1, 5, 0xf61e2562);
198        d = round2(d, a, b, c, m6, 9, 0xc040b340);
199        c = round2(c, d, a, b, m11, 14, 0x265e5a51);
200        b = round2(b, c, d, a, m0, 20, 0xe9b6c7aa);
201        a = round2(a, b, c, d, m5, 5, 0xd62f105d);
202        d = round2(d, a, b, c, m10, 9, 0x02441453);
203        c = round2(c, d, a, b, m15, 14, 0xd8a1e681);
204        b = round2(b, c, d, a, m4, 20, 0xe7d3fbc8);
205        a = round2(a, b, c, d, m9, 5, 0x21e1cde6);
206        d = round2(d, a, b, c, m14, 9, 0xc33707d6);
207        c = round2(c, d, a, b, m3, 14, 0xf4d50d87);
208        b = round2(b, c, d, a, m8, 20, 0x455a14ed);
209        a = round2(a, b, c, d, m13, 5, 0xa9e3e905);
210        d = round2(d, a, b, c, m2, 9, 0xfcefa3f8);
211        c = round2(c, d, a, b, m7, 14, 0x676f02d9);
212        b = round2(b, c, d, a, m12, 20, 0x8d2a4c8a);
213
214        a = round3(a, b, c, d, m5, 4, 0xfffa3942);
215        d = round3(d, a, b, c, m8, 11, 0x8771f681);
216        c = round3(c, d, a, b, m11, 16, 0x6d9d6122);
217        b = round3(b, c, d, a, m14, 23, 0xfde5380c);
218        a = round3(a, b, c, d, m1, 4, 0xa4beea44);
219        d = round3(d, a, b, c, m4, 11, 0x4bdecfa9);
220        c = round3(c, d, a, b, m7, 16, 0xf6bb4b60);
221        b = round3(b, c, d, a, m10, 23, 0xbebfbc70);
222        a = round3(a, b, c, d, m13, 4, 0x289b7ec6);
223        d = round3(d, a, b, c, m0, 11, 0xeaa127fa);
224        c = round3(c, d, a, b, m3, 16, 0xd4ef3085);
225        b = round3(b, c, d, a, m6, 23, 0x04881d05);
226        a = round3(a, b, c, d, m9, 4, 0xd9d4d039);
227        d = round3(d, a, b, c, m12, 11, 0xe6db99e5);
228        c = round3(c, d, a, b, m15, 16, 0x1fa27cf8);
229        b = round3(b, c, d, a, m2, 23, 0xc4ac5665);
230
231        a = round4(a, b, c, d, m0, 6, 0xf4292244);
232        d = round4(d, a, b, c, m7, 10, 0x432aff97);
233        c = round4(c, d, a, b, m14, 15, 0xab9423a7);
234        b = round4(b, c, d, a, m5, 21, 0xfc93a039);
235        a = round4(a, b, c, d, m12, 6, 0x655b59c3);
236        d = round4(d, a, b, c, m3, 10, 0x8f0ccc92);
237        c = round4(c, d, a, b, m10, 15, 0xffeff47d);
238        b = round4(b, c, d, a, m1, 21, 0x85845dd1);
239        a = round4(a, b, c, d, m8, 6, 0x6fa87e4f);
240        d = round4(d, a, b, c, m15, 10, 0xfe2ce6e0);
241        c = round4(c, d, a, b, m6, 15, 0xa3014314);
242        b = round4(b, c, d, a, m13, 21, 0x4e0811a1);
243        a = round4(a, b, c, d, m4, 6, 0xf7537e82);
244        d = round4(d, a, b, c, m11, 10, 0xbd3af235);
245        c = round4(c, d, a, b, m2, 15, 0x2ad7d2bb);
246        b = round4(b, c, d, a, m9, 21, 0xeb86d391);
247
248        [(a0 + a).swap_bytes(), (b0 + b).swap_bytes(), (c0 + c).swap_bytes(), (d0 + d).swap_bytes()]
249    }
250
251    #[inline]
252    fn message<const N: usize>(buffers: &[[u8; 64]; N], i: usize, size: usize) -> Simd<u32, N>
253    where
254        LaneCount<N>: SupportedLaneCount,
255    {
256        if i > size {
257            Simd::splat(0)
258        } else {
259            Simd::from_array(from_fn(|lane| {
260                let slice = &buffers[lane][i..i + 4];
261                u32::from_le_bytes(slice.try_into().unwrap())
262            }))
263        }
264    }
265
266    #[inline]
267    fn round1<const N: usize>(
268        a: Simd<u32, N>,
269        b: Simd<u32, N>,
270        c: Simd<u32, N>,
271        d: Simd<u32, N>,
272        m: Simd<u32, N>,
273        s: u32,
274        k: u32,
275    ) -> Simd<u32, N>
276    where
277        LaneCount<N>: SupportedLaneCount,
278    {
279        let f = (b & c) | (!b & d);
280        common(f, a, b, m, s, k)
281    }
282
283    #[inline]
284    fn round2<const N: usize>(
285        a: Simd<u32, N>,
286        b: Simd<u32, N>,
287        c: Simd<u32, N>,
288        d: Simd<u32, N>,
289        m: Simd<u32, N>,
290        s: u32,
291        k: u32,
292    ) -> Simd<u32, N>
293    where
294        LaneCount<N>: SupportedLaneCount,
295    {
296        let f = (b & d) | (c & !d);
297        common(f, a, b, m, s, k)
298    }
299
300    #[inline]
301    fn round3<const N: usize>(
302        a: Simd<u32, N>,
303        b: Simd<u32, N>,
304        c: Simd<u32, N>,
305        d: Simd<u32, N>,
306        m: Simd<u32, N>,
307        s: u32,
308        k: u32,
309    ) -> Simd<u32, N>
310    where
311        LaneCount<N>: SupportedLaneCount,
312    {
313        let f = b ^ c ^ d;
314        common(f, a, b, m, s, k)
315    }
316
317    #[inline]
318    fn round4<const N: usize>(
319        a: Simd<u32, N>,
320        b: Simd<u32, N>,
321        c: Simd<u32, N>,
322        d: Simd<u32, N>,
323        m: Simd<u32, N>,
324        s: u32,
325        k: u32,
326    ) -> Simd<u32, N>
327    where
328        LaneCount<N>: SupportedLaneCount,
329    {
330        let f = c ^ (b | !d);
331        common(f, a, b, m, s, k)
332    }
333
334    #[inline]
335    fn common<const N: usize>(
336        f: Simd<u32, N>,
337        a: Simd<u32, N>,
338        b: Simd<u32, N>,
339        m: Simd<u32, N>,
340        s: u32,
341        k: u32,
342    ) -> Simd<u32, N>
343    where
344        LaneCount<N>: SupportedLaneCount,
345    {
346        let k = Simd::splat(k);
347        let first = f + a + k + m;
348        let second = (first << s) | (first >> (32 - s));
349        second + b
350    }
351}