1
use std::collections::HashSet;
2
use std::fmt::Debug;
3
use std::marker::PhantomData;
4

            
5
pub struct LruCache<Key, Value> {
6
    nodes: Vec<Node<Key, Value>>,
7
    head: Option<NodeId>,
8
    tail: Option<NodeId>,
9
    vacant: Option<NodeId>,
10
    sequence: usize,
11
    length: usize,
12
}
13

            
14
impl<Key, Value> LruCache<Key, Value> {
15
15
    pub fn new(capacity: usize) -> Self {
16
15
        Self {
17
15
            nodes: Vec::with_capacity(capacity),
18
15
            head: None,
19
15
            tail: None,
20
15
            vacant: None,
21
15
            sequence: 0,
22
15
            length: 0,
23
15
        }
24
15
    }
25

            
26
14
    pub const fn len(&self) -> usize {
27
14
        self.length
28
14
    }
29

            
30
22
    pub const fn sequence(&self) -> usize {
31
22
        self.sequence
32
22
    }
33

            
34
30
    pub const fn head(&self) -> Option<NodeId> {
35
30
        self.head
36
30
    }
37

            
38
10
    pub const fn tail(&self) -> Option<NodeId> {
39
10
        self.tail
40
10
    }
41

            
42
8
    pub const fn iter(&self) -> Iter<'_, Key, Value> {
43
8
        Iter {
44
8
            cache: self,
45
8
            node: IterState::BeforeHead,
46
8
        }
47
8
    }
48

            
49
2013
    pub fn get(&mut self, node: NodeId) -> &Node<Key, Value> {
50
2013
        self.touch(node);
51
2013
        &self.nodes[node.as_usize()]
52
2013
    }
53

            
54
82
    pub fn get_without_touch(&self, node: NodeId) -> &Node<Key, Value> {
55
82
        &self.nodes[node.as_usize()]
56
82
    }
57

            
58
2002
    pub fn get_mut(&mut self, node: NodeId) -> &mut Node<Key, Value> {
59
2002
        self.touch(node);
60
2002
        &mut self.nodes[node.as_usize()]
61
2002
    }
62

            
63
8059
    pub fn push(&mut self, key: Key, value: Value) -> (NodeId, Option<Removed<Key, Value>>) {
64
8059
        let (node, result) = if self.head.is_some() {
65
8042
            self.push_front(key, value)
66
        } else {
67
            // First node of the list.
68
17
            self.allocate_node(key, value)
69
        };
70
8059
        (
71
8059
            node,
72
8059
            result.map(|(key, value)| Removed::Evicted(key, value)),
73
8059
        )
74
8059
    }
75

            
76
4017
    pub fn touch(&mut self, node_index: NodeId) {
77
4017
        if self.head == Some(node_index) {
78
            // No-op.
79
6
            return;
80
4011
        }
81
4011

            
82
4011
        self.sequence += 1;
83
4011

            
84
4011
        // An entry already exists. Reuse the node.
85
4011
        self.nodes[node_index.as_usize()].last_accessed = self.sequence;
86
4011

            
87
4011
        // Update the next pointer to the current head.
88
4011
        let mut next = self.head;
89
4011
        std::mem::swap(&mut next, &mut self.nodes[node_index.as_usize()].next);
90
4011
        // Get and clear the previous node, as this node is going to be the new
91
4011
        // head.
92
4011
        let previous = self.nodes[node_index.as_usize()].previous.take().unwrap();
93
4011
        // Update the previous pointer's next to the previous next value.
94
4011
        self.nodes[previous.as_usize()].next = next;
95
4011
        if self.tail == Some(node_index) {
96
22
            // If this is the tail, update the tail to the previous node.
97
22
            self.tail = Some(previous);
98
3989
        } else {
99
3989
            // Otherwise, we need to update the next node's previous to point to
100
3989
            // this node's former previous.
101
3989
            self.nodes[next.unwrap().as_usize()].previous = Some(previous);
102
3989
        }
103

            
104
        // Move this node to the front
105
4011
        self.nodes[self.head.unwrap().as_usize()].previous = Some(node_index);
106
4011

            
107
4011
        self.head = Some(node_index);
108
4017
    }
