1
use std::borrow::Cow;
2
use std::cmp::Ordering;
3
use std::fmt::{Debug, Display};
4
use std::io::Write;
5
use std::ops::Range;
6
use std::usize;
7

            
8
use byteorder::WriteBytesExt;
9
use serde::de::{SeqAccess, Visitor};
10
use serde::{ser, Deserialize, Serialize};
11
#[cfg(feature = "tracing")]
12
use tracing::instrument;
13

            
14
use crate::format::{self, Kind, Special, CURRENT_VERSION};
15
use crate::{Error, Result};
16

            
17
/// A Pot serializer.
18
pub struct Serializer<'a, W: WriteBytesExt> {
19
    symbol_map: SymbolMapRef<'a>,
20
    output: W,
21
    bytes_written: usize,
22
}
23

            
24
impl<'a, W: WriteBytesExt> Debug for Serializer<'a, W> {
25
614
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26
614
        f.debug_struct("Serializer")
27
614
            .field("symbol_map", &self.symbol_map)
28
614
            .field("bytes_written", &self.bytes_written)
29
614
            .finish()
30
614
    }
31
}
32

            
33
impl<'a, W: WriteBytesExt> Serializer<'a, W> {
34
    /// Returns a new serializer outputting written bytes into `output`.
35
    #[inline]
36
1155
    pub fn new(output: W) -> Result<Self> {
37
1155
        Self::new_with_symbol_map(
38
1155
            output,
39
1155
            SymbolMapRef::Ephemeral(EphemeralSymbolMap::default()),
40
1155
        )
41
1155
    }
42

            
43
1165
    fn new_with_symbol_map(mut output: W, symbol_map: SymbolMapRef<'a>) -> Result<Self> {
44
1165
        let bytes_written = format::write_header(&mut output, CURRENT_VERSION)?;
45
1165
        Ok(Self {
46
1165
            symbol_map,
47
1165
            output,
48
1165
            bytes_written,
49
1165
        })
50
1165
    }
51

            
52
2242275
    #[cfg_attr(feature = "tracing", instrument)]
53
1121235
    fn write_symbol(&mut self, symbol: &'static str) -> Result<()> {
54
        let registered_symbol = self.symbol_map.find_or_add(symbol);
55
        if registered_symbol.new {
56
            // The arg is the length followed by a 0 bit.
57
            let arg = (symbol.len() as u64) << 1;
58
            self.bytes_written += format::write_atom_header(&mut self.output, Kind::Symbol, arg)?;
59
            self.output.write_all(symbol.as_bytes())?;
60
            self.bytes_written += symbol.len();
61
        } else {
62
            // When a symbol was already emitted, just emit the id followed by a 1 bit.
63
            self.bytes_written += format::write_atom_header(
64
                &mut self.output,
65
                Kind::Symbol,
66
                u64::from((registered_symbol.id << 1) | 1),
67
            )?;
68
        }
69
        Ok(())
70
    }
71
}
72

            
73
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::Serializer for &'de mut Serializer<'a, W> {
74
    type Error = Error;
75
    type Ok = ();
76
    type SerializeMap = MapSerializer<'de, 'a, W>;
77
    type SerializeSeq = Self;
78
    type SerializeStruct = MapSerializer<'de, 'a, W>;
79
    type SerializeStructVariant = MapSerializer<'de, 'a, W>;
80
    type SerializeTuple = Self;
81
    type SerializeTupleStruct = Self;
82
    type SerializeTupleVariant = Self;
83

            
84
    #[inline]
85
1
    fn is_human_readable(&self) -> bool {
86
1
        false
87
1
    }
88

            
89
4
    #[cfg_attr(feature = "tracing", instrument)]
90
4
    #[inline]
91
    fn serialize_bool(self, v: bool) -> Result<()> {
92
        self.bytes_written += format::write_bool(&mut self.output, v)?;
93
        Ok(())
94
    }
95

            
96
14
    #[cfg_attr(feature = "tracing", instrument)]
97
14
    #[inline]
98
    fn serialize_i8(self, v: i8) -> Result<()> {
99
        self.bytes_written += format::write_i8(&mut self.output, v)?;
100
        Ok(())
101
    }
102

            
103
14
    #[cfg_attr(feature = "tracing", instrument)]
104
14
    #[inline]
105
    fn serialize_i16(self, v: i16) -> Result<()> {
106
        self.bytes_written += format::write_i16(&mut self.output, v)?;
107
        Ok(())
108
    }
109

            
110
18
    #[cfg_attr(feature = "tracing", instrument)]
111
18
    #[inline]
112
    fn serialize_i32(self, v: i32) -> Result<()> {
113
        self.bytes_written += format::write_i32(&mut self.output, v)?;
114
        Ok(())
115
    }
116

            
117
14
    #[cfg_attr(feature = "tracing", instrument)]
118
14
    #[inline]
119
    fn serialize_i64(self, v: i64) -> Result<()> {
120
        self.bytes_written += format::write_i64(&mut self.output, v)?;
121
        Ok(())
122
    }
123

            
124
60
    #[cfg_attr(feature = "tracing", instrument)]
125
60
    #[inline]
126
    fn serialize_i128(self, v: i128) -> Result<()> {
127
        self.bytes_written += format::write_i128(&mut self.output, v)?;
128
        Ok(())
129
    }
130

            
131
31
    #[cfg_attr(feature = "tracing", instrument)]
132
26
    #[inline]
133
    fn serialize_u8(self, v: u8) -> Result<()> {
134
        self.bytes_written += format::write_u8(&mut self.output, v)?;
135
        Ok(())
136
    }
137

            
138
280020
    #[cfg_attr(feature = "tracing", instrument)]
139
140017
    #[inline]
140
    fn serialize_u16(self, v: u16) -> Result<()> {
141
        self.bytes_written += format::write_u16(&mut self.output, v)?;
142
        Ok(())
143
    }
144

            
145
17
    #[cfg_attr(feature = "tracing", instrument)]
146
17
    #[inline]
147
    fn serialize_u32(self, v: u32) -> Result<()> {
148
        self.bytes_written += format::write_u32(&mut self.output, v)?;
149
        Ok(())
150
    }
151

            
152
280050
    #[cfg_attr(feature = "tracing", instrument)]
153
140044
    #[inline]
154
    fn serialize_u64(self, v: u64) -> Result<()> {
155
        self.bytes_written += format::write_u64(&mut self.output, v)?;
156
        Ok(())
157
    }
158

            
159
40
    #[cfg_attr(feature = "tracing", instrument)]
160
40
    #[inline]
161
    fn serialize_u128(self, v: u128) -> Result<()> {
162
        self.bytes_written += format::write_u128(&mut self.output, v)?;
163
        Ok(())
164
    }
165

            
166
16
    #[cfg_attr(feature = "tracing", instrument)]
167
16
    #[inline]
168
    fn serialize_f32(self, v: f32) -> Result<()> {
169
        self.bytes_written += format::write_f32(&mut self.output, v)?;
170
        Ok(())
171
    }
172

            
173
24
    #[cfg_attr(feature = "tracing", instrument)]
174
24
    #[inline]
175
    fn serialize_f64(self, v: f64) -> Result<()> {
176
        self.bytes_written += format::write_f64(&mut self.output, v)?;
177
        Ok(())
178
    }
179

            
180
13
    #[cfg_attr(feature = "tracing", instrument)]
181
13
    #[inline]
182
    fn serialize_char(self, v: char) -> Result<()> {
183
        self.bytes_written += format::write_u32(&mut self.output, v as u32)?;
184
        Ok(())
185
    }
186

            
187
980022
    #[cfg_attr(feature = "tracing", instrument)]
188
490017
    #[inline]
189
    fn serialize_str(self, v: &str) -> Result<()> {
190
        self.bytes_written += format::write_str(&mut self.output, v)?;
191
        Ok(())
192
    }
193

            
194
12
    #[cfg_attr(feature = "tracing", instrument)]
195
9
    #[inline]
196
    fn serialize_bytes(self, v: &[u8]) -> Result<()> {
197
        self.bytes_written += format::write_bytes(&mut self.output, v)?;
198
        Ok(())
199
    }
200

            
201
140034
    #[cfg_attr(feature = "tracing", instrument)]
202
70020
    #[inline]
203
    fn serialize_none(self) -> Result<()> {
204
        self.bytes_written += format::write_none(&mut self.output)?;
205
        Ok(())
206
    }
207

            
208
139984
    #[cfg_attr(feature = "tracing", instrument(level = "trace", skip(value)))]
209
    #[inline]
210
    fn serialize_some<T>(self, value: &T) -> Result<()>
211
    where
212
        T: ?Sized + Serialize,
213
    {
214
        value.serialize(self)
215
    }
216

            
217
7
    #[cfg_attr(feature = "tracing", instrument)]
218
7
    #[inline]
219
    fn serialize_unit(self) -> Result<()> {
220
        self.bytes_written += format::write_unit(&mut self.output)?;
221
        Ok(())
222
    }
223

            
224
2
    #[cfg_attr(feature = "tracing", instrument)]
225
2
    #[inline]
226
    fn serialize_unit_struct(self, _name: &'static str) -> Result<()> {
227
        self.serialize_unit()
228
    }
229

            
230
280010
    #[cfg_attr(feature = "tracing", instrument)]
231
140007
    #[inline]
232
    fn serialize_unit_variant(
233
        self,
234
        _name: &'static str,
235
        _variant_index: u32,
236
        variant: &'static str,
237
    ) -> Result<()> {
238
        format::write_named(&mut self.output)?;
239
        self.write_symbol(variant)?;
240
        Ok(())
241
    }
242

            
243
    #[cfg_attr(feature = "tracing", instrument(level = "trace", skip(value)))]
244
    #[inline]
245
    fn serialize_newtype_struct<T>(self, _name: &'static str, value: &T) -> Result<()>
246
    where
247
        T: ?Sized + Serialize,
248
    {
249
        value.serialize(self)
250
    }
251

            
252
2
    #[cfg_attr(feature = "tracing", instrument(level = "trace", skip(value)))]
253
    #[inline]
254
    fn serialize_newtype_variant<T>(
255
        self,
256
        _name: &'static str,
257
        _variant_index: u32,
258
        variant: &'static str,
259
        value: &T,
260
    ) -> Result<()>
261
    where
262
        T: ?Sized + Serialize,
263
    {
264
        format::write_named(&mut self.output)?;
265
        self.write_symbol(variant)?;
266
        value.serialize(&mut *self)?;
267
        Ok(())
268
    }
269

            
270
2026
    #[cfg_attr(feature = "tracing", instrument)]
271
1020
    #[inline]
272
    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq> {
273
        let len = len.ok_or(Error::SequenceSizeMustBeKnown)?;
274
        self.bytes_written +=
275
            format::write_atom_header(&mut self.output, Kind::Sequence, len as u64)?;
276
        Ok(self)
277
    }
278

            
279
2
    #[cfg_attr(feature = "tracing", instrument)]
280
2
    #[inline]
281
    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple> {
282
        self.serialize_seq(Some(len))
283
    }
284

            
285
2
    #[cfg_attr(feature = "tracing", instrument)]
286
2
    #[inline]
287
    fn serialize_tuple_struct(
288
        self,
289
        _name: &'static str,
290
        len: usize,
291
    ) -> Result<Self::SerializeTupleStruct> {
292
        self.serialize_seq(Some(len))
293
    }
294

            
295
2
    #[cfg_attr(feature = "tracing", instrument)]
296
2
    #[inline]
297
    fn serialize_tuple_variant(
298
        self,
299
        _name: &'static str,
300
        _variant_index: u32,
301
        variant: &'static str,
302
        len: usize,
303
    ) -> Result<Self::SerializeTupleVariant> {
304
        format::write_named(&mut self.output)?;
305
        self.write_symbol(variant)?;
306
        self.serialize_seq(Some(len))
307
    }
308

            
309
282046
    #[cfg_attr(feature = "tracing", instrument)]
310
141035
    #[inline]
311
    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap> {
312
        if let Some(len) = len {
313
            self.bytes_written +=
314
                format::write_atom_header(&mut self.output, Kind::Map, len as u64)?;
315
            Ok(MapSerializer {
316
                serializer: self,
317
                known_length: true,
318
            })
319
        } else {
320
            self.bytes_written += format::write_special(&mut self.output, Special::DynamicMap)?;
321
            Ok(MapSerializer {
322
                serializer: self,
323
                known_length: false,
324
            })
325
        }
326
    }
327

            
328
282043
    #[cfg_attr(feature = "tracing", instrument)]
329
141032
    #[inline]
330
    fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
331
        self.serialize_map(Some(len))
332
    }
333

            
334
2
    #[cfg_attr(feature = "tracing", instrument)]
335
2
    #[inline]
336
    fn serialize_struct_variant(
337
        self,
338
        name: &'static str,
339
        _variant_index: u32,
340
        variant: &'static str,
341
        len: usize,
342
    ) -> Result<Self::SerializeStructVariant> {
343
        format::write_named(&mut self.output)?;
344
        self.write_symbol(variant)?;
345
        self.serialize_struct(name, len)
346
    }
347
}
348

            
349
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeSeq for &'de mut Serializer<'a, W> {
350
    type Error = Error;
351
    type Ok = ();
352

            
353
    #[inline]
354
140025
    fn serialize_element<T>(&mut self, value: &T) -> Result<()>
355
140025
    where
356
140025
        T: ?Sized + Serialize,
357
140025
    {
358
140025
        value.serialize(&mut **self)
359
140025
    }
360

            
361
    #[inline]
362
1014
    fn end(self) -> Result<()> {
363
1014
        Ok(())
364
1014
    }
365
}
366

            
367
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeTuple for &'de mut Serializer<'a, W> {
368
    type Error = Error;
369
    type Ok = ();
370

            
371
    #[inline]
372
6
    fn serialize_element<T>(&mut self, value: &T) -> Result<()>
373
6
    where
374
6
        T: ?Sized + Serialize,
375
6
    {
376
6
        value.serialize(&mut **self)
377
6
    }
378

            
379
    #[inline]
380
2
    fn end(self) -> Result<()> {
381
2
        Ok(())
382
2
    }
383
}
384

            
385
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeTupleStruct for &'de mut Serializer<'a, W> {
386
    type Error = Error;
387
    type Ok = ();
388

            
389
    #[inline]
390
4
    fn serialize_field<T>(&mut self, value: &T) -> Result<()>
391
4
    where
392
4
        T: ?Sized + Serialize,
393
4
    {
394
4
        value.serialize(&mut **self)
395
4
    }
396

            
397
    #[inline]
398
2
    fn end(self) -> Result<()> {
399
2
        Ok(())
400
2
    }
401
}
402

            
403
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeTupleVariant
404
    for &'de mut Serializer<'a, W>
