1
#![doc = include_str!(".crate-docs.md")]
2
#![forbid(unsafe_code)]
3
#![warn(
4
    clippy::cargo,
5
    missing_docs,
6
    // clippy::missing_docs_in_private_items,
7
    clippy::pedantic,
8
    future_incompatible,
9
    rust_2018_idioms,
10
)]
11
#![allow(clippy::option_if_let_else, clippy::module_name_repetitions)]
12

            
13
use std::{
14
    ops::{Deref, DerefMut},
15
    pin::Pin,
16
    sync::{
17
        atomic::{AtomicUsize, Ordering},
18
        Arc,
19
    },
20
    task::Poll,
21
    time::{Duration, Instant},
22
};
23

            
24
use event_listener::{Event, EventListener};
25
use futures_util::{FutureExt, Stream};
26
use parking_lot::{RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard};
27

            
28
/// A watchable wrapper for a value.
29
5
#[derive(Default, Debug)]
30
pub struct Watchable<T> {
31
    data: Arc<Data<T>>,
32
}
33

            
34
impl<T> Clone for Watchable<T> {
35
1
    fn clone(&self) -> Self {
36
1
        self.data.watchables.fetch_add(1, Ordering::AcqRel);
37
1
        Self {
38
1
            data: self.data.clone(),
39
1
        }
40
1
    }
41
}
42

            
43
impl<T> Drop for Watchable<T> {
44
16
    fn drop(&mut self) {
45
16
        if self.data.watchables.fetch_sub(1, Ordering::AcqRel) == 1 {
46
15
            // Last watchable
47
15
            self.shutdown();
48
15
        }
49
16
    }
50
}
51

            
52
impl<T> Watchable<T> {
53
    /// Returns a new instance with the initial value provided.
54
10
    pub fn new(initial_value: T) -> Self {
55
10
        Self {
56
10
            data: Arc::new(Data {
57
10
                value: RwLock::new(initial_value),
58
10
                changed: RwLock::new(Some(Event::new())),
59
10
                version: AtomicUsize::new(0),
60
10
                watchers: AtomicUsize::new(0),
61
10
                watchables: AtomicUsize::new(1),
62
10
            }),
63
10
        }
64
10
    }
65

            
66
    /// Returns a new watcher that can monitor for changes to the contained
67
    /// value.
68
89
    pub fn watch(&self) -> Watcher<T> {
69
89
        self.data.watchers.fetch_add(1, Ordering::AcqRel);
70
89
        Watcher {
71
89
            version: AtomicUsize::new(self.data.current_version()),
72
89
            watched: self.data.clone(),
73
89
        }
74
89
    }
75

            
76
    /// Replaces the current value contained and notifies all watching
77
    /// [`Watcher`]s. Returns the previously stored value.
78
4008
    pub fn replace(&self, new_value: T) -> T {
79
4008
        let mut stored = self.data.value.write();
80
4008
        let mut old_value = new_value;
81
4008
        std::mem::swap(&mut *stored, &mut old_value);
82
4008
        self.data.increment_version();
83
4008
        old_value
84
4008
    }
85

            
86
    /// Updates the current value, if it is different from the contained value.
87
    /// Returns `Ok(previous_value)` if `new_value != previous_value`, otherwise
88
    /// returns `Err(new_value)`.
89
    ///
90
    /// # Errors
91
    ///
92
    /// Returns `Err(new_value)` if the currently stored value is equal to `new_value`.
93
20002
    pub fn update(&self, new_value: T) -> Result<T, T>
94
20002
    where
95
20002
        T: PartialEq,
96
20002
    {
97
20002
        let stored = self.data.value.upgradable_read();
98
20002
        if *stored == new_value {
99
3
            Err(new_value)
100
        } else {
101
19999
            let mut stored = RwLockUpgradableReadGuard::upgrade(stored);
102
19999
            let mut old_value = new_value;
103
19999
            std::mem::swap(&mut *stored, &mut old_value);
104
19999
            self.data.increment_version();
105
19999
            Ok(old_value)
106
        }
107
20002
    }
108

            
109
    /// Returns a write guard that allows updating the value. If the inner value
110
    /// is accessed through [`DerefMut::deref_mut()`], all [`Watcher`]s will be
111
    /// notified when the returned guard is dropped.
112
    ///
113
    /// [`WatchableWriteGuard`] holds an exclusive lock. No other threads will
114
    /// be able to read or write the contained value until the guard is dropped.
115
3
    pub fn write(&self) -> WatchableWriteGuard<'_, T> {
116
3
        WatchableWriteGuard {
117
3
            watchable: self,
118
3
            guard: self.data.value.write(),
119
3
            accessed_mut: false,
120
3
        }
121
3
    }