109

            
110
8042
    fn push_front(&mut self, key: Key, value: Value) -> (NodeId, Option<(Key, Value)>) {
111
8042
        let (node, removed) = self.allocate_node(key, value);
112
8042
        self.sequence += 1;
113
8042
        let mut entry = &mut self.nodes[node.as_usize()];
114
8042
        entry.last_accessed = self.sequence;
115
8042
        entry.next = Some(self.head.unwrap());
116
8042

            
117
8042
        let mut previous_head = &mut self.nodes[self.head.unwrap().as_usize()];
118
8042
        debug_assert!(previous_head.previous.is_none());
119
8042
        previous_head.previous = Some(node);
120
8042
        self.head = Some(node);
121
8042
        (node, removed)
122
8042
    }
123

            
124
    fn allocate_node(&mut self, key: Key, value: Value) -> (NodeId, Option<(Key, Value)>) {
125
8059
        if let Some(vacant) = self.vacant {
126
            // Pull a node off the vacant list.
127
6
            self.vacant = self.nodes[vacant.as_usize()].next;
128
6
            self.nodes[vacant.as_usize()].next = None;
129
6
            self.nodes[vacant.as_usize()].entry = Entry::Occupied { key, value };
130
6
            self.length += 1;
131
6
            if self.head.is_none() {
132
2
                self.head = Some(vacant);
133
2
                self.tail = Some(vacant);
134
4
            }
135
6
            (vacant, None)
136
8053
        } else if self.nodes.len() == self.nodes.capacity() {
137
            // Expire the least recently used key (tail).
138
4012
            let index = self.tail.unwrap();
139
4012
            self.tail = self.nodes[index.as_usize()].previous;
140
4012
            if let Some(previous) = self.tail {
141
4012
                self.nodes[previous.as_usize()].next = None;
142
4012
            }
143
4012
            self.nodes[index.as_usize()].previous = None;
144
4012

            
145
4012
            let mut entry = Entry::Occupied { key, value };
146
4012
            std::mem::swap(&mut entry, &mut self.nodes[index.as_usize()].entry);
147
4012

            
148
4012
            (index, entry.into())
149
        } else {
150
            // We have capacity to fill.
151
4041
            let index = NodeId(self.nodes.len() as u32);
152
4041
            self.length += 1;
153
4041
            self.nodes.push(Node {
154
4041
                last_accessed: self.sequence,
155
4041
                previous: None,
156
4041
                next: None,
157
4041
                entry: Entry::Occupied { key, value },
158
4041
            });
159
4041
            if self.head.is_none() {
160
15
                self.head = Some(index);
161
15
                self.tail = Some(index);
162
4026
            }
163
4041
            (index, None)
164
        }
165
8059
    }
166

            
167
22
    pub fn remove(&mut self, node: NodeId) -> ((Key, Value), Option<NodeId>, Option<NodeId>) {
168
22
        self.length -= 1;
169
22
        let removed = self.nodes[node.as_usize()].entry.evict();
170
22
        let mut next = self.vacant;
171
22
        std::mem::swap(&mut next, &mut self.nodes[node.as_usize()].next);
172
22
        let previous = self.nodes[node.as_usize()].previous.take();
173

            
174
22
        if let Some(previous) = previous {
175
4
            self.nodes[previous.as_usize()].next = next;
176
18
        }
177
22
        if let Some(next) = next {
178
12
            self.nodes[next.as_usize()].previous = previous;
179
12
        }
180

            
181
22
        if self.tail == Some(node) {
182
10
            self.tail = previous;
183
12
        }
184

            
185
22
        if self.head == Some(node) {
186
18
            self.head = next;
187
18
        }
188

            
189
22
        self.vacant = Some(node);
190
22

            
191
22
        (removed, next, previous)
192
22
    }
193
}
194

            
195
impl<Key, Value> Debug for LruCache<Key, Value>
196
where
197
    Key: Debug,
198
    Value: Debug,
199
{
200
2
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201
2
        let mut list = f.debug_list();
202
2
        if let Some(head) = self.head {
203
2
            let mut seen_nodes = HashSet::new();
204
2
            let mut current_node = head;
205
2
            let mut end_found = false;
206
4
            while seen_nodes.insert(current_node) {
207
4
                let node = &self.nodes[current_node.as_usize()];
208
4
                list.entry(node);
209
4
                current_node = if let Some(next) = node.next {
210
2
                    next
211
                } else {
212
2
                    end_found = true;
213
2
                    break;
214
                };
215
            }
216

            
217
2
            assert!(end_found, "cycle detected");
218
        }
219

            
220
2
        list.finish()
221
2
    }
