1
use std::io::{Read, Write};
2
use std::num::TryFromIntError;
3

            
4
use crate::Variable;
5

            
6
/// An unsigned integer value
7
///
8
/// This type encodes values in the range `0..2.pow(124)` by using the first 4
9
/// bits to denote an unsigned byte `length`. This length ranges from `0..=15`.
10
/// The remaining 4 bits of the first byte and any additional bytes are then
11
/// used to store the integer in big-endian encoding.
12
38
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Ord, PartialOrd)]
13
pub struct Unsigned(pub(crate) u128);
14

            
15
impl Unsigned {
16
131273
    pub(crate) fn encode_be_bytes<W: Write, const N: usize>(
17
131273
        mut value: [u8; N],
18
131273
        mut output: W,
19
131273
    ) -> std::io::Result<usize> {
20
131273
        // Because we encode "extra bytes" in 4 bits, we must keep the extra
21
131273
        // bytes to 15 or less. This only affects 128-bit encoding
22
131273
        if N == 16 && value[0] >> 4 != 0 {
23
1
            return Err(std::io::Error::from(std::io::ErrorKind::InvalidData));
24
131272
        }
25
131272

            
26
131272
        let (total_length, extra_bytes) = value
27
131272
            .iter()
28
131272
            .enumerate()
29
133163
            .find_map(|(index, &byte)| {
30
133163
                if byte > 0 {
31
131266
                    let extra_bytes = N - 1 - index;
32
131266
                    if byte < 16 {
33
7805
                        Some((extra_bytes + 1, extra_bytes))
34
                    } else {
35
123461
                        Some((extra_bytes + 2, extra_bytes + 1))
36
                    }
37
                } else {
38
1897
                    None
39
                }
40
133163
            })
41
131272
            .unwrap_or((0, 0));
42
131272
        let total_length = total_length.max(1);
43
131272

            
44
131272
        let extra_bytes_encoded = (extra_bytes as u8) << 4;
45
131272
        if total_length > N {
46
            // We need an extra byte to store the length information
47
122886
            output.write_all(&[extra_bytes_encoded])?;
48
122886
            output.write_all(&value)?;
49
        } else {
50
8386
            value[N - total_length] |= extra_bytes_encoded;
51
8386
            output.write_all(&value[N - total_length..])?;
52
        }
53

            
54
131272
        Ok(total_length)
55
131273
    }
56

            
57
200
    pub(crate) fn decode_variable_bytes<R: Read, const N: usize>(
58
200
        mut input: R,
59
200
    ) -> std::io::Result<[u8; N]> {
60
200
        let mut buffer = [0_u8; N];
61
200
        input.read_exact(&mut buffer[0..1])?;
62
200
        let first_byte = buffer[0];
63
200
        let length = (first_byte >> 4) as usize;
64
200
        if length > N {
65
1
            return Err(std::io::Error::from(std::io::ErrorKind::InvalidData));
66
199
        }
67
199
        input.read_exact(&mut buffer[N - length..])?;
68
199
        match N - length {
69
5
            0 => {
70
5
                // We overwrite the first byte with the read operation, so we need
71
5
                // to fill back in the bits from the first byte.
72
5
                buffer[0] |= first_byte & 0b1111;
73
5
            }
74
19
            1 => {
75
19
                // Clear the top 4 bits of the first byte. The lower 4 bits may
76
19
                // still have data in them.
77
19
                buffer[0] &= 0b1111;
78
19
            }
79
175
            _ => {
80
175
                // Move the first byte's data into the last byte read, then
81
175
                // clear our initial byte.
82
175
                buffer[N - 1 - length] |= first_byte & 0b1111;
83
175
                buffer[0] = 0;
84
175
            }
85
        }
86
199
        Ok(buffer)
87
200
    }
88
}
89

            
90
impl Variable for Unsigned {
91
38
    fn encode_variable<W: Write>(&self, output: W) -> std::io::Result<usize> {
92
38
        Self::encode_be_bytes(self.0.to_be_bytes(), output)
93
38
    }
94

            
95
38
    fn decode_variable<R: Read>(input: R) -> std::io::Result<Self> {
96
38
        let buffer = Self::decode_variable_bytes(input)?;
97

            
98
38
        Ok(Self(u128::from_be_bytes(buffer)))
99
38
    }
100
}
101

            
102
macro_rules! impl_primitive_from_varint {
103
    ($ty:ty) => {
104
        impl TryFrom<Unsigned> for $ty {
105
            type Error = TryFromIntError;
106

            
107
1
            fn try_from(value: Unsigned) -> Result<Self, Self::Error> {
108
1
                value.0.try_into()
109
1
            }
110
        }
111
    };
112
}
113

            
114
macro_rules! impl_varint_from_primitive {
115
    ($ty:ty, $dest:ty) => {
116
        impl From<$ty> for Unsigned {
117
37
            fn from(value: $ty) -> Self {
118
37
                Self(<$dest>::from(value))
119
37
            }
120
        }
121
    };
122
}
123

            
124
impl_varint_from_primitive!(u8, u128);
125
impl_varint_from_primitive!(u16, u128);
126
impl_varint_from_primitive!(u32, u128);
127
impl_varint_from_primitive!(u64, u128);
128
impl_varint_from_primitive!(u128, u128);
129

            
130
impl_primitive_from_varint!(u8);
131
impl_primitive_from_varint!(u16);
132
impl_primitive_from_varint!(u32);
133
impl_primitive_from_varint!(u64);
134
impl_primitive_from_varint!(usize);
135

            
136
impl From<Unsigned> for u128 {
137
1
    fn from(value: Unsigned) -> Self {
138
1
        value.0
139
1
    }
140
}
141

            
142
impl From<usize> for Unsigned {
143
3
    fn from(value: usize) -> Self {
144
3
        Self(value as u128)
145
3
    }
146
}