aoc/util/
md5.rs

1//! MD5 hash algorithm
2//!
3//! Computes a 128bit [MD5 hash](https://en.wikipedia.org/wiki/MD5) for a slice of `u8`.
4//! The hash is returned as a tuple of four `u32` values.
5//!
6//! The slice is modified in place and must be a multiple of 64 bytes long with at least 9 bytes
7//! spare for the md5 padding. The [`buffer_size`] method calculates the necessary size.
8//!
9//! To maximize speed the loop for each of the four rounds used to create the hash is unrolled and
10//! all internal utility functions marked as
11//! [`#[inline]`](https://doc.rust-lang.org/reference/attributes/codegen.html#the-inline-attribute).
12//!
13//! An optional SIMD variant that computes multiple hashes in parallel is also implemented.
14
15pub fn buffer_size(n: usize) -> usize {
16    (n + 9).next_multiple_of(64)
17}
18
19pub fn hash(mut buffer: &mut [u8], size: usize) -> (u32, u32, u32, u32) {
20    let end = buffer.len() - 8;
21    let bits = size * 8;
22
23    buffer[size] = 0x80;
24    buffer[end..].copy_from_slice(&bits.to_le_bytes());
25
26    let mut m = [0; 16];
27    let mut a0: u32 = 0x67452301;
28    let mut b0: u32 = 0xefcdab89;
29    let mut c0: u32 = 0x98badcfe;
30    let mut d0: u32 = 0x10325476;
31
32    while !buffer.is_empty() {
33        let (prefix, suffix) = buffer.split_at_mut(64);
34        buffer = suffix;
35
36        for (i, chunk) in prefix.chunks_exact(4).enumerate() {
37            m[i] = u32::from_le_bytes(chunk.try_into().unwrap());
38        }
39
40        let mut a = a0;
41        let mut b = b0;
42        let mut c = c0;
43        let mut d = d0;
44
45        a = round1(a, b, c, d, m[0], 7, 0xd76aa478);
46        d = round1(d, a, b, c, m[1], 12, 0xe8c7b756);
47        c = round1(c, d, a, b, m[2], 17, 0x242070db);
48        b = round1(b, c, d, a, m[3], 22, 0xc1bdceee);
49        a = round1(a, b, c, d, m[4], 7, 0xf57c0faf);
50        d = round1(d, a, b, c, m[5], 12, 0x4787c62a);
51        c = round1(c, d, a, b, m[6], 17, 0xa8304613);
52        b = round1(b, c, d, a, m[7], 22, 0xfd469501);
53        a = round1(a, b, c, d, m[8], 7, 0x698098d8);
54        d = round1(d, a, b, c, m[9], 12, 0x8b44f7af);
55        c = round1(c, d, a, b, m[10], 17, 0xffff5bb1);
56        b = round1(b, c, d, a, m[11], 22, 0x895cd7be);
57        a = round1(a, b, c, d, m[12], 7, 0x6b901122);
58        d = round1(d, a, b, c, m[13], 12, 0xfd987193);
59        c = round1(c, d, a, b, m[14], 17, 0xa679438e);
60        b = round1(b, c, d, a, m[15], 22, 0x49b40821);
61
62        a = round2(a, b, c, d, m[1], 5, 0xf61e2562);
63        d = round2(d, a, b, c, m[6], 9, 0xc040b340);
64        c = round2(c, d, a, b, m[11], 14, 0x265e5a51);
65        b = round2(b, c, d, a, m[0], 20, 0xe9b6c7aa);
66        a = round2(a, b, c, d, m[5], 5, 0xd62f105d);
67        d = round2(d, a, b, c, m[10], 9, 0x02441453);
68        c = round2(c, d, a, b, m[15], 14, 0xd8a1e681);
69        b = round2(b, c, d, a, m[4], 20, 0xe7d3fbc8);
70        a = round2(a, b, c, d, m[9], 5, 0x21e1cde6);
71        d = round2(d, a, b, c, m[14], 9, 0xc33707d6);
72        c = round2(c, d, a, b, m[3], 14, 0xf4d50d87);
73        b = round2(b, c, d, a, m[8], 20, 0x455a14ed);
74        a = round2(a, b, c, d, m[13], 5, 0xa9e3e905);
75        d = round2(d, a, b, c, m[2], 9, 0xfcefa3f8);
76        c = round2(c, d, a, b, m[7], 14, 0x676f02d9);
77        b = round2(b, c, d, a, m[12], 20, 0x8d2a4c8a);
78
79        a = round3(a, b, c, d, m[5], 4, 0xfffa3942);
80        d = round3(d, a, b, c, m[8], 11, 0x8771f681);
81        c = round3(c, d, a, b, m[11], 16, 0x6d9d6122);
82        b = round3(b, c, d, a, m[14], 23, 0xfde5380c);
83        a = round3(a, b, c, d, m[1], 4, 0xa4beea44);
84        d = round3(d, a, b, c, m[4], 11, 0x4bdecfa9);
85        c = round3(c, d, a, b, m[7], 16, 0xf6bb4b60);
86        b = round3(b, c, d, a, m[10], 23, 0xbebfbc70);
87        a = round3(a, b, c, d, m[13], 4, 0x289b7ec6);
88        d = round3(d, a, b, c, m[0], 11, 0xeaa127fa);
89        c = round3(c, d, a, b, m[3], 16, 0xd4ef3085);
90        b = round3(b, c, d, a, m[6], 23, 0x04881d05);
91        a = round3(a, b, c, d, m[9], 4, 0xd9d4d039);
92        d = round3(d, a, b, c, m[12], 11, 0xe6db99e5);
93        c = round3(c, d, a, b, m[15], 16, 0x1fa27cf8);
94        b = round3(b, c, d, a, m[2], 23, 0xc4ac5665);
95
96        a = round4(a, b, c, d, m[0], 6, 0xf4292244);
97        d = round4(d, a, b, c, m[7], 10, 0x432aff97);
98        c = round4(c, d, a, b, m[14], 15, 0xab9423a7);
99        b = round4(b, c, d, a, m[5], 21, 0xfc93a039);
100        a = round4(a, b, c, d, m[12], 6, 0x655b59c3);
101        d = round4(d, a, b, c, m[3], 10, 0x8f0ccc92);
102        c = round4(c, d, a, b, m[10], 15, 0xffeff47d);
103        b = round4(b, c, d, a, m[1], 21, 0x85845dd1);
104        a = round4(a, b, c, d, m[8], 6, 0x6fa87e4f);
105        d = round4(d, a, b, c, m[15], 10, 0xfe2ce6e0);
106        c = round4(c, d, a, b, m[6], 15, 0xa3014314);
107        b = round4(b, c, d, a, m[13], 21, 0x4e0811a1);
108        a = round4(a, b, c, d, m[4], 6, 0xf7537e82);
109        d = round4(d, a, b, c, m[11], 10, 0xbd3af235);
110        c = round4(c, d, a, b, m[2], 15, 0x2ad7d2bb);
111        b = round4(b, c, d, a, m[9], 21, 0xeb86d391);
112
113        a0 = a0.wrapping_add(a);
114        b0 = b0.wrapping_add(b);
115        c0 = c0.wrapping_add(c);
116        d0 = d0.wrapping_add(d);
117    }
118
119    (a0.to_be(), b0.to_be(), c0.to_be(), d0.to_be())
120}
121
122#[inline]
123fn round1(a: u32, b: u32, c: u32, d: u32, m: u32, s: u32, k: u32) -> u32 {
124    let f = (b & c) | (!b & d);
125    common(f, a, b, m, s, k)
126}
127
128#[inline]
129fn round2(a: u32, b: u32, c: u32, d: u32, m: u32, s: u32, k: u32) -> u32 {
130    let f = (b & d) | (c & !d);
131    common(f, a, b, m, s, k)
132}
133
134#[inline]
135fn round3(a: u32, b: u32, c: u32, d: u32, m: u32, s: u32, k: u32) -> u32 {
136    let f = b ^ c ^ d;
137    common(f, a, b, m, s, k)
138}
139
140#[inline]
141fn round4(a: u32, b: u32, c: u32, d: u32, m: u32, s: u32, k: u32) -> u32 {
142    let f = c ^ (b | !d);
143    common(f, a, b, m, s, k)
144}
145
146#[inline]
147fn common(f: u32, a: u32, b: u32, m: u32, s: u32, k: u32) -> u32 {
148    f.wrapping_add(a).wrapping_add(k).wrapping_add(m).rotate_left(s).wrapping_add(b)
149}
150
151#[cfg(feature = "simd")]
152pub mod simd {
153    use std::array;
154    use std::simd::num::SimdUint as _;
155    use std::simd::{LaneCount, Simd, SupportedLaneCount};
156
157    #[inline]
158    #[expect(clippy::too_many_lines)]
159    pub fn hash<const N: usize>(
160        buffers: &mut [[u8; 64]],
161        size: usize,
162    ) -> ([u32; N], [u32; N], [u32; N], [u32; N])
163    where
164        LaneCount<N>: SupportedLaneCount,
165    {
166        // Assume all buffers are the same size.
167        let end = 64 - 8;
168        let bits = size * 8;
169
170        for buffer in buffers.iter_mut() {
171            buffer[size] = 0x80;
172            buffer[end..].copy_from_slice(&bits.to_le_bytes());
173        }
174
175        let mut a0: Simd<u32, N> = Simd::splat(0x67452301);
176        let mut b0: Simd<u32, N> = Simd::splat(0xefcdab89);
177        let mut c0: Simd<u32, N> = Simd::splat(0x98badcfe);
178        let mut d0: Simd<u32, N> = Simd::splat(0x10325476);
179
180        let mut a = a0;
181        let mut b = b0;
182        let mut c = c0;
183        let mut d = d0;
184
185        let m0 = message(buffers, 0);
186        a = round1(a, b, c, d, m0, 7, 0xd76aa478);
187        let m1 = message(buffers, 1);
188        d = round1(d, a, b, c, m1, 12, 0xe8c7b756);
189        let m2 = message(buffers, 2);
190        c = round1(c, d, a, b, m2, 17, 0x242070db);
191        let m3 = message(buffers, 3);
192        b = round1(b, c, d, a, m3, 22, 0xc1bdceee);
193        let m4 = message(buffers, 4);
194        a = round1(a, b, c, d, m4, 7, 0xf57c0faf);
195        let m5 = message(buffers, 5);
196        d = round1(d, a, b, c, m5, 12, 0x4787c62a);
197        let m6 = message(buffers, 6);
198        c = round1(c, d, a, b, m6, 17, 0xa8304613);
199        let m7 = message(buffers, 7);
200        b = round1(b, c, d, a, m7, 22, 0xfd469501);
201        let m8 = message(buffers, 8);
202        a = round1(a, b, c, d, m8, 7, 0x698098d8);
203        let m9 = message(buffers, 9);
204        d = round1(d, a, b, c, m9, 12, 0x8b44f7af);
205        let m10 = message(buffers, 10);
206        c = round1(c, d, a, b, m10, 17, 0xffff5bb1);
207        let m11 = message(buffers, 11);
208        b = round1(b, c, d, a, m11, 22, 0x895cd7be);
209        let m12 = message(buffers, 12);
210        a = round1(a, b, c, d, m12, 7, 0x6b901122);
211        let m13 = message(buffers, 13);
212        d = round1(d, a, b, c, m13, 12, 0xfd987193);
213        let m14 = message(buffers, 14);
214        c = round1(c, d, a, b, m14, 17, 0xa679438e);
215        let m15 = message(buffers, 15);
216        b = round1(b, c, d, a, m15, 22, 0x49b40821);
217
218        a = round2(a, b, c, d, m1, 5, 0xf61e2562);
219        d = round2(d, a, b, c, m6, 9, 0xc040b340);
220        c = round2(c, d, a, b, m11, 14, 0x265e5a51);
221        b = round2(b, c, d, a, m0, 20, 0xe9b6c7aa);
222        a = round2(a, b, c, d, m5, 5, 0xd62f105d);
223        d = round2(d, a, b, c, m10, 9, 0x02441453);
224        c = round2(c, d, a, b, m15, 14, 0xd8a1e681);
225        b = round2(b, c, d, a, m4, 20, 0xe7d3fbc8);
226        a = round2(a, b, c, d, m9, 5, 0x21e1cde6);
227        d = round2(d, a, b, c, m14, 9, 0xc33707d6);
228        c = round2(c, d, a, b, m3, 14, 0xf4d50d87);
229        b = round2(b, c, d, a, m8, 20, 0x455a14ed);
230        a = round2(a, b, c, d, m13, 5, 0xa9e3e905);
231        d = round2(d, a, b, c, m2, 9, 0xfcefa3f8);
232        c = round2(c, d, a, b, m7, 14, 0x676f02d9);
233        b = round2(b, c, d, a, m12, 20, 0x8d2a4c8a);
234
235        a = round3(a, b, c, d, m5, 4, 0xfffa3942);
236        d = round3(d, a, b, c, m8, 11, 0x8771f681);
237        c = round3(c, d, a, b, m11, 16, 0x6d9d6122);
238        b = round3(b, c, d, a, m14, 23, 0xfde5380c);
239        a = round3(a, b, c, d, m1, 4, 0xa4beea44);
240        d = round3(d, a, b, c, m4, 11, 0x4bdecfa9);
241        c = round3(c, d, a, b, m7, 16, 0xf6bb4b60);
242        b = round3(b, c, d, a, m10, 23, 0xbebfbc70);
243        a = round3(a, b, c, d, m13, 4, 0x289b7ec6);
244        d = round3(d, a, b, c, m0, 11, 0xeaa127fa);
245        c = round3(c, d, a, b, m3, 16, 0xd4ef3085);
246        b = round3(b, c, d, a, m6, 23, 0x04881d05);
247        a = round3(a, b, c, d, m9, 4, 0xd9d4d039);
248        d = round3(d, a, b, c, m12, 11, 0xe6db99e5);
249        c = round3(c, d, a, b, m15, 16, 0x1fa27cf8);
250        b = round3(b, c, d, a, m2, 23, 0xc4ac5665);
251
252        a = round4(a, b, c, d, m0, 6, 0xf4292244);
253        d = round4(d, a, b, c, m7, 10, 0x432aff97);
254        c = round4(c, d, a, b, m14, 15, 0xab9423a7);
255        b = round4(b, c, d, a, m5, 21, 0xfc93a039);
256        a = round4(a, b, c, d, m12, 6, 0x655b59c3);
257        d = round4(d, a, b, c, m3, 10, 0x8f0ccc92);
258        c = round4(c, d, a, b, m10, 15, 0xffeff47d);
259        b = round4(b, c, d, a, m1, 21, 0x85845dd1);
260        a = round4(a, b, c, d, m8, 6, 0x6fa87e4f);
261        d = round4(d, a, b, c, m15, 10, 0xfe2ce6e0);
262        c = round4(c, d, a, b, m6, 15, 0xa3014314);
263        b = round4(b, c, d, a, m13, 21, 0x4e0811a1);
264        a = round4(a, b, c, d, m4, 6, 0xf7537e82);
265        d = round4(d, a, b, c, m11, 10, 0xbd3af235);
266        c = round4(c, d, a, b, m2, 15, 0x2ad7d2bb);
267        b = round4(b, c, d, a, m9, 21, 0xeb86d391);
268
269        a0 += a;
270        b0 += b;
271        c0 += c;
272        d0 += d;
273
274        (
275            a0.swap_bytes().to_array(),
276            b0.swap_bytes().to_array(),
277            c0.swap_bytes().to_array(),
278            d0.swap_bytes().to_array(),
279        )
280    }
281
282    #[inline]
283    fn message<const N: usize>(buffers: &mut [[u8; 64]], i: usize) -> Simd<u32, N>
284    where
285        LaneCount<N>: SupportedLaneCount,
286    {
287        let start = 4 * i;
288        let end = start + 4;
289        Simd::from_array(array::from_fn(|lane| {
290            let slice = &buffers[lane][start..end];
291            u32::from_le_bytes(slice.try_into().unwrap())
292        }))
293    }
294
295    #[inline]
296    fn round1<const N: usize>(
297        a: Simd<u32, N>,
298        b: Simd<u32, N>,
299        c: Simd<u32, N>,
300        d: Simd<u32, N>,
301        m: Simd<u32, N>,
302        s: u32,
303        k: u32,
304    ) -> Simd<u32, N>
305    where
306        LaneCount<N>: SupportedLaneCount,
307    {
308        let f = (b & c) | (!b & d);
309        common(f, a, b, m, s, k)
310    }
311
312    #[inline]
313    fn round2<const N: usize>(
314        a: Simd<u32, N>,
315        b: Simd<u32, N>,
316        c: Simd<u32, N>,
317        d: Simd<u32, N>,
318        m: Simd<u32, N>,
319        s: u32,
320        k: u32,
321    ) -> Simd<u32, N>
322    where
323        LaneCount<N>: SupportedLaneCount,
324    {
325        let f = (b & d) | (c & !d);
326        common(f, a, b, m, s, k)
327    }
328
329    #[inline]
330    fn round3<const N: usize>(
331        a: Simd<u32, N>,
332        b: Simd<u32, N>,
333        c: Simd<u32, N>,
334        d: Simd<u32, N>,
335        m: Simd<u32, N>,
336        s: u32,
337        k: u32,
338    ) -> Simd<u32, N>
339    where
340        LaneCount<N>: SupportedLaneCount,
341    {
342        let f = b ^ c ^ d;
343        common(f, a, b, m, s, k)
344    }
345
346    #[inline]
347    fn round4<const N: usize>(
348        a: Simd<u32, N>,
349        b: Simd<u32, N>,
350        c: Simd<u32, N>,
351        d: Simd<u32, N>,
352        m: Simd<u32, N>,
353        s: u32,
354        k: u32,
355    ) -> Simd<u32, N>
356    where
357        LaneCount<N>: SupportedLaneCount,
358    {
359        let f = c ^ (b | !d);
360        common(f, a, b, m, s, k)
361    }
362
363    #[inline]
364    fn common<const N: usize>(
365        f: Simd<u32, N>,
366        a: Simd<u32, N>,
367        b: Simd<u32, N>,
368        m: Simd<u32, N>,
369        s: u32,
370        k: u32,
371    ) -> Simd<u32, N>
372    where
373        LaneCount<N>: SupportedLaneCount,
374    {
375        let k = Simd::splat(k);
376        let first = f + a + k + m;
377        let second = (first << s) | (first >> (32 - s));
378        second + b
379    }
380}