405
{
406
    type Error = Error;
407
    type Ok = ();
408

            
409
    #[inline]
410
4
    fn serialize_field<T>(&mut self, value: &T) -> Result<()>
411
4
    where
412
4
        T: ?Sized + Serialize,
413
4
    {
414
4
        value.serialize(&mut **self)
415
4
    }
416

            
417
    #[inline]
418
2
    fn end(self) -> Result<()> {
419
2
        Ok(())
420
2
    }
421
}
422

            
423
/// Serializes map-like values.
424
pub struct MapSerializer<'de, 'a, W: WriteBytesExt> {
425
    serializer: &'de mut Serializer<'a, W>,
426
    known_length: bool,
427
}
428

            
429
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeMap for MapSerializer<'de, 'a, W> {
430
    type Error = Error;
431
    type Ok = ();
432

            
433
    #[inline]
434
5
    fn serialize_key<T>(&mut self, key: &T) -> Result<()>
435
5
    where
436
5
        T: ?Sized + Serialize,
437
5
    {
438
5
        key.serialize(&mut *self.serializer)
439
5
    }
440

            
441
    #[inline]
442
5
    fn serialize_value<T>(&mut self, value: &T) -> Result<()>
443
5
    where
444
5
        T: ?Sized + Serialize,
445
5
    {
446
5
        value.serialize(&mut *self.serializer)
447
5
    }
448

            
449
    #[inline]
450
3
    fn end(self) -> Result<()> {
451
3
        if !self.known_length {
452
2
            format::write_special(&mut self.serializer.output, Special::DynamicEnd)?;
453
1
        }
454
3
        Ok(())
455
3
    }
456
}
457

            
458
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeStruct for MapSerializer<'de, 'a, W> {
459
    type Error = Error;
460
    type Ok = ();
461

            
462
    #[inline]
463
    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
464
    where
465
        T: ?Sized + Serialize,
466
    {
467
981220
        self.serializer.write_symbol(key)?;
468
981220
        value.serialize(&mut *self.serializer)
469
981220
    }
470

            
471
    #[inline]
472
141030
    fn end(self) -> Result<()> {
473
141030
        if !self.known_length {
474
            format::write_special(&mut self.serializer.output, Special::DynamicEnd)?;
475
141030
        }
476
141030
        Ok(())
477
141030
    }
478
}
479

            
480
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeStructVariant
481
    for MapSerializer<'de, 'a, W>
