1
use std::io::{Read, Write};
2
use std::num::TryFromIntError;
3

            
4
use crate::Variable;
5

            
6
/// A signed integer value.
7
///
8
/// This type encodes values in the range `-2.pow(123)..2.pow(123)` by using the
9
/// first 5 bits to denote a signed byte `length`. This length ranges from
10
/// `-15..=15`. The number of bytes read is always absolute, but the sign of the
11
/// length is used to determine the overall sign of the encoded value. The
12
/// remaining 3 bits of the first byte and any additional bytes are then
13
/// used to store the integer in big-endian encoding.
14
58
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Ord, PartialOrd)]
15
pub struct Signed(i128);
16

            
17
impl Signed {
18
131437
    pub(crate) fn encode_be_bytes<W: Write, const N: usize>(
19
131437
        mut value: [u8; N],
20
131437
        mut output: W,
21
131437
    ) -> std::io::Result<usize> {
22
131437
        let check_bits = if N == 16 {
23
            // We reserve 5 bits for a signed 4 bit number, ranging from -16..=15.
24
325
            let reserved = value[0] >> 3;
25
325
            match reserved {
26
161
                0 => 0,
27
162
                0b11111 => 0xFF,
28
2
                _ => return Err(std::io::Error::from(std::io::ErrorKind::InvalidData)),
29
            }
30
131112
        } else if value[0] >> 7 == 0 {
31
            // positive
32
65555
            0
33
        } else {
34
            // negative
35
65557
            0xff
36
        };
37

            
38
131435
        let (total_length, extra_bytes) = value
39
131435
            .iter()
40
131435
            .enumerate()
41
135068
            .find_map(|(index, &byte)| {
42
135068
                if byte == check_bits {
43
3649
                    None
44
                } else {
45
131419
                    let extra_bytes = N - 1 - index;
46
131419
                    if byte >> 3 == check_bits >> 3 {
47
7315
                        Some((extra_bytes + 1, extra_bytes))
48
                    } else {
49
124104
                        Some((extra_bytes + 2, extra_bytes + 1))
50
                    }
51
                }
52
135068
            })
53
131435
            .unwrap_or((0, 0));
54
131435
        let total_length = total_length.max(1);
55

            
56
131435
        let length_header = if check_bits == 0 {
57
65716
            extra_bytes + 2_usize.pow(4)
58
        } else {
59
65719
            2_usize.pow(4) - extra_bytes - 1
60
        };
61

            
62
131435
        let encoded_length_header = (length_header as u8) << 3;
63
131435
        if total_length > N {
64
            // We can't fit the length in the buffer.
65
122892
            output.write_all(&[encoded_length_header | (check_bits >> 5)])?;
66
122892
            output.write_all(&value)?;
67
        } else {
68
            // Clear the top bits to prepare for the header
69
8543
            value[N - total_length] &= 0b111;
70
8543
            // Set the length bits
71
8543
            value[N - total_length] |= encoded_length_header;
72
8543

            
73
8543
            output.write_all(&value[N - total_length..])?;
74
        }
75

            
76
131435
        Ok(total_length)
77
131437
    }
78

            
79
305
    pub(crate) fn decode_variable_bytes<R: Read, const N: usize>(
80
305
        mut input: R,
81
305
    ) -> std::io::Result<[u8; N]> {
82
305
        let mut buffer = [0_u8; N];
83
305
        input.read_exact(&mut buffer[0..1])?;
84
305
        let first_byte = buffer[0];
85
305
        let encoded_length = first_byte as usize >> 3;
86
305
        let (negative, length) = if encoded_length >= 2_usize.pow(4) {
87
152
            (false, encoded_length - 2_usize.pow(4))
88
        } else {
89
153
            (true, 2_usize.pow(4) - (encoded_length + 1))
90
        };
91
305
        if length > N {
92
2
            return Err(std::io::Error::from(std::io::ErrorKind::InvalidData));
93
303
        }
94
303

            
95
303
        input.read_exact(&mut buffer[N - length..])?;
96

            
97
303
        match N - length {
98
            0 => {
99
                // We overwrote our first byte, but the first byte has some 3
100
                // bits of data we need to preserve.
101
10
                let mut first_bits = first_byte & 0b111;
102
10
                if negative {
103
5
                    first_bits ^= 0b111;
104
5
                }
105
10
                buffer[0] |= first_bits << 5;
106
            }
107
            1 => {
108
                // Clear the top 3 bits of the top byte, and negate if needed.
109
32
                buffer[0] &= 0b111;
110
32
                if negative {
111
16
                    buffer[0] ^= 0b1111_1000;
112
16
                }
113
            }
114
            _ => {
115
261
                buffer[N - 1 - length] |= first_byte & 0b111;
116
261
                if negative {
117
131
                    buffer[N - 1 - length] ^= 0b1111_1000;
118
132
                }
119
261
                buffer[0] = 0;
120
            }
121
        }
122

            
123
303
        if negative && N > 1 {
124
149
            let bytes_to_negate = N - length;
125
149
            // We know we can skip updating the last byte that contained data.
126
149
            if bytes_to_negate > 1 {
127
927
                for byte in &mut buffer[0..bytes_to_negate - 1] {
128
927
                    *byte ^= 0xFF;
129
927
                }
130
18
            }
131
154
        }
132

            
133
303
        Ok(buffer)
134
305
    }
135
}
136

            
137
impl Variable for Signed {
138
58
    fn encode_variable<W: Write>(&self, output: W) -> std::io::Result<usize> {
139
58
        Self::encode_be_bytes(self.0.to_be_bytes(), output)
140
58
    }
141

            
142
58
    fn decode_variable<R: Read>(mut input: R) -> std::io::Result<Self> {
143
58
        let mut buffer = [0_u8; 16];
144
58
        input.read_exact(&mut buffer[0..1])?;
145

            
146
58
        let encoded_length = buffer[0] as usize >> 3;
147
58
        let (negative, length) = if encoded_length >= 2_usize.pow(4) {
148
28
            (false, encoded_length - 2_usize.pow(4))
149
        } else {
150
30
            (true, 2_usize.pow(4) - (encoded_length + 1))
151
        };
152

            
153
58
        input.read_exact(&mut buffer[16 - length..])?;
154

            
155
58
        if length < 15 {
156
54
            buffer[15 - length] |= buffer[0] & 0b111;
157
54
            if negative {
158
28
                buffer[15 - length] ^= 0b1111_1000;
159
28
            }
160
54
            buffer[0] = 0;
161
        } else {
162
4
            buffer[0] &= 0b111;
163
4
            if negative {
164
2
                buffer[0] ^= 0b1111_1000;
165
2
            }
166
        }
167

            
168
58
        if negative {
169
277
            for byte in &mut buffer[0..15 - length] {
170
277
                *byte ^= 0xFF;
171
277
            }
172
28
        }
173

            
174
58
        Ok(Self(i128::from_be_bytes(buffer)))
175
58
    }
176
}
177

            
178
macro_rules! impl_primitive_from_varint {
179
    ($ty:ty) => {
180
        impl TryFrom<Signed> for $ty {
181
            type Error = TryFromIntError;
182

            
183
1
            fn try_from(value: Signed) -> Result<Self, Self::Error> {
184
1
                value.0.try_into()
185
1
            }
186
        }
187
    };
188
}
189

            
190
macro_rules! impl_varint_from_primitive {
191
    ($ty:ty, $dest:ty) => {
192
        impl From<$ty> for Signed {
193
56
            fn from(value: $ty) -> Self {
194
56
                Self(<$dest>::from(value))
195
56
            }
196
        }
197
    };
198
}
199

            
200
impl_varint_from_primitive!(i8, i128);
201
impl_varint_from_primitive!(i16, i128);
202
impl_varint_from_primitive!(i32, i128);
203
impl_varint_from_primitive!(i64, i128);
204
impl_varint_from_primitive!(i128, i128);
205

            
206
impl_primitive_from_varint!(i8);
207
impl_primitive_from_varint!(i16);
208
impl_primitive_from_varint!(i32);
209
impl_primitive_from_varint!(i64);
210
impl_primitive_from_varint!(isize);
211

            
212
impl From<Signed> for i128 {
213
1
    fn from(value: Signed) -> Self {
214
1
        value.0
215
1
    }
216
}
217

            
218
impl From<isize> for Signed {
219
4
    fn from(value: isize) -> Self {
220
4
        Self(value as i128)
221
4
    }
222
}