122

            
123
    /// Returns a guard which can be used to access the value held within the
124
    /// variable. This guard does not block other threads from reading the
125
    /// value.
126
1
    pub fn read(&self) -> WatchableReadGuard<'_, T> {
127
1
        WatchableReadGuard(self.data.value.read())
128
1
    }
129

            
130
    /// Returns the currently contained value.
131
    #[must_use]
132
1
    pub fn get(&self) -> T
133
1
    where
134
1
        T: Clone,
135
1
    {
136
1
        self.data.value.read().clone()
137
1
    }
138

            
139
    /// Returns the number of [`Watcher`]s for this value.
140
    #[must_use]
141
6
    pub fn watchers(&self) -> usize {
142
6
        self.data.watchers.load(Ordering::Acquire)
143
6
    }
144

            
145
    /// Returns true if there are any [`Watcher`]s for this value.
146
    #[must_use]
147
3
    pub fn has_watchers(&self) -> bool {
148
3
        self.watchers() > 0
149
3
    }
150

            
151
    /// Disconnects all [`Watcher`]s.
152
    ///
153
    /// All future value updates will not be observed by the watchers, but the
154
    /// last value will still be readable before the watcher signals that it is
155
    /// disconnected.
156
16
    pub fn shutdown(&self) {
157
16
        let mut changed = self.data.changed.write();
158
16
        if let Some(changed) = changed.take() {
159
15
            changed.notify(usize::MAX);
160
15
        }
161
16
    }
162
}
163

            
164
impl<T> Data<T> {
165
65693
    fn current_version(&self) -> usize {
166
65693
        self.version.load(Ordering::Acquire)
167
65693
    }
168

            
169
24008
    fn increment_version(&self) {
170
24008
        self.version.fetch_add(1, Ordering::AcqRel);
171
24008
        let changed = self.changed.read();
172
24008
        if let Some(changed) = changed.as_ref() {
173
24008
            changed.notify(usize::MAX);
174
24008
        }
175
24008
    }
176
}
177

            
178
/// A read guard that allows reading the currently stored value in a
179
/// [`Watchable`]. No values can be stored within the source [`Watchable`] while
180
/// this guard exists.
181
///
182
/// The inner value is accessible through [`Deref`].
183
#[must_use]
184
pub struct WatchableReadGuard<'a, T>(RwLockReadGuard<'a, T>);
185

            
186
impl<'a, T> Deref for WatchableReadGuard<'a, T> {
187
    type Target = T;
188

            
189
16826
    fn deref(&self) -> &Self::Target {
190
16826
        &self.0
191
16826
    }
192
}
193

            
194
/// A write guard that allows updating the currently stored value in a
195
/// [`Watchable`].
196
///
197
/// The inner value is readable through [`Deref`], and modifiable through
198
/// [`DerefMut`]. Any usage of [`DerefMut`] will cause all [`Watcher`]s to be
199
/// notified of an updated value when the guard is dropped.
200
///
201
/// [`WatchableWriteGuard`] is an exclusive guard. No other threads will be
202
/// able to read or write the contained value until the guard is dropped.
203
#[must_use]
204
pub struct WatchableWriteGuard<'a, T> {
205
    watchable: &'a Watchable<T>,
206
    accessed_mut: bool,
207
    guard: RwLockWriteGuard<'a, T>,
208
}
209

            
210
impl<'a, T> Deref for WatchableWriteGuard<'a, T> {
211
    type Target = T;
212

            
213
2
    fn deref(&self) -> &Self::Target {
214
2
        &self.guard
215
2
    }
216
}
217

            
218
impl<'a, T> DerefMut for WatchableWriteGuard<'a, T> {
219
1
    fn deref_mut(&mut self) -> &mut Self::Target {
220
1
        self.accessed_mut = true;
221
1
        &mut self.guard
222
1
    }
223
}
224

            
225
impl<'a, T> Drop for WatchableWriteGuard<'a, T> {
226
3
    fn drop(&mut self) {
227
3
        if self.accessed_mut {
228
1
            self.watchable.data.increment_version();
229
2
        }
230
3
    }
231
}
232

            
233
#[derive(Debug)]
234
struct Data<T> {
235
    changed: RwLock<Option<Event>>,
236
    version: AtomicUsize,
237
    watchers: AtomicUsize,
238
    watchables: AtomicUsize,
239
    value: RwLock<T>,
240
}
241

            
242
impl<T> Default for Data<T>
243
where
244
    T: Default,