482
{
483
    type Error = Error;
484
    type Ok = ();
485

            
486
    #[inline]
487
    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
488
    where
489
        T: ?Sized + Serialize,
490
    {
491
2
        self.serializer.write_symbol(key)?;
492
2
        value.serialize(&mut *self.serializer)
493
2
    }
494

            
495
    #[inline]
496
2
    fn end(self) -> Result<()> {
497
2
        if !self.known_length {
498
            format::write_special(&mut self.serializer.output, Special::DynamicEnd)?;
499
2
        }
500
2
        Ok(())
501
2
    }
502
}
503

            
504
1156
#[derive(Default)]
505
struct EphemeralSymbolMap {
506
    symbols: Vec<(&'static str, u32)>,
507
}
508

            
509
struct RegisteredSymbol {
510
    id: u32,
511
    new: bool,
512
}
513

            
514
impl EphemeralSymbolMap {
515
    #[allow(clippy::cast_possible_truncation)]
516
4484211
    fn find_or_add(&mut self, symbol: &'static str) -> RegisteredSymbol {
517
4484211
        // Symbols have to be static strings, and so we can rely on the addres
518
4484211
        // not changing. To avoid string comparisons, we're going to use the
519
4484211
        // address of the str in the map.
520
4484211
        let symbol_address = symbol.as_ptr() as usize;
521
4484211
        // Perform a binary search to find this existing element.
522
4484211
        match self
523
4484211
            .symbols
524
14060726
            .binary_search_by(|check| (check.0.as_ptr() as usize).cmp(&symbol_address))
525
        {
526
4431834
            Ok(position) => RegisteredSymbol {
527
4431834
                id: self.symbols[position].1,
528
4431834
                new: false,
529
4431834
            },
530
52377
            Err(position) => {
531
52377
                let id = self.symbols.len() as u32;
532
52377
                self.symbols.insert(position, (symbol, id));
533
52377
                RegisteredSymbol { id, new: true }
534
            }
535
        }
536
4484211
    }
537
}
538

            
539
impl Debug for EphemeralSymbolMap {
540
485
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541
485
        let mut set = f.debug_set();
542
7858
        for index in SymbolIdSorter::new(&self.symbols, |sym| sym.1) {
543
1976
            set.entry(&self.symbols[index].0);
544
1976
        }
545
485
        set.finish()
546
485
    }
547
}
548

            
549
/// A list of previously serialized symbols.
550
pub struct SymbolMap {
551
    symbols: String,
552
    entries: Vec<(Range<usize>, u32)>,
553
    static_lookup: Vec<(usize, u32)>,
554
}
555

            
556
impl Debug for SymbolMap {
557
135
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
558
135
        let mut s = f.debug_set();
559
1255
        for entry in &self.entries {
560
1120
            s.entry(&&self.symbols[entry.0.clone()]);
561
1120
        }
562
135
        s.finish()
563
135
    }
