1
use core::fmt::{self, Debug};
2

            
3
use crate::map::{self, Field, OwnedOrRef};
4
use crate::{Map, Sort};
5

            
6
/// An iterator over the vakyes in a [`Set`].
7
pub type Iter<'a, T> = map::Keys<'a, T, ()>;
8
/// An iterator that converts a [`Set`] into its owned values.
9
pub type IntoIter<T> = map::IntoKeys<T, ()>;
10

            
11
/// An ordered collection of unique `T`s.
12
///
13
/// This data type only allows each unique value to be stored once.
14
///
15
/// ```rust
16
/// use kempt::Set;
17
///
18
/// let mut set = Set::new();
19
/// set.insert(1);
20
/// assert!(!set.insert(1));
21
/// assert_eq!(set.len(), 1);
22
/// ```
23
///
24
/// The values in the collection are automatically sorted using `T`'s [`Ord`]
25
/// implementation.
26
///
27
/// ```rust
28
/// use kempt::Set;
29
///
30
/// let mut set = Set::new();
31
/// set.insert(1);
32
/// set.insert(3);
33
/// set.insert(2);
34
/// assert_eq!(set.member(0), Some(&1));
35
/// assert_eq!(set.member(1), Some(&2));
36
/// assert_eq!(set.member(2), Some(&3));
37
/// ```
38
2
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
39
pub struct Set<T>(Map<T, ()>)
40
where
41
    T: Sort<T>;
42

            
43
impl<T> Default for Set<T>
44
where
45
    T: Sort<T>,
46
{
47
    #[inline]
48
1
    fn default() -> Self {
49
1
        Self::new()
50
1
    }
51
}
52

            
53
impl<T> Set<T>
54
where
55
    T: Sort<T>,
56
{
57
    /// Returns an empty set.
58
    #[must_use]
59
    #[inline]
60
1
    pub const fn new() -> Self {
61
1
        Self(Map::new())
62
1
    }
63

            
64
    /// Returns an empty set with enough allocated memory to store `capacity`
65
    /// values without reallocating.
66
    #[must_use]
67
    #[inline]
68
6
    pub fn with_capacity(capacity: usize) -> Self {
69
6
        Self(Map::with_capacity(capacity))
70
6
    }
71

            
72
    /// Returns the current capacity this map can hold before it must
73
    /// reallocate.
74
    #[must_use]
75
    #[inline]
76
5
    pub fn capacity(&self) -> usize {
77
5
        self.0.capacity()
78
5
    }
79

            
80
    /// Inserts or replaces `value` in the set, returning `true` if the
81
    /// collection is modified. If a previously contained value returns
82
    /// [`Ordering::Equal`](core::cmp::Ordering::Equal) from [`Ord::cmp`], the
83
    /// collection will not be modified and `false` will be returned.
84
    #[inline]
85
12
    pub fn insert(&mut self, value: T) -> bool {
86
12
        self.0.insert_with(value, || ()).is_none()
87
12
    }
88

            
89
    /// Inserts or replaces `value` in the set. If a previously contained value
90
    /// returns [`Ordering::Equal`](core::cmp::Ordering::Equal) from
91
    /// [`Ord::cmp`], the new value will overwrite the stored value and it will
92
    /// be returned.
93
    #[inline]
94
1
    pub fn replace(&mut self, value: T) -> Option<T> {
95
1
        self.0.insert(value, ()).map(|field| field.into_parts().0)
96
1
    }
97

            
98
    /// Returns true if the set contains a matching `value`.
99
    #[inline]
100
1
    pub fn contains<SearchFor>(&self, value: &SearchFor) -> bool
101
1
    where
102
1
        T: Sort<SearchFor>,
103
1
        SearchFor: ?Sized,
104
1
    {
105
1
        self.0.contains(value)
106
1
    }
107

            
108
    /// Returns the contained value that matches `value`.
109
    #[inline]
110
1
    pub fn get<SearchFor>(&self, value: &SearchFor) -> Option<&T>
111
1
    where
112
1
        T: Sort<SearchFor>,
113
1
        SearchFor: ?Sized,
114
1
    {
115
1
        self.0.get_field(value).map(Field::key)
116
1
    }
117

            
118
    /// Removes a value from the set, returning the value if it was removed.
119
    #[inline]
120
2
    pub fn remove<SearchFor>(&mut self, value: &SearchFor) -> Option<T>
121
2
    where
122
2
        T: Sort<SearchFor>,
123
2
        SearchFor: ?Sized,
124
2
    {
125
2
        self.0.remove(value).map(|field| field.into_parts().0)
126
2
    }
127

            
128
    /// Returns the member at `index` inside of this ordered set. Returns `None`
129
    /// if `index` is greater than or equal to the set's length.
130
    #[inline]
131
2
    pub fn member(&self, index: usize) -> Option<&T> {
132
2
        self.0.field(index).map(Field::key)
133
2
    }
134

            
135
    /// Removes the member at `index`.
136
    ///
137
    /// # Panics
138
    ///
139
    /// A panic will occur if `index` is greater than or equal to the set's
140
    /// length.
141
    #[inline]
142
    pub fn remove_member(&mut self, index: usize) -> T {
143
        self.0.remove_by_index(index).into_key()
144
    }
145

            
146
    /// Returns the number of members in this set.
147
    #[must_use]
148
    #[inline]
149
4
    pub fn len(&self) -> usize {
150
4
        self.0.len()
151
4
    }
152

            
153
    /// Returns true if there are no members in this set.
154
    #[must_use]
155
    #[inline]
156
2
    pub fn is_empty(&self) -> bool {
157
2
        self.0.is_empty()
158
2
    }
159

            
160
    /// Returns an iterator over the members in this set.
161
    #[must_use]
162
    #[inline]
163
1
    pub fn iter(&self) -> Iter<'_, T> {
164
1
        self.into_iter()
165
1
    }