222
}
223

            
224
#[derive(Debug)]
225
enum Entry<Key, Value> {
226
    Occupied { key: Key, value: Value },
227
    Vacant,
228
}
229

            
230
impl<Key, Value> Entry<Key, Value> {
231
22
    fn evict(&mut self) -> (Key, Value) {
232
22
        let mut entry = Self::Vacant;
233
22
        std::mem::swap(&mut entry, self);
234
22
        match entry {
235
22
            Self::Occupied { key, value } => (key, value),
236
            Self::Vacant => unreachable!("evict called on a vacant entry"),
237
        }
238
22
    }
239
}
240

            
241
impl<Key, Value> From<Entry<Key, Value>> for Option<(Key, Value)> {
242
4012
    fn from(entry: Entry<Key, Value>) -> Self {
243
4012
        match entry {
244
4012
            Entry::Occupied { key, value } => Some((key, value)),
245
            Entry::Vacant => None,
246
        }
247
4012
    }
248
}
249

            
250
pub struct Node<Key, Value> {
251
    entry: Entry<Key, Value>,
252
    previous: Option<NodeId>,
253
    next: Option<NodeId>,
254
    last_accessed: usize,
255
}
256

            
257
impl<Key, Value> Debug for Node<Key, Value>
258
where
259
    Key: Debug,
260
    Value: Debug,
261
{
262
4
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263
4
        let mut debug = f.debug_struct("Node");
264

            
265
4
        if let Entry::Occupied { key, value } = &self.entry {
266
4
            debug.field("key", key);
267
4
            debug.field("value", value);
268
4
        }
269
4
        debug.field("last_accessed", &self.last_accessed);
270
4

            
271
4
        debug.finish()
272
4
    }
273
}
274

            
275
impl<Key, Value> Node<Key, Value> {
276
8
    pub const fn last_accessed(&self) -> usize {
277
8
        self.last_accessed
278
8
    }
279

            
280
94
    pub fn key(&self) -> &Key {
281
94
        match &self.entry {
282
94
            Entry::Occupied { key, .. } => key,
283
            Entry::Vacant => unreachable!("EntryRef can't be made against Vacant"),
284
        }
285
94
    }
286

            
287
2092
    pub fn value(&self) -> &Value {
288
2092
        match &self.entry {
289
2092
            Entry::Occupied { value, .. } => value,
290
            Entry::Vacant => unreachable!("EntryRef can't be made against Vacant"),
291
        }
292
2092
    }
293

            
294
    pub fn value_mut(&mut self) -> &mut Value {
295
        match &mut self.entry {
296
            Entry::Occupied { value, .. } => value,
297
            Entry::Vacant => unreachable!("EntryRef can't be made against Vacant"),
298
        }
299
    }
300

            
301
2002
    pub fn replace_value(&mut self, mut new_value: Value) -> Value {
302
2002
        match &mut self.entry {
303
2002
            Entry::Occupied { value, .. } => {
304
2002
                std::mem::swap(value, &mut new_value);
305
2002
                new_value
306
            }
307
            Entry::Vacant => unreachable!("EntryRef can't be made against Vacant"),
308
        }
309
2002
    }
310
}
311

            
312
/// A reference to an entry in a Least Recently Used map.
313
#[derive(Debug)]
314
pub struct EntryRef<'a, Cache, Key, Value>
315
where
316
    Cache: EntryCache<Key, Value>,
317
{
318
    cache: &'a mut Cache,
319
    node: NodeId,
320
    accessed: bool,
321
    _phantom: PhantomData<(Key, Value)>,
322
}
323

            
324
pub trait EntryCache<Key, Value> {
325
    fn cache(&self) -> &LruCache<Key, Value>;
326
    fn cache_mut(&mut self) -> &mut LruCache<Key, Value>;
327
    fn remove(&mut self, node: NodeId) -> ((Key, Value), Option<NodeId>, Option<NodeId>);
328
}
329

            
330
8072
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
331
#[repr(transparent)]
332
pub struct NodeId(u32);
333

            
334
impl NodeId {
335
60492
    const fn as_usize(self) -> usize {
336
60492
        self.0 as usize
337
60492
    }
338
}
339

            
340
impl<'a, Cache, Key, Value> EntryRef<'a, Cache, Key, Value>
341
where
342
    Cache: EntryCache<Key, Value>,