564
}
565

            
566
impl Default for SymbolMap {
567
    #[inline]
568
5
    fn default() -> Self {
569
5
        Self::new()
570
5
    }
571
}
572

            
573
impl SymbolMap {
574
    /// Returns a new, empty symbol map.
575
    #[must_use]
576
13
    pub const fn new() -> Self {
577
13
        Self {
578
13
            symbols: String::new(),
579
13
            entries: Vec::new(),
580
13
            static_lookup: Vec::new(),
581
13
        }
582
13
    }
583

            
584
    /// Returns a serializer that writes into `output` and persists symbols
585
    /// into `self`.
586
    #[inline]
587
10
    pub fn serializer_for<W: WriteBytesExt>(&mut self, output: W) -> Result<Serializer<'_, W>> {
588
10
        Serializer::new_with_symbol_map(output, SymbolMapRef::Persistent(self))
589
10
    }
590

            
591
    /// Serializes `value` into `writer` while persisting symbols into `self`.
592
    pub fn serialize_to<T, W>(&mut self, writer: W, value: &T) -> Result<()>
593
    where
594
        W: Write,
595
        T: Serialize,
596
    {
597
10
        value.serialize(&mut self.serializer_for(writer)?)
598
10
    }
599

            
600
    /// Serializes `value` into a new `Vec<u8>` while persisting symbols into