245
{
246
5
    fn default() -> Self {
247
5
        Self {
248
5
            changed: RwLock::new(Some(Event::new())),
249
5
            version: AtomicUsize::new(0),
250
5
            watchers: AtomicUsize::new(0),
251
5
            watchables: AtomicUsize::new(1),
252
5
            value: RwLock::default(),
253
5
        }
254
5
    }
255
}
256

            
257
/// An observer of a [`Watchable`] value.
258
///
259
/// ## Cloning behavior
260
///
261
/// Cloning a watcher also clones the current watching state. If the watcher
262
/// hasn't read the value currently stored, the cloned instance will also
263
/// consider the current value unread.
264
#[derive(Debug)]
265
#[must_use]
266
pub struct Watcher<T> {
267
    version: AtomicUsize,
268
    watched: Arc<Data<T>>,
269
}
270

            
271
impl<T> Drop for Watcher<T> {
272
90
    fn drop(&mut self) {
273
90
        self.watched.watchers.fetch_sub(1, Ordering::AcqRel);
274
90
    }
275
}
276

            
277
impl<T> Clone for Watcher<T> {
278
1
    fn clone(&self) -> Self {
279
1
        Self {
280
1
            version: AtomicUsize::new(self.version.load(Ordering::Relaxed)),
281
1
            watched: self.watched.clone(),
282
1
        }
283
1
    }
284
}
285

            
286
#[derive(Debug)]
287
enum CreateListenerError {
288
    NewValueAvailable,
289
    Disconnected,
290
}
291

            
292
/// A watch operation failed because all [`Watchable`] instances have been
293
/// dropped.
294
#[derive(Debug, thiserror::Error, Eq, PartialEq)]
295
#[error("all watchable instances have been dropped")]
296
pub struct Disconnected;
297

            
298
/// A watch operation with a timeout failed.
299
#[derive(Debug, thiserror::Error, Eq, PartialEq)]
300
pub enum TimeoutError {
301
    /// A watch operation failed because all [`Watchable`] instances have been
302
    /// dropped.
303
    #[error("all watchable instances have been dropped")]
304
    Disconnected,
305
    /// No new values were written before the timeout elapsed
306
    #[error("no new values were written before the timeout elapsed")]
307
    Timeout,
308
}
309

            
310
impl<T> Watcher<T> {
311
16774
    fn create_listener_if_needed(&self) -> Result<Pin<Box<EventListener>>, CreateListenerError> {
312
16774
        let changed = self.watched.changed.read();
313
16774
        match (changed.as_ref(), self.is_current()) {
314
596
            (_, false) => Err(CreateListenerError::NewValueAvailable),
315
19
            (None, _) => Err(CreateListenerError::Disconnected),
316
16159
            (Some(changed), true) => {
317
16159
                let listener = changed.listen();
318
16159

            
319
16159
                // Between now and creating the listener, an update may have
320
16159
                // come in, so we need to check again before returning the
321
16159
                // listener.
322
16159
                if self.is_current() {
323
15843
                    Ok(listener)
324
                } else {
325
316
                    Err(CreateListenerError::NewValueAvailable)
326
                }
327
            }
328
        }
329
16774
    }
330

            
331
    /// Returns true if the latest value has been read from this instance.
332
    #[must_use]
333
48774
    pub fn is_current(&self) -> bool {
334
48774
        self.version.load(Ordering::Relaxed) == self.watched.current_version()
335
48774
    }
336

            
337
    /// Updates this instance's state to reflect that it has read the currently
338
    /// stored value. The next call to a watch call will block until the next
339
    /// value is stored.
340
    ///
341
    /// Returns true if the internal state was updated, and false if no changes
342
    /// were necessary.
343
6
    pub fn mark_read(&self) -> bool {
344
6
        let current_version = self.watched.current_version();
345
6
        let mut stored_version = self.version.load(Ordering::Acquire);
346
6
        while stored_version < current_version {
347
2
            match self.version.compare_exchange(
348
2
                stored_version,
349
2
                current_version,
350
2
                Ordering::Release,
351
2
                Ordering::Acquire,
352
2
            ) {
353
2
                Ok(_) => return true,
354
                Err(new_stored) => stored_version = new_stored,
355
            }
356
        }
357
4
        false
358
6
    }
359

            
360
    /// Watches for a new value to be stored in the source [`Watchable`]. If the
361
    /// current value hasn't been accessed through [`Self::read()`] or marked
362
    /// read with [`Self::mark_read()`], this call will block the calling
363
    /// thread until a new value has been published.
364
    ///
365
    /// # Errors
366
    ///
367
    /// Returns [`Disconnected`] if all instances of [`Watchable`] have been
368
    /// dropped and the current value has been read.
369
409
    pub fn watch(&self) -> Result<(), Disconnected> {
370
        loop {
371
409
            match self.create_listener_if_needed() {
372
204
                Ok(mut listener) => {
373
204
                    listener.as_mut().wait();
374
204
                    if !self.is_current() {
375
204
                        break;
376
                    }
377
                }
378
14
                Err(CreateListenerError::Disconnected) => return Err(Disconnected),
379
191
                Err(CreateListenerError::NewValueAvailable) => break,
380
            }
381
        }
382

            
383
395
        Ok(())
384
409
    }
385

            
386
    /// Watches for a new value to be stored in the source [`Watchable`]. If the
387
    /// current value hasn't been accessed through [`Self::read()`] or marked
388
    /// read with [`Self::mark_read()`], this call will block the calling
389
    /// thread until a new value has been published or until `duration` has
390
    /// elapsed.
391
    ///
392
    /// # Errors
393
    ///
394
    /// - [`TimeoutError::Disconnected`]: All instances of [`Watchable`] have
395
    /// been dropped and the current value has been read.
396
    /// - [`TimeoutError::Timeout`]: A timeout occurred before a new value was
397
    /// written.
398
3
    pub fn watch_timeout(&self, duration: Duration) -> Result<(), TimeoutError> {
399
3
        self.watch_until(Instant::now() + duration)
400
3
    }
401

            
402
    /// Watches for a new value to be stored in the source [`Watchable`]. If the
403
    /// current value hasn't been accessed through [`Self::read()`] or marked
404
    /// read with [`Self::mark_read()`], this call will block the calling
405
    /// thread until a new value has been published or until `deadline`.
406
    ///
407
    /// # Errors
408
    ///
409
    /// - [`TimeoutError::Disconnected`]: All instances of [`Watchable`] have
410
    /// been dropped and the current value has been read.
411
    /// - [`TimeoutError::Timeout`]: A timeout occurred before a new value was
412
    /// written.
413
6
    pub fn watch_until(&self, deadline: Instant) -> Result<(), TimeoutError> {
414
        loop {
415
8
            match self.create_listener_if_needed() {
416
4
                Ok(mut listener) => {
417
4
                    if listener.as_mut().wait_deadline(deadline).is_some() {
418
2
                        if !self.is_current() {
419
                            break;
420
2
                        } else if Instant::now() < deadline {
421
2
                            // Spurious wake-up
422
2
                        }
423
                    } else {
424
2
                        return Err(TimeoutError::Timeout);
425
                    }
426
                }
427
2
                Err(CreateListenerError::Disconnected) => return Err(TimeoutError::Disconnected),
428
2
                Err(CreateListenerError::NewValueAvailable) => break,
429
            }
430
        }
431

            
432
2
        Ok(())
433
6
    }
434

            
435
    /// Watches for a new value to be stored in the source [`Watchable`]. If the
436
    /// current value hasn't been accessed through [`Self::read()`] or marked
437
    /// read with [`Self::mark_read()`], the async task will block until
438
    /// a new value has been published.
439
    ///
440
    /// # Errors
441
    ///
442
    /// Returns [`Disconnected`] if all instances of [`Watchable`] have been
443
    /// dropped and the current value has been read.
444
16340
    pub async fn watch_async(&self) -> Result<(), Disconnected> {
445
        loop {
446
16340
            match self.create_listener_if_needed() {
447
15625
                Ok(listener) => {
448
15625
                    listener.await;
449
15625
                    if !self.is_current() {
450
15625
                        break;
451
                    }
452
                }
453
                Err(CreateListenerError::Disconnected) => return Err(Disconnected),
454
715
                Err(CreateListenerError::NewValueAvailable) => break,
455
            }
456
        }
457
16340
        Ok(())
458
16340
    }
459

            
460
    /// Returns a read guard that allows reading the currently stored value.
461
    /// This function does not consider the value read, and the next call to a
462
    /// watch function will be unaffected.
463
1
    pub fn peek(&self) -> WatchableReadGuard<'_, T> {
464
1
        let guard = self.watched.value.read();
465
1
        WatchableReadGuard(guard)
466
1
    }