166

            
167
    /// Returns an iterator that yields a single reference to all members found
168
    /// in either `self` or `other`.
169
    ///
170
    /// This iterator is guaranteed to return results in the sort order of the
171
    /// `Key` type.
172
    #[must_use]
173
    #[inline]
174
2
    pub fn union<'a>(&'a self, other: &'a Set<T>) -> Union<'a, T> {
175
2
        Union(self.0.union(&other.0))
176
2
    }
177

            
178
    /// Returns an iterator that yields a single reference to all members found
179
    /// in both `self` and `other`.
180
    ///
181
    /// This iterator is guaranteed to return results in the sort order of the
182
    /// `Key` type.
183
    #[must_use]
184
    #[inline]
185
2
    pub fn intersection<'a>(&'a self, other: &'a Set<T>) -> Intersection<'a, T> {
186
2
        Intersection(self.0.intersection(&other.0))
187
2
    }
188

            
189
    /// Returns an iterator that yields a single reference to all members found
190
    /// in `self` but not `other`.
191
    ///
192
    /// This iterator is guaranteed to return results in the sort order of the
193
    /// `Key` type.
194
    #[must_use]
195
    #[inline]
196
2
    pub fn difference<'a>(&'a self, other: &'a Set<T>) -> Difference<'a, T> {
197
2
        Difference(self.0.difference(&other.0))
198
2
    }
199

            
200
    /// Returns an iterator over the contents of this set. After the iterator is
201
    /// dropped, this set will be empty.
202
    #[inline]
203
    pub fn drain(&mut self) -> Drain<'_, T> {
204
        Drain(self.0.drain())
205
    }
206

            
207
    /// Clears the contents of this collection.
208
    ///
209
    /// This does not return any allocated memory to the OS.
210
    #[inline]
211
1
    pub fn clear(&mut self) {
212
1
        self.0.clear();
213
1
    }
214

            
215
    /// Resizes this collection to fit its contents exactly.
216
    ///
217
    /// This function will reallocate its internal storage to fit the contents
218
    /// of this collection's current size. If the allocation is already the
219
    /// correct size, this is a no-op.
220
    #[inline]
221
1
    pub fn shrink_to_fit(&mut self) {
222
1
        self.0.shrink_to_fit();
223
1
    }
