1
use std::{
2
    borrow::Cow,
3
    fmt,
4
    ops::{Deref, DerefMut},
5
};
6

            
7
use serde::{
8
    de::{Error, SeqAccess, Visitor},
9
    Deserialize, Serialize,
10
};
11

            
12
use crate::{print_bytes, ArcBytes, OwnedBytes};
13

            
14
impl<'a> Serialize for ArcBytes<'a> {
15
3
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
16
3
    where
17
3
        S: serde::Serializer,
18
3
    {
19
3
        serializer.serialize_bytes(self.as_slice())
20
3
    }
21
}
22

            
23
impl<'a, 'de: 'a> Deserialize<'de> for ArcBytes<'a> {
24
5
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
25
5
    where
26
5
        D: serde::Deserializer<'de>,
27
5
    {
28
5
        deserializer
29
5
            .deserialize_bytes(BufferVisitor)
30
5
            .map(Self::from)
31
5
    }
32
}
33

            
34
impl Serialize for OwnedBytes {
35
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
36
    where
37
        S: serde::Serializer,
38
    {
39
        serializer.serialize_bytes(self.0.as_slice())
40
    }
41
}
42

            
43
impl<'de> Deserialize<'de> for OwnedBytes {
44
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
45
    where
46
        D: serde::Deserializer<'de>,
47
    {
48
        deserializer
49
            .deserialize_byte_buf(BufferVisitor)
50
            .map(|bytes| match bytes {
51
                Cow::Borrowed(borrowed) => Self(ArcBytes::owned(borrowed.to_vec())),
52
                Cow::Owned(vec) => Self(ArcBytes::owned(vec)),
53
            })
54
    }
55
}
56

            
57
/// A `Vec<u8>` wrapper that supports serializing efficiently in Serde.
58
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
59
pub struct Bytes(pub Vec<u8>);
60

            
61
impl Bytes {
62
    /// Returns the underlying Vec.
63
    #[allow(clippy::missing_const_for_fn)] // false positive
64
    #[must_use]
65
    pub fn into_vec(self) -> Vec<u8> {
66
        self.0
67
    }
68
}
69

            
70
impl_std_cmp!(Bytes);
71

            
72
impl<'a> std::fmt::Debug for Bytes {
73
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74
        let slice = self.as_slice();
75
        write!(f, "Bytes {{ length: {}, bytes: [", slice.len())?;
76
        print_bytes(slice, f)?;
77
        f.write_str("] }")
78
    }
79
}
80

            
81
impl From<Vec<u8>> for Bytes {
82
    fn from(buffer: Vec<u8>) -> Self {
83
        Self(buffer)
84
    }
85
}
86

            
87
impl<'a> From<&'a [u8]> for Bytes {
88
    fn from(buffer: &'a [u8]) -> Self {
89
        Self(buffer.to_vec())
90
    }
91
}
92

            
93
impl<const N: usize> From<[u8; N]> for Bytes {
94
    fn from(buffer: [u8; N]) -> Self {
95
        Self(buffer.to_vec())
96
    }
97
}
98

            
99
impl<'a> From<ArcBytes<'a>> for Bytes {
100
    fn from(buffer: ArcBytes<'a>) -> Self {
101
        Self(buffer.into_vec())
102
    }
103
}
104

            
105
impl<'a> From<Bytes> for ArcBytes<'a> {
106
    fn from(bytes: Bytes) -> Self {
107
        ArcBytes::owned(bytes.0)
108
    }
109
}
110

            
111
impl Deref for Bytes {
112
    type Target = Vec<u8>;
113

            
114
1
    fn deref(&self) -> &Self::Target {
115
1
        &self.0
116
1
    }
117
}
118

            
119
impl DerefMut for Bytes {
120
    fn deref_mut(&mut self) -> &mut Self::Target {
121
        &mut self.0
122
    }
123
}
124

            
125
impl Serialize for Bytes {
126
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
127
    where
128
        S: serde::Serializer,
129
    {
130
        serializer.serialize_bytes(self.0.as_slice())
131
    }
132
}
133

            
134
impl<'de> Deserialize<'de> for Bytes {
135
1
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
136
1
    where
137
1
        D: serde::Deserializer<'de>,
138
1
    {
139
1
        deserializer
140
1
            .deserialize_byte_buf(BufferVisitor)
141
1
            .map(|bytes| match bytes {
142
1
                Cow::Borrowed(borrowed) => Self(borrowed.to_vec()),
143
                Cow::Owned(vec) => Self(vec),
144
1
            })
145
1
    }
146
}
147

            
148
/// A `Cow<'a, [u8]>` wrapper that supports serializing efficiently in Serde.
149
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
150
pub struct CowBytes<'a>(pub Cow<'a, [u8]>);
151

            
152
impl_std_cmp!(CowBytes<'a>);
153

            
154
impl<'a> CowBytes<'a> {
155
    /// Returns the underlying Cow.
156
    #[allow(clippy::missing_const_for_fn)] // false positive
157
    #[must_use]
158
    pub fn into_cow(self) -> Cow<'a, [u8]> {
159
        self.0
160
    }
161

            
162
    /// Returns the underlying Vec inside of the Cow, or clones the borrowed bytes into a new Vec..
163
    #[allow(clippy::missing_const_for_fn)] // false positive
164
    #[must_use]
165
    pub fn into_vec(self) -> Vec<u8> {
166
        match self.0 {
167
            Cow::Borrowed(bytes) => bytes.to_vec(),
168
            Cow::Owned(vec) => vec,
169
        }
170
    }
171

            
172
    /// Returns a slice of the contained data.
173
    #[must_use]
174
    pub fn as_slice(&self) -> &[u8] {
175
        &self.0
176
    }
177
}
178

            
179
impl<'a> std::fmt::Debug for CowBytes<'a> {
180
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181
        let slice = &self[..];
182
        write!(
183
            f,
184
            "CowBytes {{ length: {}, bytes: {}[",
185
            slice.len(),
186
            if matches!(self.0, Cow::Borrowed(_)) {
187
                "&"
188
            } else {
189
                ""
190
            }
191
        )?;
192
        print_bytes(slice, f)?;
193
        f.write_str("] }")
194
    }