467

            
468
    /// Returns a read guard that allows reading the currently stored value.
469
    /// This function marks the stored value as read, ensuring that the next
470
    /// call to a watch function will block until the a new value has been
471
    /// published.
472
16824
    pub fn read(&self) -> WatchableReadGuard<'_, T> {
473
16824
        let guard = self.watched.value.read();
474
16824
        self.version
475
16824
            .store(self.watched.current_version(), Ordering::Relaxed);
476
16824
        WatchableReadGuard(guard)
477
16824
    }
478

            
479
    /// Returns the currently contained value. This function marks the stored
480
    /// value as read, ensuring that the next call to a watch function will
481
    /// block until the a new value has been published.
482
    #[must_use]
483
1
    pub fn get(&self) -> T
484
1
    where
485
1
        T: Clone,
486
1
    {
487
1
        self.read().clone()
488
1
    }
489

            
490
    /// Watches for a new value to be stored in the source [`Watchable`] and
491
    /// returns a clone of it. If the current value hasn't been accessed through
492
    /// [`Self::read()`] or marked read with [`Self::mark_read()`], this call
493
    /// will block the calling thread until a new value has been published.
494
    ///
495
    /// # Errors
496
    ///
497
    /// Returns [`Disconnected`] if all instances of [`Watchable`] have been