601
    /// `self`.
602
6
    pub fn serialize_to_vec<T>(&mut self, value: &T) -> Result<Vec<u8>>
603
6
    where
604
6
        T: Serialize,
605
6
    {
606
6
        let mut output = Vec::new();
607
6
        self.serialize_to(&mut output, value)?;
608
6
        Ok(output)
609
6
    }
610

            
611
157
    fn find_or_add(&mut self, symbol: &'static str) -> RegisteredSymbol {
612
157
        // Symbols have to be static strings, and so we can rely on the addres
613
157
        // not changing. To avoid string comparisons, we're going to use the
614
157
        // address of the str in the map.
615
157
        let symbol_address = symbol.as_ptr() as usize;
616
157
        // Perform a binary search to find this existing element.
617
157
        match self
618
157
            .static_lookup
619
378
            .binary_search_by(|check| symbol_address.cmp(&check.0))
620
        {
621
69
            Ok(position) => RegisteredSymbol {
622
69
                id: self.static_lookup[position].1,
623
69
                new: false,
624
69
            },
625
88
            Err(position) => {
626
88
                // This static symbol hasn't been encountered before.
627
88
                let symbol = self.find_entry_by_str(symbol);
628
88
                self.static_lookup
629
88
                    .insert(position, (symbol_address, symbol.id));
630
88
                symbol
631
            }
632
        }
633
157
    }
634

            
635
    #[allow(clippy::cast_possible_truncation)]
636
88
    fn find_entry_by_str(&mut self, symbol: &str) -> RegisteredSymbol {
637
88
        match self
638
88
            .entries
639
186
            .binary_search_by(|check| self.symbols[check.0.clone()].cmp(symbol))
640
        {
641
2
            Ok(index) => RegisteredSymbol {
642
2
                id: self.entries[index].1,
643
2
                new: false,
644
2
            },
645
86
            Err(insert_at) => {
646
86
                let id = self.entries.len() as u32;
647
86
                let start = self.symbols.len();
648
86
                self.symbols.push_str(symbol);
649
86
                self.entries
650
86
                    .insert(insert_at, (start..self.symbols.len(), id));
651
86
                RegisteredSymbol { id, new: true }
652
            }
653
        }
654
88
    }
655

            
656
    /// Inserts `symbol` into this map.
657
    ///
658
    /// Returns true if this symbol had not previously been registered. Returns
659
    /// false if the symbol was already included in the map.
660
    pub fn insert(&mut self, symbol: &str) -> bool {
661
        self.find_entry_by_str(symbol).new
662
    }
663

            
664
    /// Returns the number of entries in this map.
665
    #[must_use]
666
7
    pub fn len(&self) -> usize {
667
7
        self.entries.len()
668
7
    }
669

            
670
    /// Returns true if the map has no entries.
671
    #[must_use]
672
1
    pub fn is_empty(&self) -> bool {
673
1
        self.len() == 0
674
1
    }
675

            
676
    /// Adds all symbols encountered in `value`.
677
    ///
678
    /// Returns the number of symbols added.
679
    ///
680
    /// Due to how serde works, this function can only encounter symbols that
681
    /// are being used. For example, if `T` is an enum, only variant being
682
    /// passed in will have its name, and additional calls for each variant will
683
    /// be needed to ensure every symbol is added.
684
7
    pub fn populate_from<T>(&mut self, value: &T) -> Result<usize, SymbolMapPopulationError>
685
7
    where
686
7
        T: Serialize,
687
7
    {
688
7
        let start_count = self.entries.len();
689
7
        value.serialize(&mut SymbolMapPopulator(self))?;
690
7
        Ok(self.entries.len() - start_count)
691
7
    }