195
}
196

            
197
impl<'a> From<Bytes> for CowBytes<'a> {
198
    fn from(bytes: Bytes) -> Self {
199
        CowBytes(Cow::Owned(bytes.0))
200
    }
201
}
202

            
203
impl<'a> From<CowBytes<'a>> for Bytes {
204
    fn from(bytes: CowBytes<'a>) -> Self {
205
        match bytes.0 {
206
            Cow::Borrowed(bytes) => Self(bytes.to_vec()),
207
            Cow::Owned(vec) => Self(vec),
208
        }
209
    }
210
}
211

            
212
impl<'a> From<CowBytes<'a>> for ArcBytes<'a> {
213
    fn from(bytes: CowBytes<'a>) -> Self {
214
        ArcBytes::from(bytes.0)
215
    }
216
}
217

            
218
impl<'a> From<Vec<u8>> for CowBytes<'a> {
219
    fn from(buffer: Vec<u8>) -> Self {
220
        Self(Cow::Owned(buffer))
221
    }
222
}
223

            
224
impl<'a> From<&'a [u8]> for CowBytes<'a> {
225
    fn from(buffer: &'a [u8]) -> Self {
226
        Self(Cow::Borrowed(buffer))
227
    }
228
}
229

            
230
impl<'a, const N: usize> From<[u8; N]> for CowBytes<'a> {
231
    fn from(buffer: [u8; N]) -> Self {
232
        Self::from(buffer.to_vec())
233
    }
234
}
235

            
236
impl<'a, const N: usize> From<&'a [u8; N]> for CowBytes<'a> {
237
    fn from(buffer: &'a [u8; N]) -> Self {
238
        Self(Cow::Borrowed(buffer))
239
    }
240
}
241

            
242
impl<'a> Deref for CowBytes<'a> {
243
    type Target = Cow<'a, [u8]>;
244

            
245
2
    fn deref(&self) -> &Self::Target {
246
2
        &self.0
247
2
    }
248
}
249

            
250
impl<'a> DerefMut for CowBytes<'a> {
251
    fn deref_mut(&mut self) -> &mut Self::Target {
252
        &mut self.0
253
    }
254
}
255

            
256
impl<'a> Serialize for CowBytes<'a> {
257
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
258
    where
259
        S: serde::Serializer,
260
    {
261
        serializer.serialize_bytes(&self.0)
262
    }
263
}
264

            
265
impl<'de: 'a, 'a> Deserialize<'de> for CowBytes<'a> {
266
2
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267
2
    where
268
2
        D: serde::Deserializer<'de>,
269
2
    {
270
2
        deserializer
271
2
            .deserialize_byte_buf(BufferVisitor)
272
2
            .map(CowBytes)
273
2
    }
274
}
275

            
276
struct BufferVisitor;
277

            
278
impl<'de> Visitor<'de> for BufferVisitor {
279
    type Value = Cow<'de, [u8]>;
280

            
281
    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
282
        formatter.write_str("bytes")
283
    }
284

            
285
1
    fn visit_seq<V>(self, mut visitor: V) -> Result<Self::Value, V::Error>