498
    /// dropped and the current value has been read.
499
17
    pub fn next_value(&self) -> Result<T, Disconnected>
500
17
    where
501
17
        T: Clone,
502
17
    {
503
17
        self.watch().map(|()| self.read().clone())
504
17
    }
505

            
506
    /// Watches for a new value to be stored in the source [`Watchable`] and
507
    /// returns a clone of it. If the current value hasn't been accessed through
508
    /// [`Self::read()`] or marked read with [`Self::mark_read()`], this call
509
    /// will asynchronously wait for a new value to be published.
510
    ///
511
    /// The async task is safe to be cancelled without losing track of the last
512
    /// read value.
513
    ///
514
    /// # Errors
515
    ///
516
    /// Returns [`Disconnected`] if all instances of [`Watchable`] have been
517
    /// dropped and the current value has been read.
518
1
    pub async fn next_value_async(&self) -> Result<T, Disconnected>
519
1
    where
520
1
        T: Clone,
521
1
    {
522
1
        self.watch_async().await.map(|()| self.read().clone())
523
1
    }
524

            
525
    /// Returns this watcher in a type that implements [`Stream`].
526
2
    pub fn into_stream(self) -> WatcherStream<T> {
527
2
        WatcherStream {
528
2
            watcher: self,
529
2
            listener: None,
530
2
        }
531
2
    }
532
}
533

            
534
impl<T> Iterator for Watcher<T>
535
where
536
    T: Clone,
537
{
538
    type Item = T;
539

            
540
8
    fn next(&mut self) -> Option<Self::Item> {
541
8
        self.next_value().ok()
542
8
    }
543
}
544

            
545
/// Asynchronous iterator for a [`Watcher`]. Implements [`Stream`].
546
#[derive(Debug)]
547
#[must_use]
548
pub struct WatcherStream<T> {
549
    watcher: Watcher<T>,
550
    listener: Option<Pin<Box<EventListener>>>,
551
}
552

            
553
impl<T> WatcherStream<T> {
554
    /// Returns the wrapped [`Watcher`].
555
1
    pub fn into_inner(self) -> Watcher<T> {
556
1
        self.watcher
557
1
    }
558
}
559

            
560
impl<T> Stream for WatcherStream<T>
561
where
562
    T: Clone,
