1
use core::marker::PhantomData;
2

            
3
use serde::de::{MapAccess, Visitor};
4
use serde::ser::{SerializeMap, SerializeSeq};
5
use serde::{Deserialize, Serialize};
6

            
7
use crate::{Map, Set, Sort};
8

            
9
impl<Key, Value> Serialize for Map<Key, Value>
10
where
11
    Key: Serialize + Sort<Key>,
12
    Value: Serialize,
13
{
14
    #[inline]
15
1
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
16
1
    where
17
1
        S: serde::Serializer,
18
1
    {
19
1
        let mut map = serializer.serialize_map(Some(self.len()))?;
20
3
        for field in self {
21
2
            map.serialize_entry(field.key(), &field.value)?;
22
        }
23
1
        map.end()
24
1
    }
25
}
26

            
27
impl<'de, Key, Value> Deserialize<'de> for Map<Key, Value>
28
where
29
    Key: Deserialize<'de> + Sort<Key>,
30
    Value: Deserialize<'de>,
31
{
32
    #[inline]
33
3
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
34
3
    where
35
3
        D: serde::Deserializer<'de>,
36
3
    {
37
3
        deserializer.deserialize_map(MapVisitor(PhantomData))
38
3
    }
39
}
40

            
41
struct MapVisitor<Key, Value>(PhantomData<(Key, Value)>);
42

            
43
impl<'de, Key, Value> Visitor<'de> for MapVisitor<Key, Value>
44
where
45
    Key: Deserialize<'de> + Sort<Key>,
46
    Value: Deserialize<'de>,
47
{
48
    type Value = Map<Key, Value>;
49

            
50
    #[inline]
51
1
    fn expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result {
52
1
        formatter.write_str("a Map")
53
1
    }
54

            
55
    #[inline]
56
2
    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
57
2
    where
58
2
        A: MapAccess<'de>,
59
2
    {
60
2
        let mut obj = Map::with_capacity(map.size_hint().unwrap_or(0));
61
6
        while let Some((key, value)) = map.next_entry()? {
62
4
            obj.insert(key, value);
63
4
        }
64
2
        Ok(obj)
65
2
    }
66
}
67

            
68
impl<Key> Serialize for Set<Key>
69
where
70
    Key: Ord + Serialize,
71
{
72
1
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73
1
    where
74
1
        S: serde::Serializer,
75
1
    {
76
1
        let mut seq = serializer.serialize_seq(Some(self.len()))?;
77
3
        for field in self {
78
2
            seq.serialize_element(field)?;
79
        }
80
1
        seq.end()
81
1
    }
82
}
83

            
84
impl<'de, Key> Deserialize<'de> for Set<Key>
85
where
86
    Key: Ord + Deserialize<'de>,
87
{
88
3
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
89
3
    where
90
3
        D: serde::Deserializer<'de>,
91
3
    {
92
3
        deserializer.deserialize_seq(SetVisitor(PhantomData))
93
3
    }
94
}
95

            
96
struct SetVisitor<Key>(PhantomData<(Key,)>);
97

            
98
impl<'de, Key> Visitor<'de> for SetVisitor<Key>
99
where
100
    Key: Deserialize<'de> + Sort<Key>,
101
{
102
    type Value = Set<Key>;
103

            
104
    #[inline]
105
1
    fn expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result {
106
1
        formatter.write_str("a Set")
107
1
    }
108

            
109
    #[inline]
110
2
    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
111
2
    where
112
2
        A: serde::de::SeqAccess<'de>,
113
2
    {
114
2
        let mut obj = Set::with_capacity(seq.size_hint().unwrap_or(0));
115
6
        while let Some(key) = seq.next_element()? {
116
4
            obj.insert(key);
117
4
        }
118
2
        Ok(obj)
119
2
    }
120
}
121

            
122
1
#[test]
123
1
fn map_tests() {
124
1
    use serde_test::{assert_de_tokens_error, assert_tokens, Token};
125
1

            
126
1
    let map = [(1, 1), (2, 2)].into_iter().collect::<Map<u8, u16>>();
127
1
    assert_tokens(
128
1
        &map,
129
1
        &[
130
1
            Token::Map { len: Some(2) },
131
1
            Token::U8(1),
132
1
            Token::U16(1),
133
1
            Token::U8(2),
134
1
            Token::U16(2),
135
1
            Token::MapEnd,
136
1
        ],
137
1
    );
138
1

            
139
1
    assert_de_tokens_error::<Map<u8, u16>>(
140
1
        &[Token::U8(1)],
141
1
        "invalid type: integer `1`, expected a Map",
142
1
    );
143
1
}
144

            
145
1
#[test]
146
1
fn set_tests() {
147
1
    use serde_test::{assert_de_tokens_error, assert_tokens, Token};
148
1

            
149
1
    let map = [1, 2].into_iter().collect::<Set<u8>>();
150
1
    assert_tokens(
151
1
        &map,
152
1
        &[
153
1
            Token::Seq { len: Some(2) },
154
1
            Token::U8(1),
155
1
            Token::U8(2),
156
1
            Token::SeqEnd,
157
1
        ],
158
1
    );
159
1

            
160
1
    assert_de_tokens_error::<Set<u8>>(&[Token::U8(1)], "invalid type: integer `1`, expected a Set");
161
1
}