aoc/year2021/
day16.rs

1//! Packet Decoder
2//!
3//! [`BitStream`] is the key to making this problem tractable. It works like an iterator, allowing
4//! us to consume an arbitrary number of bits from the input and convert this to a number.
5//!
6//! It works by maintaining an internal `u64` buffer. If the requested number of bits is larger than
7//! the buffer's current size then additional bits are added to the buffer 4 at a time from each
8//! hexadecimal digit of the input data.
9//!
10//! Additionally it keeps track of the total number of bits consumed so far. This is needed when
11//! parsing packets that use the total length in bits to determine sub-packets.
12//!
13//! The decoded packet data is stored as a tree-like struct allowing recursive solutions to part 1
14//! and part 2 to reuse the same decoded input.
15use std::str::Bytes;
16
17struct BitStream<'a> {
18    available: u64,
19    bits: u64,
20    read: u64,
21    iter: Bytes<'a>,
22}
23
24impl BitStream<'_> {
25    fn from(s: &str) -> BitStream<'_> {
26        BitStream { available: 0, bits: 0, read: 0, iter: s.bytes() }
27    }
28
29    fn next(&mut self, amount: u64) -> u64 {
30        while self.available < amount {
31            self.available += 4;
32            self.bits = (self.bits << 4) | self.hex_to_binary();
33        }
34
35        self.available -= amount;
36        self.read += amount;
37
38        let mask = (1 << amount) - 1;
39        (self.bits >> self.available) & mask
40    }
41
42    fn hex_to_binary(&mut self) -> u64 {
43        let hex_digit = self.iter.next().unwrap();
44
45        if hex_digit.is_ascii_digit() { (hex_digit - 48) as u64 } else { (hex_digit - 55) as u64 }
46    }
47}
48
49pub enum Packet {
50    Literal { version: u64, type_id: u64, value: u64 },
51    Operator { version: u64, type_id: u64, packets: Vec<Packet> },
52}
53
54impl Packet {
55    fn from(bit_stream: &mut BitStream<'_>) -> Packet {
56        let version = bit_stream.next(3);
57        let type_id = bit_stream.next(3);
58
59        if type_id == 4 {
60            let mut todo = true;
61            let mut value = 0;
62
63            while todo {
64                todo = bit_stream.next(1) == 1;
65                value = (value << 4) | bit_stream.next(4);
66            }
67
68            Packet::Literal { version, type_id, value }
69        } else {
70            let mut packets = Vec::new();
71
72            if bit_stream.next(1) == 0 {
73                let target = bit_stream.next(15) + bit_stream.read;
74                while bit_stream.read < target {
75                    packets.push(Self::from(bit_stream));
76                }
77            } else {
78                let sub_packets = bit_stream.next(11);
79                for _ in 0..sub_packets {
80                    packets.push(Self::from(bit_stream));
81                }
82            }
83
84            Packet::Operator { version, type_id, packets }
85        }
86    }
87}
88
89pub fn parse(input: &str) -> Packet {
90    let mut bit_stream = BitStream::from(input);
91    Packet::from(&mut bit_stream)
92}
93
94pub fn part1(packet: &Packet) -> u64 {
95    fn helper(packet: &Packet) -> u64 {
96        match packet {
97            Packet::Literal { version, .. } => *version,
98            Packet::Operator { version, packets, .. } => {
99                *version + packets.iter().map(helper).sum::<u64>()
100            }
101        }
102    }
103
104    helper(packet)
105}
106
107pub fn part2(packet: &Packet) -> u64 {
108    fn helper(packet: &Packet) -> u64 {
109        match packet {
110            Packet::Literal { value, .. } => *value,
111            Packet::Operator { type_id, packets, .. } => {
112                let mut iter = packets.iter().map(helper);
113                match type_id {
114                    0 => iter.sum(),
115                    1 => iter.product(),
116                    2 => iter.min().unwrap(),
117                    3 => iter.max().unwrap(),
118                    5 => (iter.next().unwrap() > iter.next().unwrap()) as u64,
119                    6 => (iter.next().unwrap() < iter.next().unwrap()) as u64,
120                    7 => (iter.next().unwrap() == iter.next().unwrap()) as u64,
121                    _ => unreachable!(),
122                }
123            }
124        }
125    }
126
127    helper(packet)
128}