224

            
225
    /// Resizes this collection to be able to hold `min_capacity`.
226
    ///
227
    /// This function will reallocate its internal storage to fit the contents
228
    /// of this collection's current size. If the allocation is already the
229
    /// correct size, this is a no-op.
230
    ///
231
    /// If the length of this collection is larger than `min_capacity`, this
232
    /// function will behave identically to
233
    /// [`shrink_to_fit()`](Self::shrink_to_fit).
234
    #[inline]
235
1
    pub fn shrink_to(&mut self, min_capacity: usize) {
236
1
        self.0.shrink_to(min_capacity);
237
1
    }
238
}
239

            
240
impl<T> Debug for Set<T>
241
where
242
    T: Sort<T> + Debug,
243
{
244
    #[inline]
245
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246
        let mut s = f.debug_set();
247
        for member in self {
248
            s.entry(member);
249
        }
250
        s.finish()
251
    }
252
}
253

            
254
impl<'a, T> IntoIterator for &'a Set<T>
255
where
256
    T: Sort<T>,
257
{
258
    type IntoIter = Iter<'a, T>;
259
    type Item = &'a T;
260

            
261
    #[inline]
262
2
    fn into_iter(self) -> Self::IntoIter {
263
2
        self.0.keys()
264
2
    }
265
}
266

            
267
impl<T> FromIterator<T> for Set<T>
268
where
269
    T: Sort<T>,
270
{
271
    #[inline]
272
10
    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
273
29
        Self(iter.into_iter().map(|t| (t, ())).collect())
274
10
    }
275
}
276

            
277
/// An iterator that yields a single reference to all members found in either
278
/// [`Set`] being unioned.
279
///
280
/// This iterator is guaranteed to return results in the sort order of the `Key`
281
/// type.
282
pub struct Union<'a, T>(map::Union<'a, T, ()>)
283
where
284
    T: Sort<T>;
285

            
286
impl<'a, T> Iterator for Union<'a, T>
287
where
288
    T: Sort<T>,
289
{
290
    type Item = &'a T;
291

            
292
    #[inline]
293
12
    fn next(&mut self) -> Option<Self::Item> {
294
12
        self.0
295
12
            .next()
296
12
            .map(|unioned| unioned.map_both(|_, (), ()| OwnedOrRef::Owned(())).key)
297
12
    }
298

            
299
    #[inline]
300
2
    fn size_hint(&self) -> (usize, Option<usize>) {
301
2
        self.0.size_hint()
302
2
    }
303
}
304

            
305
/// An iterator that yields a single reference to all members found in both
306
/// [`Set`]s being intersected.
307
///
308
/// This iterator is guaranteed to return results in the sort order of the `Key`
309
/// type.
310
pub struct Intersection<'a, T>(map::Intersection<'a, T, ()>)
311
where
312
    T: Sort<T>;
313

            
314
impl<'a, T> Iterator for Intersection<'a, T>
315
where
316
    T: Sort<T>,
317
{
318
    type Item = &'a T;
319

            
320
    #[inline]
321
4
    fn next(&mut self) -> Option<Self::Item> {
322
4
        self.0.next().map(|(k, (), ())| k)
323
4
    }
324

            
325
    #[inline]
326
2
    fn size_hint(&self) -> (usize, Option<usize>) {
327
2
        self.0.size_hint()
328
2
    }
329
}
330

            
331
/// An iterator that yields a single reference to all members found in one
332
/// [`Set`], but not another.
333
///
334
/// This iterator is guaranteed to return results in the sort order of the `Key`
335
/// type.
336
pub struct Difference<'a, T>(map::Difference<'a, T, ()>)
337
where
338
    T: Sort<T>;
339

            
340
impl<'a, T> Iterator for Difference<'a, T>
341
where
342
    T: Sort<T>,