692
}
693

            
694
impl Serialize for SymbolMap {
695
    #[inline]
696
2
    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
697
2
    where
698
2
        S: serde::Serializer,
699
2
    {
700
        use serde::ser::SerializeSeq;
701
2
        let mut seq = serializer.serialize_seq(Some(self.len()))?;
702
4
        for index in SymbolIdSorter::new(&self.entries, |entry| entry.1) {
703
4
            seq.serialize_element(&self.symbols[self.entries[index].0.clone()])?;
704
        }
705
2
        seq.end()
706
2
    }
707
}
708

            
709
impl<'de> Deserialize<'de> for SymbolMap {
710
    #[inline]
711
1
    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
712
1
    where
713
1
        D: serde::Deserializer<'de>,
714
1
    {
715
1
        deserializer.deserialize_seq(SymbolMapVisitor)
716
1
    }
717
}
718

            
719
struct SymbolMapVisitor;
720

            
721
impl<'de> Visitor<'de> for SymbolMapVisitor {
722
    type Value = SymbolMap;
723

            
724
    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
725
        formatter.write_str("symbol map")
726
    }
727

            
728
    #[inline]
729
1
    fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
730
1
    where
731
1
        A: SeqAccess<'de>,
732
1
    {
733
1
        let mut map = SymbolMap::new();
734
1
        if let Some(hint) = seq.size_hint() {
735
1
            map.entries.reserve(hint);
736
1
        }
737
1
        let mut id = 0;
738
3
        while let Some(element) = seq.next_element::<Cow<'_, str>>()? {
739
2
            let start = map.symbols.len();
740
2
            map.symbols.push_str(&element);
741
2
            map.entries.push((start..map.symbols.len(), id));
742
2
            id += 1;
743
2
        }
744

            
745
1
        map.entries
746
1
            .sort_by(|a, b| map.symbols[a.0.clone()].cmp(&map.symbols[b.0.clone()]));
747
1

            
748
1
        Ok(map)
749
1
    }
750
}
751

            
752
614
#[derive(Debug)]
753
enum SymbolMapRef<'a> {
754
    Ephemeral(EphemeralSymbolMap),
755
    Persistent(&'a mut SymbolMap),
756
}
757

            
758
impl SymbolMapRef<'_> {
759
4484337
    fn find_or_add(&mut self, symbol: &'static str) -> RegisteredSymbol {
760
4484337
        match self {
761
4484207
            SymbolMapRef::Ephemeral(map) => map.find_or_add(symbol),
762
130
            SymbolMapRef::Persistent(map) => map.find_or_add(symbol),
763
        }
764
4484337
    }