343
{
344
39
    pub(crate) fn new(cache: &'a mut Cache, node: NodeId) -> Self {
345
39
        Self {
346
39
            node,
347
39
            cache,
348
39
            accessed: false,
349
39
            _phantom: PhantomData,
350
39
        }
351
39
    }
352

            
353
    /// Returns the key of this entry.
354
    #[must_use]
355
27
    pub fn key(&self) -> &Key {
356
27
        self.cache.cache().get_without_touch(self.node).key()
357
27
    }
358

            
359
    /// Returns the value of this entry.
360
    ///
361
    /// This function touches the key, making it the most recently used key.
362
    /// This function only touches the key once. Subsequent calls will return
363
    /// the value without touching the key. This remains true until
364
    /// `move_next()` or `move_previous()` are invoked.
365
    #[must_use]
366
2
    pub fn value(&mut self) -> &Value {
367
2
        if !self.accessed {
368
2
            self.accessed = true;
369
2
            self.touch();
370
2
        }
371
2
        self.cache.cache_mut().get(self.node).value()
372
2
    }
373

            
374
    /// Touches this key, making it the most recently used key.
375
2
    pub fn touch(&mut self) {
376
2
        self.cache.cache_mut().touch(self.node);
377
2
    }
378

            
379
    /// Returns the value of this entry.
380
    ///
381
    /// This function does not touch the key, preserving its current position in
382
    /// the lru cache.
383
    #[must_use]
384
10
    pub fn peek_value(&self) -> &Value {
385
10
        self.cache.cache().get_without_touch(self.node).value()
386
10
    }
387

            
388
    /// Returns the number of changes to the cache since this key was last
389
    /// touched.
390
    #[must_use]
391
14
    pub fn staleness(&self) -> usize {
392
14
        self.cache.cache().sequence().wrapping_sub(
393
14
            self.cache
394
14
                .cache()
395
14
                .get_without_touch(self.node)
396
14
                .last_accessed,
397
14
        )
398
14
    }
399

            
400
    /// Returns an iterator over the least-recently used keys beginning with the
401
    /// current entry.
402
4
    pub fn iter(&self) -> Iter<'_, Key, Value> {
403
4
        Iter {
404
4
            cache: self.cache.cache(),
405
4
            node: IterState::StartingAt(self.node),
406
4
        }
407
4
    }
408

            
409
    /// Updates this reference to point to the next least recently used key in
410
    /// the list. Returns true if a next entry was found, or returns false if
411
    /// the entry is the last entry in the list.
412
    #[must_use]
413
    pub fn move_next(&mut self) -> bool {
414
12
        if let Some(next) = self.cache.cache().get_without_touch(self.node).next {
415
6
            self.node = next;
416
6
            self.accessed = false;
417
6
            true
418
        } else {
419
6
            false
420
        }
421
12
    }
422

            
423
    /// Updates this reference to point to the next most recently used key in
424
    /// the list. Returns true if a previous entry was found, or returns false
425
    /// if the entry is the first entry in the list.
426
    #[must_use]
427
    pub fn move_previous(&mut self) -> bool {
428
8
        if let Some(previous) = self.cache.cache().get_without_touch(self.node).previous {
429
2
            self.node = previous;
430
2
            self.accessed = false;
431
2
            true
432
        } else {
433
6
            false
434
        }
435
8
    }
436

            
437
12
    fn remove_with_direction(mut self, move_next: bool) -> ((Key, Value), Option<Self>) {
438
12
        let (removed, next, previous) = self.cache.remove(self.node);
439
12
        let new_self = match (move_next, next, previous) {
440
2
            (true, Some(next), _) => {
441
2
                self.node = next;
442
2
                Some(self)
443
            }
444
2
            (false, _, Some(previous)) => {
445
2
                self.node = previous;
446
2
                Some(self)
447
            }
448
8
            _ => None,
449
        };
450
12
        (removed, new_self)
451
12
    }
452

            
453
    /// Removes and returns the current entry's key and value.
454
    #[must_use]
455
4
    pub fn take(self) -> (Key, Value) {
456
4
        let (removed, _) = self.remove_with_direction(true);
457
4
        removed
458
4
    }
459

            
460
    /// Removes and returns the current entry's key and value. If this was not
461
    /// the last entry, the next entry's [`EntryRef`] will be returned.
462
    #[must_use]