343
{
344
    type Item = &'a T;
345

            
346
    #[inline]
347
6
    fn next(&mut self) -> Option<Self::Item> {
348
6
        self.0.next().map(|(k, ())| k)
349
6
    }
350

            
351
    #[inline]
352
2
    fn size_hint(&self) -> (usize, Option<usize>) {
353
2
        self.0.size_hint()
354
2
    }
355
}
356

            
357
/// An iterator that drains the contents of a [`Set`].
358
///
359
/// When this is dropped, the remaining contents are drained.
360
pub struct Drain<'a, T>(map::Drain<'a, T, ()>);
361

            
362
impl<T> Iterator for Drain<'_, T> {
363
    type Item = T;
364

            
365
    #[inline]
366
    fn next(&mut self) -> Option<Self::Item> {
367
        self.0.next().map(map::Field::into_key)
368
    }
369
}
370

            
371
1
#[test]
372
1
fn basics() {
373
1
    let mut set = Set::default();
374
1
    assert!(set.is_empty());
375
1
    assert!(set.insert(1));
376
1
    assert!(set.contains(&1));
377
1
    assert_eq!(set.replace(1), Some(1));
378
1
    assert!(set.insert(0));
379

            
380
1
    assert_eq!(set.member(0), Some(&0));
381
1
    assert_eq!(set.member(1), Some(&1));
382

            
383
1
    assert_eq!(set.len(), 2);
384
1
    assert_eq!(set.remove(&0), Some(0));
385
1
    assert_eq!(set.len(), 1);
386
1
    assert_eq!(set.remove(&1), Some(1));
387
1
    assert_eq!(set.len(), 0);
388
1
}
389

            
390
1
#[test]
391
1
fn union() {
392
1
    use alloc::vec::Vec;
393
1
    let a = [1, 3, 5].into_iter().collect::<Set<u8>>();
394
1
    let b = [2, 3, 4].into_iter().collect::<Set<u8>>();
395
1
    assert_eq!(a.union(&b).copied().collect::<Vec<_>>(), [1, 2, 3, 4, 5]);
396

            
397
1
    let b = [2, 3, 6].into_iter().collect::<Set<u8>>();
398
1
    assert_eq!(a.union(&b).copied().collect::<Vec<_>>(), [1, 2, 3, 5, 6]);
399
1
}
400

            
401
1
#[test]
402
1
fn intersection() {
403
1
    use alloc::vec::Vec;
404
1
    let a = [1, 3, 5].into_iter().collect::<Set<u8>>();
405
1
    let b = [2, 3, 4].into_iter().collect::<Set<u8>>();
406
1
    assert_eq!(a.intersection(&b).copied().collect::<Vec<_>>(), [3]);
407

            
408
1
    let b = [2, 3, 6].into_iter().collect::<Set<u8>>();
409
1
    assert_eq!(a.intersection(&b).copied().collect::<Vec<_>>(), [3]);
410
1
}
411

            
412
1
#[test]
413
1
fn difference() {
414
1
    use alloc::vec::Vec;
415
1
    let a = [1, 3, 5].into_iter().collect::<Set<u8>>();
416
1
    let b = [2, 3, 4].into_iter().collect::<Set<u8>>();
417
1
    assert_eq!(a.difference(&b).copied().collect::<Vec<_>>(), [1, 5]);
418

            
419
1
    let b = [2, 5, 6].into_iter().collect::<Set<u8>>();
420
1
    assert_eq!(a.difference(&b).copied().collect::<Vec<_>>(), [1, 3]);
421
1
}
422

            
423
1
#[test]
424
1
fn lookup() {
425
1
    let mut set = Set::with_capacity(1);
426
1
    let key = alloc::string::String::from("hello");
427
1
    let key_ptr = key.as_ptr();
428
1
    set.insert(key);
429
1
    assert_eq!(set.get("hello").unwrap().as_ptr(), key_ptr);
430
1
}
431

            
432
1
#[test]
433
1
fn iteration() {
434
1
    use alloc::vec::Vec;
435
1
    let mut set = Set::with_capacity(3);
436
1
    set.insert(1);
437
1
    set.insert(3);
438
1
    set.insert(2);
439
1
    assert_eq!(set.iter().copied().collect::<Vec<_>>(), &[1, 2, 3]);
440
1
}