286
1
    where
287
1
        V: SeqAccess<'de>,
288
1
    {
289
1
        let mut bytes = if let Some(len) = visitor.size_hint() {
290
            Vec::with_capacity(len)
291
        } else {
292
1
            Vec::default()
293
        };
294

            
295
4
        while let Some(b) = visitor.next_element()? {
296
3
            bytes.push(b);
297
3
        }
298

            
299
1
        Ok(Cow::Owned(bytes))
300
1
    }
301

            
302
1
    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
303
1
    where
304
1
        E: Error,
305
1
    {
306
1
        Ok(Cow::Owned(v.to_vec()))
307
1
    }
308

            
309
4
    fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
310
4
    where
311
4
        E: Error,
312
4
    {
313
4
        Ok(Cow::Borrowed(v))
314
4
    }
315

            
316
2
    fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
317
2
    where
318
2
        E: Error,
319
2
    {
320
2
        Ok(Cow::Owned(v))
321
2
    }
322

            
323
    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
324
    where
325
        E: Error,
326
    {
327
        Ok(Cow::Owned(v.as_bytes().to_vec()))
328
    }
329

            
330
    fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
331
    where
332
        E: Error,
333
    {
334
        Ok(Cow::Borrowed(v.as_bytes()))
335
    }
336

            
337
    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
338
    where
339
        E: Error,
340
    {
341
        Ok(Cow::Owned(v.into_bytes()))
342
    }
343
}
344

            
345
1
#[test]
346
1
fn serialization_tests() {
347
1
    use super::Bytes;
348
1

            
349
1
    let simple_buffer = vec![1_u8, 2, 3];
350
1
    let simple_buffer = simple_buffer.as_slice();
351
1
    let simple_arcbytes = ArcBytes::from(simple_buffer);
352
1

            
353
1
    // deserialize_seq
354
1
    let u8_sequence_bytes = pot::to_vec(&simple_buffer).unwrap();
355
1
    let buffer = pot::from_slice::<ArcBytes<'_>>(&u8_sequence_bytes).unwrap();
356
1
    assert_eq!(buffer, simple_buffer);
357
1
    assert!(matches!(buffer.buffer, Bytes::Owned(_)));
358

            
359
    // deserialize_borrowed_bytes
360
1
    let actual_bytes = pot::to_vec(&simple_arcbytes).unwrap();
361
1
    let buffer = pot::from_slice::<ArcBytes<'_>>(&actual_bytes).unwrap();
362
1
    assert_eq!(buffer, simple_buffer);
363
1
    assert!(matches!(buffer.buffer, Bytes::Borrowed(_)));
364

            
365
    // deserialize_byte_buf
366
1
    let json = serde_json::to_string(&simple_arcbytes).unwrap();
367
1
    let buffer = serde_json::from_str::<ArcBytes<'_>>(&json).unwrap();
368
1
    assert_eq!(buffer, simple_buffer);
369
1
    assert!(matches!(buffer.buffer, Bytes::Owned(_)));
370

            
371
    // deserialize_str
372
1
    let str_bytes = pot::to_vec(&"hello").unwrap();
373
1
    let buffer = pot::from_slice::<ArcBytes<'_>>(&str_bytes).unwrap();
374
1
    assert_eq!(buffer, b"hello");
375
1
    assert!(matches!(buffer.buffer, Bytes::Borrowed(_)));
376

            
377
    // deserialize_string
378
1
    let buffer = serde_json::from_str::<ArcBytes<'_>>(r#""hello\u0020world""#).unwrap();
379
1
    assert_eq!(buffer, b"hello world");
380
1
    assert!(matches!(buffer.buffer, Bytes::Owned(_)));
381

            
382
    // Deserialize `Bytes`
383
1
    let actual_bytes = pot::to_vec(&simple_arcbytes).unwrap();
384
1
    let buffer = pot::from_slice::<self::Bytes>(&actual_bytes).unwrap();
385
1
    assert_eq!(buffer.as_slice(), simple_buffer);
386

            
387
    // Deserialize `CowBytes`
388
1
    let buffer = pot::from_slice::<self::CowBytes<'_>>(&actual_bytes).unwrap();
389
1
    assert_eq!(&buffer[..], simple_buffer);
390
1
    assert!(matches!(buffer.0, Cow::Borrowed(_)));
391
1
    let buffer = pot::from_slice::<self::CowBytes<'_>>(&u8_sequence_bytes).unwrap();
392
1
    assert_eq!(&buffer[..], simple_buffer);
393
1
    assert!(matches!(buffer.0, Cow::Owned(_)));
394
1
}