563
{
564
    type Item = T;
565

            
566
26
    fn poll_next(
567
26
        mut self: std::pin::Pin<&mut Self>,
568
26
        cx: &mut std::task::Context<'_>,
569
26
    ) -> Poll<Option<Self::Item>> {
570
        // If we have a listener or we have already read the current value, we
571
        // need to poll the listener as a future first.
572
        loop {
573
27
            match self
574
27
                .listener
575
27
                .take()
576
27
                .ok_or(CreateListenerError::Disconnected)
577
27
                .or_else(|_| self.watcher.create_listener_if_needed())
578
            {
579
20
                Ok(mut listener) => {
580
20
                    match listener.poll_unpin(cx) {
581
                        Poll::Ready(()) => {
582
10
                            if !self.watcher.is_current() {
583
9
                                break;
584
1
                            }
585

            
586
                            // A new value is available. Fall through.
587
                        }
588
                        Poll::Pending => {
589
                            // The listener wasn't ready, store it again.
590
10
                            self.listener = Some(listener);
591
10
                            return Poll::Pending;
592
                        }
593
                    }
594
                }
595
4
                Err(CreateListenerError::NewValueAvailable) => break,
596
3
                Err(CreateListenerError::Disconnected) => return Poll::Ready(None),
597
            }
598
        }
599

            
600
13
        Poll::Ready(Some(self.watcher.read().clone()))
601
26
    }
602
}
603

            
604
1
#[test]
605
1
fn basics() {
606
1
    let watchable = Watchable::new(1_u32);
607
1
    assert!(!watchable.has_watchers());
608
1
    let watcher1 = watchable.watch();
609
1
    let watcher2 = watchable.watch();
610
1
    assert!(!watcher1.mark_read());
611

            
612
1
    assert_eq!(watchable.watchers(), 2);
613
1
    assert_eq!(watchable.replace(2), 1);
614
    // A call to watch should not block since the value has already been sent
615
1
    watcher1.watch().unwrap();
616
1
    // Peek shouldn't cause the watcher to block.
617
1
    assert_eq!(*watcher1.peek(), 2);
618
1
    watcher1.watch().unwrap();
619
1
    // Reading should switch the state back to needing to block, which we'll
620
1
    // test in an other unit test
621
1
    assert_eq!(*watcher1.read(), 2);
622
1
    assert!(!watcher1.mark_read());
623
1
    drop(watcher1);
624
1
    assert_eq!(watchable.watchers(), 1);
625

            
626
    // Now, despite watcher1 having updated, watcher2 should have independent state
627
1
    assert!(watcher2.mark_read());
628
1
    assert_eq!(*watcher2.read(), 2);
629
1
    drop(watcher2);
630
1
    assert_eq!(watchable.watchers(), 0);
631
1
}
632

            
633
1
#[test]
634
1
fn accessing_values() {
635
1
    let watchable = Watchable::new(String::from("hello"));
636
1
    assert_eq!(watchable.get(), "hello");
637
1
    assert_eq!(&*watchable.read(), "hello");
638
1
    assert_eq!(&*watchable.write(), "hello");
639

            
640
1
    let watcher = watchable.watch();
641
1
    assert_eq!(watcher.get(), "hello");
642
1
    assert_eq!(&*watcher.read(), "hello");
643
1
}
644

            
645
1
#[test]
646
#[allow(clippy::redundant_clone)]
647
1
fn clones() {
648
1
    let watchable = Watchable::default();
649
1
    let cloned_watchable = watchable.clone();
650
1
    let watcher1 = watchable.watch();
651
1
    let watcher2 = watcher1.clone();
652
1

            
653
1
    watchable.replace(1);
654
1
    assert_eq!(watcher1.next_value().unwrap(), 1);
655
1
    assert_eq!(watcher2.next_value().unwrap(), 1);
656
1
    cloned_watchable.replace(2);
657
1
    assert_eq!(watcher1.next_value().unwrap(), 2);
658
1
    assert_eq!(watcher2.next_value().unwrap(), 2);
659
1
}
660

            
661
1
#[test]
662
1
fn drop_watchable() {
663
1
    let watchable = Watchable::default();
664
1
    assert!(!watchable.has_watchers());
665
1
    let watcher = watchable.watch();
666
1

            
667
1
    watchable.replace(1_u32);
668
1
    assert_eq!(watcher.next_value().unwrap(), 1);
669

            
670
1
    drop(watchable);
671
1
    assert!(matches!(watcher.next_value().unwrap_err(), Disconnected));
672
1
}
673

            
674
1
#[test]
675
1
fn drop_watchable_timeouts() {
676
1
    let watchable = Watchable::new(0_u8);
677
1
    assert!(!watchable.has_watchers());
678
1
    let watcher = watchable.watch();
679
1
    let start = Instant::now();
680
1
    let wait_timeout_thread = std::thread::spawn(move || {
681
1
        assert!(matches!(
682
1
            watcher.watch_timeout(Duration::from_secs(15)).unwrap_err(),
683
            TimeoutError::Disconnected
684
        ));
685
1
    });
686
1
    let watcher = watchable.watch();
687
1
    let wait_until_thread = std::thread::spawn(move || {
688
1
        assert!(matches!(
689
1
            watcher
690
1
                .watch_until(Instant::now().checked_add(Duration::from_secs(15)).unwrap())
691
1
                .unwrap_err(),
692
            TimeoutError::Disconnected
693
        ));
694
1
    });
695
1

            
696
1
    // Give time for the threads to spawn.
697
1
    std::thread::sleep(Duration::from_millis(100));
698
1
    drop(watchable);
699
1

            
700
1
    wait_timeout_thread.join().unwrap();
701
1
    wait_until_thread.join().unwrap();
702
1

            
703
1
    let elapsed = Instant::now().checked_duration_since(start).unwrap();
704
1
    assert!(elapsed.as_secs() < 1);
705
1
}
706

            
707
1
#[test]
708
1
fn timeouts() {
709
1
    let watchable = Watchable::new(1_u32);
710
1
    let watcher = watchable.watch();
711
1
    let start = Instant::now();
712
1
    assert!(matches!(
713
1
        watcher.watch_timeout(Duration::from_millis(100)),
714
        Err(TimeoutError::Timeout)
715
    ));
716
1
    assert!(matches!(
717
1
        watcher.watch_until(Instant::now() + Duration::from_millis(100)),
718
        Err(TimeoutError::Timeout)
719
    ));
720
1
    let elapsed = Instant::now().checked_duration_since(start).unwrap();
721
    // We don't control the delay logic, so to ensure this test is stable, we're
722
    // comparing against a duration slightly less than 200 ms even though in
723
    // theory that shouldn't be possible.
724
1
    assert!(elapsed.as_millis() >= 180);
725

            
726
    // Test that watch_timeout/until return true when a new event is available
727
1
    watchable.replace(2);
728
1
    watcher.watch_timeout(Duration::from_secs(1)).unwrap();
729
1
    watchable.replace(3);
730
1
    watcher
731
1
        .watch_until(Instant::now() + Duration::from_secs(1))
732
1
        .unwrap();
733
1
}
734

            
735
1
#[test]
736
1
fn deref_publish() {
737
1
    let watchable = Watchable::new(1_u32);
738
1
    let watcher = watchable.watch();
739
1
    // Reading the value (Deref) shouldn't publish a new value
740
1
    {
741
1
        let write_guard = watchable.write();
742
1
        assert_eq!(*write_guard, 1);
743
    }
744
1
    assert!(!watcher.mark_read());
745
    // Writing a value (DerefMut) should publish a new value
746
1
    {
747
1
        let mut write_guard = watchable.write();
748
1
        *write_guard = 2;
749
1
    }
750
1
    assert!(watcher.mark_read());
751
1
}
752

            
753
1
#[test]
754
1
fn blocking_tests() {
755
1
    let watchable = Watchable::new(1_u32);
756
1
    let watcher = watchable.watch();
757
1
    let (sender, receiver) = std::sync::mpsc::sync_channel(1);
758
1
    let worker_thread = std::thread::spawn(move || {
759
1
        watcher.watch().unwrap();
760
1
        assert_eq!(*watcher.read(), 2);
761
1
        sender.send(()).unwrap();
762
1
        watcher.watch().unwrap();
763
1
        *watcher.read()
764
1
    });
765
1

            
766
1
    watchable.replace(2);
767
1
    // Wait for the thread to perform its read.
768
1
    receiver.recv().unwrap();
769
1
    assert!(watchable.update(42).is_ok());
770
1
    assert!(watchable.update(42).is_err());
771

            
772
1
    assert_eq!(worker_thread.join().unwrap(), 42);
773
1
}
774

            
775
1
#[test]
776
1
fn iterator_test() {
777
1
    let watchable = Watchable::new(1_u32);
778
1
    let watcher = watchable.watch();
779
1
    let worker_thread = std::thread::spawn(move || {
780
1
        let mut last_value = watcher.next_value().unwrap();
781
1
        for value in watcher {
782
            assert_ne!(last_value, value);
783
            println!("Received {value}");
784
            last_value = value;
785
        }
786
1
        assert_eq!(last_value, 1000);
787
1
    });
788

            
789
1001
    for i in 1..=1000 {
790
1000
        watchable.replace(i);
791
1000
    }
792
1
    drop(watchable);
793
1
    worker_thread.join().unwrap();
794
1
}
795

            
796
#[cfg(test)]
797
1
#[tokio::test(flavor = "multi_thread")]
798
1
async fn stream_test() {
799
1
    use futures_util::StreamExt;
800
1

            
801
1
    let watchable = Watchable::default();
802
1
    let watcher = watchable.watch();
803
1
    let worker_thread = tokio::task::spawn(async move {
804
1
        let mut last_value = watcher.next_value_async().await.unwrap();
805
1
        let mut stream = watcher.into_stream();
806
10
        while let Some(value) = stream.next().await {
807
9
            assert_ne!(last_value, value);
808
9
            println!("Received {value}");
809
9
            last_value = value;
810
        }
811
1
        assert_eq!(last_value, 1000);
812

            
813
        // Ensure it's safe to call next again with no blocking and no panics.
814
1
        assert!(stream.next().await.is_none());
815

            
816
        // Convert back to a normal watcher and check that the state still
817
        // matches.
818
1
        let watcher = stream.into_inner();
819
1
        assert!(!watcher.mark_read());
820
1
    });
821

            
822
1001
    for i in 1..=1000 {
823
1000
        watchable.replace(i);
824
1000
        if i % 100 == 0 {
825
10
            tokio::time::sleep(Duration::from_millis(10)).await;
826
990
        }
827
    }
828

            
829
    // Allow the stream to end
830
1
    drop(watchable);
831
1

            
832
1
    // Wait for the task to finish.
833
1
    worker_thread.await.unwrap();
834
}
835

            
836
1
#[test]
837
1
fn stress_test() {
838
1
    let watchable = Watchable::new(1_u32);
839
1
    let mut workers = Vec::new();
840
11
    for _ in 1..=10 {
841
10
        let watcher = watchable.watch();
842
10
        workers.push(std::thread::spawn(move || {
843
10
            let mut last_value = *watcher.read();
844
388
            while watcher.watch().is_ok() {
845
378
                let current_value = *watcher.read();
846
378
                assert_ne!(last_value, current_value);
847
378
                last_value = current_value;
848
            }
849
10
            assert_eq!(last_value, 10000);
850
10
        }));
851
10
    }
852

            
853
10001
    for i in 1..=10000 {
854
10000
        let _ = watchable.update(i);
855
10000
    }
856
1
    drop(watchable);
857

            
858
11
    for worker in workers {
859
10
        worker.join().unwrap();
860
10
    }
861
1
}
862

            
863
#[cfg(test)]
864
1
#[tokio::test(flavor = "multi_thread")]
865
1
async fn stress_test_async() {
866
1
    let watchable = Watchable::new(1_u32);
867
1
    let mut workers = Vec::new();
868
65
    for _ in 1..=64 {
869
64
        let watcher = watchable.watch();
870
64
        workers.push(tokio::task::spawn(async move {
871
64
            let mut last_value = *watcher.read();
872
            loop {
873
16339
                watcher.watch_async().await.unwrap();
874
16339
                let current_value = *watcher.read();
875
16339
                assert_ne!(last_value, current_value);
876
16339
                if current_value == 10000 {
877
64
                    break;
878
16275
                }
879
16275
                last_value = current_value;
880
            }
881
64
        }));
882
64
    }
883

            
884
1
    tokio::task::spawn_blocking(move || {
885
10001
        for i in 1..=10000 {
886
10000
            let _ = watchable.update(i);
887
10000
        }
888
1
    })
889
1
    .await
890
1
    .unwrap();
891

            
892
65
    for worker in workers {
893
64
        worker.await.unwrap();
894
    }
895
}
896

            
897
1
#[test]
898
1
fn shutdown() {
899
1
    let watchable = Watchable::new(0);
900
1
    let watcher = watchable.watch();
901
1

            
902
1
    // Set a new value, then shutdown
903
1
    watchable.replace(1);
904
1
    watchable.shutdown();
905
1

            
906
1
    // The value should still be accessible
907
1
    assert_eq!(watcher.next_value().expect("initial value missing"), 1);
908
1
    watcher
909
1
        .next_value()
910
1
        .expect_err("watcher should be disconnected");
911
1
}