765
}
766

            
767
struct SymbolMapPopulator<'a>(&'a mut SymbolMap);
768

            
769
impl<'ser, 'a> serde::ser::Serializer for &'ser mut SymbolMapPopulator<'a> {
770
    type Error = SymbolMapPopulationError;
771
    type Ok = ();
772
    type SerializeMap = Self;
773
    type SerializeSeq = Self;
774
    type SerializeStruct = Self;
775
    type SerializeStructVariant = Self;
776
    type SerializeTuple = Self;
777
    type SerializeTupleStruct = Self;
778
    type SerializeTupleVariant = Self;
779

            
780
    #[inline]
781
    fn serialize_bool(self, _v: bool) -> std::result::Result<Self::Ok, Self::Error> {
782
        Ok(())
783
    }
784

            
785
    #[inline]
786
1
    fn serialize_i8(self, _v: i8) -> std::result::Result<Self::Ok, Self::Error> {
787
1
        Ok(())
788
1
    }
789

            
790
    #[inline]
791
1
    fn serialize_i16(self, _v: i16) -> std::result::Result<Self::Ok, Self::Error> {
792
1
        Ok(())
793
1
    }
794

            
795
    #[inline]
796
1
    fn serialize_i32(self, _v: i32) -> std::result::Result<Self::Ok, Self::Error> {
797
1
        Ok(())
798
1
    }
799

            
800
    #[inline]
801
1
    fn serialize_i64(self, _v: i64) -> std::result::Result<Self::Ok, Self::Error> {
802
1
        Ok(())
803
1
    }
804

            
805
    #[inline]
806
1
    fn serialize_u8(self, _v: u8) -> std::result::Result<Self::Ok, Self::Error> {
807
1
        Ok(())
808
1
    }
809

            
810
    #[inline]
811
1
    fn serialize_u16(self, _v: u16) -> std::result::Result<Self::Ok, Self::Error> {
812
1
        Ok(())
813
1
    }
814

            
815
    #[inline]
816
1
    fn serialize_u32(self, _v: u32) -> std::result::Result<Self::Ok, Self::Error> {
817
1
        Ok(())
818
1
    }
819

            
820
    #[inline]
821
6
    fn serialize_u64(self, _v: u64) -> std::result::Result<Self::Ok, Self::Error> {
822
6
        Ok(())
823
6
    }
824

            
825
    #[inline]
826
1
    fn serialize_i128(self, _v: i128) -> Result<Self::Ok, Self::Error> {
827
1
        Ok(())
828
1
    }
829

            
830
    #[inline]
831
1
    fn serialize_u128(self, _v: u128) -> Result<Self::Ok, Self::Error> {
832
1
        Ok(())
833
1
    }
834

            
835
    #[inline]
836
1
    fn serialize_f32(self, _v: f32) -> std::result::Result<Self::Ok, Self::Error> {
837
1
        Ok(())
838
1
    }
839

            
840
    #[inline]
841
1
    fn serialize_f64(self, _v: f64) -> std::result::Result<Self::Ok, Self::Error> {
842
1
        Ok(())
843
1
    }
844

            
845
    #[inline]
846
1
    fn serialize_char(self, _v: char) -> std::result::Result<Self::Ok, Self::Error> {
847
1
        Ok(())
848
1
    }
849

            
850
    #[inline]
851
1
    fn serialize_str(self, _v: &str) -> std::result::Result<Self::Ok, Self::Error> {
852
1
        Ok(())
853
1
    }
854

            
855
    #[inline]
856
    fn serialize_bytes(self, _v: &[u8]) -> std::result::Result<Self::Ok, Self::Error> {
857
        Ok(())
858
    }
859

            
860
    #[inline]
861
    fn serialize_none(self) -> std::result::Result<Self::Ok, Self::Error> {
862
        Ok(())
863
    }
864

            
865
    #[inline]
866
    fn serialize_some<T: ?Sized>(self, value: &T) -> std::result::Result<Self::Ok, Self::Error>
867
    where
868
        T: Serialize,
869
    {
870
        value.serialize(self)
871
    }
872

            
873
    #[inline]
874
    fn serialize_unit(self) -> std::result::Result<Self::Ok, Self::Error> {
875
        Ok(())
876
    }
877

            
878
    #[inline]
879
    fn serialize_unit_struct(
880
        self,
881
        _name: &'static str,
882
    ) -> std::result::Result<Self::Ok, Self::Error> {
883
        Ok(())
884
    }
885

            
886
    #[inline]
887
2
    fn serialize_unit_variant(
888
2
        self,
889
2
        _name: &'static str,
890
2
        _variant_index: u32,
891
2
        variant: &'static str,
892
2
    ) -> std::result::Result<Self::Ok, Self::Error> {
893
2
        self.0.find_or_add(variant);
894
2
        Ok(())
895
2
    }
896

            
897
    #[inline]
898
    fn serialize_newtype_struct<T: ?Sized>(
899
        self,
900
        _name: &'static str,
901
        value: &T,
902
    ) -> std::result::Result<Self::Ok, Self::Error>
903
    where
904
        T: Serialize,
905
    {
906
        value.serialize(self)
907
    }
908

            
909
    #[inline]
910
1
    fn serialize_newtype_variant<T: ?Sized>(
911
1
        self,
912
1
        _name: &'static str,
913
1
        _variant_index: u32,
914
1
        variant: &'static str,
915
1
        value: &T,
916
1
    ) -> std::result::Result<Self::Ok, Self::Error>
917
1
    where
918
1
        T: Serialize,
919
1
    {
920
1
        self.0.find_or_add(variant);
921
1
        value.serialize(self)
922
1
    }
923

            
924
    #[inline]
925
    fn serialize_seq(
926
        self,
927
        _len: Option<usize>,
928
    ) -> std::result::Result<Self::SerializeSeq, Self::Error> {
929
        Ok(self)
930
    }
931

            
932
    #[inline]
933
    fn serialize_tuple(
934
        self,
935
        _len: usize,
936
    ) -> std::result::Result<Self::SerializeTuple, Self::Error> {
937
        Ok(self)
938
    }
939

            
940
    #[inline]
941
    fn serialize_tuple_struct(
942
        self,
943
        _name: &'static str,
944
        _len: usize,
945
    ) -> std::result::Result<Self::SerializeTupleStruct, Self::Error> {
946
        Ok(self)
947
    }
948

            
949
    #[inline]
950
1
    fn serialize_tuple_variant(
951
1
        self,
952
1
        _name: &'static str,
953
1
        _variant_index: u32,
954
1
        variant: &'static str,
955
1
        _len: usize,
956
1
    ) -> std::result::Result<Self::SerializeTupleVariant, Self::Error> {
957
1
        self.0.find_or_add(variant);
958
1
        Ok(self)
959
1
    }
960

            
961
    #[inline]
962
    fn serialize_map(
963
        self,
964
        _len: Option<usize>,
965
    ) -> std::result::Result<Self::SerializeMap, Self::Error> {
966
        Ok(self)
967
    }
968

            
969
    #[inline]
970
2
    fn serialize_struct(
971
2
        self,
972
2
        _name: &'static str,
973
2
        _len: usize,
974
2
    ) -> std::result::Result<Self::SerializeStruct, Self::Error> {
975
2
        Ok(self)
976
2
    }
977

            
978
    #[inline]
979
1
    fn serialize_struct_variant(
980
1
        self,
981
1
        _name: &'static str,
982
1
        _variant_index: u32,
983
1
        variant: &'static str,
984
1
        _len: usize,
985
1
    ) -> std::result::Result<Self::SerializeStructVariant, Self::Error> {
986
1
        self.0.find_or_add(variant);
987
1
        Ok(self)
988
1
    }
989
}
990

            
991
impl serde::ser::SerializeMap for &mut SymbolMapPopulator<'_> {
992
    type Error = SymbolMapPopulationError;
993
    type Ok = ();
994

            
995
    #[inline]
996
    fn serialize_key<T: ?Sized>(&mut self, key: &T) -> std::result::Result<(), Self::Error>
997
    where