463
4
    pub fn take_and_move_next(self) -> ((Key, Value), Option<Self>) {
464
4
        self.remove_with_direction(true)
465
4
    }
466

            
467
    /// Removes and returns the current entry's key and value. If this was not
468
    /// the first entry, the previous entry's [`EntryRef`] will be returned.
469
    #[must_use]
470
4
    pub fn take_and_move_previous(self) -> ((Key, Value), Option<Self>) {
471
4
        self.remove_with_direction(false)
472
4
    }
473

            
474
    /// Removes the current entry. If this was not the last entry, the next
475
    /// entry's [`EntryRef`] will be returned.
476
    #[must_use]
477
4
    pub fn remove_moving_next(self) -> Option<Self> {
478
4
        let (_, new_self) = self.take_and_move_next();
479
4
        new_self
480
4
    }
481

            
482
    /// Removes the current entry. If this was not the first entry, the previous
483
    /// entry's [`EntryRef`] will be returned.
484
    #[must_use]
485
4
    pub fn remove_moving_previous(self) -> Option<Self> {
486
4
        let (_, new_self) = self.take_and_move_previous();
487
4
        new_self
488
4
    }
489
}
490

            
491
/// A removed value or entry.
492
8
#[derive(Debug, Eq, PartialEq)]
493
pub enum Removed<Key, Value> {
494
    /// The previously stored value for the key that was written to.
495
    PreviousValue(Value),
496
    /// An entry was evicted to make room for the key that was written to.
497
    Evicted(Key, Value),
498
}
499

            
500
/// A double-ended iterator over a cache's keys and values in order from most
501
/// recently touched to least recently touched.
502
#[must_use]
503
pub struct Iter<'a, Key, Value> {
504
    cache: &'a LruCache<Key, Value>,
505
    node: IterState,
506
}
507

            
508
enum IterState {
509
    BeforeHead,
510
    AfterTail,
511
    StartingAt(NodeId),
512
    Node(NodeId),
513
}
514

            
515
impl<'a, Key, Value> Iterator for Iter<'a, Key, Value> {
516
    type Item = (&'a Key, &'a Value);
517

            
518
56
    fn next(&mut self) -> Option<Self::Item> {
519
56
        let next_node = match self.node {
520
8
            IterState::BeforeHead => self.cache.head,
521
2
            IterState::StartingAt(node) => Some(node),
522
46
            IterState::Node(node) => self.cache.nodes[node.as_usize()].next,
523
            IterState::AfterTail => None,
524
        };
525
56
        if let Some(node_id) = next_node {
526
46
            let node = &self.cache.nodes[node_id.as_usize()];
527
46
            self.node = IterState::Node(node_id);
528
46
            Some((node.key(), node.value()))
529
        } else {
530
10
            self.node = IterState::AfterTail;
531
10
            None
532
        }
533
56
    }
534
}
535
impl<'a, Key, Value> DoubleEndedIterator for Iter<'a, Key, Value> {
536
16
    fn next_back(&mut self) -> Option<Self::Item> {
537
16
        let previous_node = match self.node {
538
2
            IterState::BeforeHead => None,
539
10
            IterState::StartingAt(node) | IterState::Node(node) => {
540
12
                self.cache.nodes[node.as_usize()].previous
541
            }
542
2
            IterState::AfterTail => self.cache.tail,
543
        };
544
16
        if let Some(node_id) = previous_node {
545
12
            let node = &self.cache.nodes[node_id.as_usize()];
546
12
            self.node = IterState::Node(node_id);
547
12
            Some((node.key(), node.value()))
548
        } else {
549
4
            self.node = IterState::BeforeHead;
550
4
            None
551
        }
552
16
    }
553
}
554

            
555
pub struct IntoIter<Key, Value> {
556
    cache: LruCache<Key, Value>,
557
}
558

            
559
impl<Key, Value> From<LruCache<Key, Value>> for IntoIter<Key, Value> {
560
2
    fn from(cache: LruCache<Key, Value>) -> Self {
561
2
        Self { cache }
562
2
    }
563
}
564

            
565
impl<Key, Value> Iterator for IntoIter<Key, Value> {
566
    type Item = (Key, Value);
567

            
568
12
    fn next(&mut self) -> Option<Self::Item> {
569
12
        self.cache.head().map(|node| {
570
10
            let (removed, ..) = self.cache.remove(node);
571
10
            removed
572
12
        })
573
12
    }
574
}