1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
//! Packet Decoder
//!
//! [`BitStream`] is the key to making this problem tractable. It works like an iterator, allowing
//! us to consume an arbitrary number of bits from the input and convert this to a number.
//!
//! It works by maintaining an internal `u64` buffer. If the requested number of bits is larger than
//! the buffer's current size then additional bits are added to the buffer 4 at a time from each
//! hexadecimal digit of the input data.
//!
//! Additionally it keeps track of the total number of bits consumed so far. This is needed when
//! parsing packets that use the total length in bits to determine sub-packets.
//!
//! The decoded packet data is stored as a tree-like struct allowing recursive solutions to part 1
//! and part 2 to reuse the same decoded input.
use std::str::Bytes;

struct BitStream<'a> {
    available: u64,
    bits: u64,
    read: u64,
    iter: Bytes<'a>,
}

impl BitStream<'_> {
    fn from(s: &str) -> BitStream<'_> {
        BitStream { available: 0, bits: 0, read: 0, iter: s.bytes() }
    }

    fn next(&mut self, amount: u64) -> u64 {
        while self.available < amount {
            self.available += 4;
            self.bits = (self.bits << 4) | self.hex_to_binary();
        }

        self.available -= amount;
        self.read += amount;

        let mask = (1 << amount) - 1;
        (self.bits >> self.available) & mask
    }

    fn hex_to_binary(&mut self) -> u64 {
        let hex_digit = self.iter.next().unwrap();

        if hex_digit.is_ascii_digit() {
            (hex_digit - 48) as u64
        } else {
            (hex_digit - 55) as u64
        }
    }
}

pub enum Packet {
    Literal { version: u64, type_id: u64, value: u64 },
    Operator { version: u64, type_id: u64, packets: Vec<Packet> },
}

impl Packet {
    fn from(bit_stream: &mut BitStream<'_>) -> Packet {
        let version = bit_stream.next(3);
        let type_id = bit_stream.next(3);

        if type_id == 4 {
            let mut todo = true;
            let mut value = 0;

            while todo {
                todo = bit_stream.next(1) == 1;
                value = (value << 4) | bit_stream.next(4);
            }

            Packet::Literal { version, type_id, value }
        } else {
            let mut packets = Vec::new();

            if bit_stream.next(1) == 0 {
                let target = bit_stream.next(15) + bit_stream.read;
                while bit_stream.read < target {
                    packets.push(Self::from(bit_stream));
                }
            } else {
                let sub_packets = bit_stream.next(11);
                for _ in 0..sub_packets {
                    packets.push(Self::from(bit_stream));
                }
            }

            Packet::Operator { version, type_id, packets }
        }
    }
}

pub fn parse(input: &str) -> Packet {
    let mut bit_stream = BitStream::from(input);
    Packet::from(&mut bit_stream)
}

pub fn part1(packet: &Packet) -> u64 {
    fn helper(packet: &Packet) -> u64 {
        match packet {
            Packet::Literal { version, .. } => *version,
            Packet::Operator { version, packets, .. } => {
                *version + packets.iter().map(helper).sum::<u64>()
            }
        }
    }

    helper(packet)
}

pub fn part2(packet: &Packet) -> u64 {
    fn helper(packet: &Packet) -> u64 {
        match packet {
            Packet::Literal { value, .. } => *value,
            Packet::Operator { type_id, packets, .. } => {
                let mut iter = packets.iter().map(helper);
                match type_id {
                    0 => iter.sum(),
                    1 => iter.product(),
                    2 => iter.min().unwrap(),
                    3 => iter.max().unwrap(),
                    5 => (iter.next().unwrap() > iter.next().unwrap()) as u64,
                    6 => (iter.next().unwrap() < iter.next().unwrap()) as u64,
                    7 => (iter.next().unwrap() == iter.next().unwrap()) as u64,
                    _ => unreachable!(),
                }
            }
        }
    }

    helper(packet)
}