998
        T: Serialize,
999
    {
        key.serialize(&mut SymbolMapPopulator(&mut *self.0))
    }

            
    #[inline]
    fn serialize_value<T: ?Sized>(&mut self, value: &T) -> std::result::Result<(), Self::Error>
    where
        T: Serialize,
    {
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
    }

            
    #[inline]
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
        Ok(())
    }
}

            
impl serde::ser::SerializeSeq for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
    fn serialize_element<T: ?Sized>(&mut self, value: &T) -> std::result::Result<(), Self::Error>
    where
        T: Serialize,
    {
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
    }

            
    #[inline]
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
        Ok(())
    }
}

            
impl serde::ser::SerializeStruct for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
15
    fn serialize_field<T: ?Sized>(
15
        &mut self,
15
        key: &'static str,
15
        value: &T,
15
    ) -> std::result::Result<(), Self::Error>
15
    where
15
        T: Serialize,
15
    {
15
        self.0.find_or_add(key);
15
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
15
    }

            
    #[inline]
2
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
2
        Ok(())
2
    }
}

            
impl serde::ser::SerializeStructVariant for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
1
    fn serialize_field<T: ?Sized>(
1
        &mut self,
1
        key: &'static str,
1
        value: &T,
1
    ) -> std::result::Result<(), Self::Error>
1
    where
1
        T: Serialize,
1
    {
1
        self.0.find_or_add(key);
1
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
1
    }

            
    #[inline]
1
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
1
        Ok(())
1
    }
}

            
impl serde::ser::SerializeTuple for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
    fn serialize_element<T: ?Sized>(&mut self, value: &T) -> std::result::Result<(), Self::Error>
    where
        T: Serialize,
    {
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
    }

            
    #[inline]
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
        Ok(())
    }
}
impl serde::ser::SerializeTupleStruct for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
    fn serialize_field<T: ?Sized>(&mut self, value: &T) -> std::result::Result<(), Self::Error>
    where
        T: Serialize,
    {
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
    }

            
    #[inline]
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
        Ok(())
    }
}
impl serde::ser::SerializeTupleVariant for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
2
    fn serialize_field<T: ?Sized>(&mut self, value: &T) -> std::result::Result<(), Self::Error>
2
    where
2
        T: Serialize,
2
    {
2
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
2
    }

            
    #[inline]
1
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
1
        Ok(())
1
    }
}

            
/// A [`Serialize`] implementation returned an error.
#[derive(Debug)]
pub struct SymbolMapPopulationError(String);

            
impl Display for SymbolMapPopulationError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(&self.0)
    }
}

            
impl std::error::Error for SymbolMapPopulationError {}

            
impl serde::ser::Error for SymbolMapPopulationError {
    fn custom<T>(msg: T) -> Self
    where
        T: Display,
    {
        Self(msg.to_string())
    }
}

            
struct SymbolIdSorter<'a, T, F> {
    source: &'a [T],
    map: F,
    min: usize,
    id: u32,
}

            
impl<'a, T, F> SymbolIdSorter<'a, T, F>
where
    F: FnMut(&T) -> u32,
{
487
    pub fn new(source: &'a [T], map: F) -> Self {
487
        Self {
487
            source,
487
            map,
487
            min: 0,
487
            id: 0,
487
        }
487
    }
}
impl<'a, T, F> Iterator for SymbolIdSorter<'a, T, F>
where
    F: FnMut(&T) -> u32,
    T: Clone,
{
    type Item = usize;

            
2467
    fn next(&mut self) -> Option<Self::Item> {
2467
        let mut encountered_greater = false;
2467
        let start_min = self.min;
7863
        for (relative_index, entry) in self.source[start_min..].iter().enumerate() {
7862
            let id = (self.map)(entry);
7862
            match id.cmp(&self.id) {
                Ordering::Equal => {
1980
                    self.id += 1;
1980
                    let index = start_min + relative_index;
1980
                    if !encountered_greater {
887
                        self.min = index + 1;
1095
                    }
1980
                    return Some(index);
                }
1779
                Ordering::Greater => encountered_greater = true,
1093
                Ordering::Less if !encountered_greater => self.min = start_min + relative_index,
3010
                Ordering::Less => {}
            }
        }

            
487
        None
2467
    }
}

            
1
#[test]
1
fn symbol_map_debug() {
1
    let mut map = EphemeralSymbolMap::default();
1
    // To force the order, we're splitting a single string into multiple parts.
1
    let full_source = "abcd";
1

            
1
    map.find_or_add(&full_source[1..2]);
1
    map.find_or_add(&full_source[0..1]);
1
    map.find_or_add(&full_source[2..3]);
1
    map.find_or_add(&full_source[3..4]);
1

            
1
    // Verify the map sorted the symbols correctly (by memory address).
1
    assert_eq!(map.symbols[0].0, "a");
1
    assert_eq!(map.symbols[1].0, "b");
1
    assert_eq!(map.symbols[2].0, "c");
1
    assert_eq!(map.symbols[3].0, "d");

            
    // Verify the debug output printed the correct order.
1
    assert_eq!(format!("{map:?}"), r#"{"b", "a", "c", "d"}"#);
1
}