1
use std::{
2
    sync::{Arc, Condvar, Mutex},
3
    task::Waker,
4
};
5

            
6
use crate::{sealed::BudgetableSealed, Budgetable};
7

            
8
/// An atomic budget storage that can be replenished by other threads or tasks
9
/// than the one driving the budgeted task.
10
30
#[derive(Clone, Debug, Default)]
11
pub struct ReplenishableBudget {
12
    data: Arc<Data>,
13
}
14

            
15
2
#[derive(Debug, Default)]
16
struct State {
17
    generation: usize,
18
    denied_budget_at_generation: Option<usize>,
19
    budget: usize,
20
    wakers: Vec<Waker>,
21
}
22

            
23
2
#[derive(Debug, Default)]
24
struct Data {
25
    sync: Condvar,
26
    state: Mutex<State>,
27
}
28

            
29
impl ReplenishableBudget {
30
    /// Adds `amount` to the budget. This will wake up the task if it is
31
    /// currently waiting for additional budget.
32
23217
    pub fn replenish(&self, amount: usize) {
33
23217
        let mut state = self.data.state.lock().expect("poisoned");
34
23217
        state.generation = state.generation.wrapping_add(1);
35
23217
        state.denied_budget_at_generation = None;
36
23217
        state.budget = state.budget.saturating_add(amount);
37

            
38
1848779
        for waker in &state.wakers {
39
1848779
            waker.wake_by_ref();
40
1848779
        }
41
23217
        drop(state);
42
23217
        self.data.sync.notify_all();
43
23217
    }
44

            
45
    /// Returns the remaining budget.
46
    #[must_use]
47
3
    pub fn remaining(&self) -> usize {
48
3
        self.data.state.lock().expect("poisoned").budget
49
3
    }
50
}
51

            
52
impl ReplenishableBudget {
53
    /// Returns a new instance with the intiial budget provided.
54
    #[must_use]
55
6
    pub fn new(initial_budget: usize) -> Self {
56
6
        Self {
57
6
            data: Arc::new(Data {
58
6
                sync: Condvar::new(),
59
6
                state: Mutex::new(State {
60
6
                    generation: 0,
61
6
                    denied_budget_at_generation: None,
62
6
                    budget: initial_budget,
63
6
                    wakers: Vec::default(),
64
6
                }),
65
6
            }),
66
6
        }
67
6
    }
68
}
69

            
70
impl Budgetable for ReplenishableBudget {}
71

            
72
impl BudgetableSealed for ReplenishableBudget {
73
    fn get(&self) -> usize {
74
        self.remaining()
75
    }
76

            
77
519205
    fn spend(&mut self, amount: usize) -> bool {
78
519205
        let mut state = self.data.state.lock().expect("poisoned");
79
519205
        if let Some(remaining) = state.budget.checked_sub(amount) {
80
23225
            state.denied_budget_at_generation = None;
81
23225
            state.budget = remaining;
82
23225
            true
83
        } else {
84
495980
            state.denied_budget_at_generation = Some(state.generation);
85
495980
            false
86
        }
87
519205
    }
88

            
89
    fn replenish(&mut self, amount: usize) {
90
        ReplenishableBudget::replenish(self, amount);
91
    }
92

            
93
530700
    fn add_waker(&self, new_waker: &std::task::Waker) {
94
530700
        let mut state = self.data.state.lock().expect("poisoned");
95

            
96
530700
        if let Some((_, waker)) = state
97
530700
            .wakers
98
530700
            .iter_mut()
99
530700
            .enumerate()
100
22820386
            .find(|(_, waker)| waker.will_wake(new_waker))
101
481286
        {
102
481286
            *waker = new_waker.clone();
103
481286
        } else {
104
49414
            state.wakers.push(new_waker.clone());
105
49414
        }
106
530700
    }
107

            
108
66630
    fn remove_waker(&self, reference: &std::task::Waker) {
109
66630
        let mut state = self.data.state.lock().expect("poisoned");
110
66630
        if let Some((index, _)) = state
111
66630
            .wakers
112
66630
            .iter()
113
66630
            .enumerate()
114
4118047
            .find(|(_, waker)| waker.will_wake(reference))
115
49389
        {
116
49389
            state.wakers.remove(index);
117
49389
        }
118
66630
    }
119

            
120
13021
    fn park_for_budget(&self) {
121
13021
        let mut state = self.data.state.lock().expect("poisoned");
122
        loop {
123
            // If the budget hasn't been denied or the generation has changed
124
            // since the budget was denied, do not park.
125
41084
            if state.denied_budget_at_generation.is_none()
126
28063
                || state.denied_budget_at_generation != Some(state.generation)
127
            {
128
13021
                break;
129
28063
            }
130
28063

            
131
28063
            // Park the thread using a condvar.
132
28063
            state = self.data.sync.wait(state).expect("poisoned");
133
        }
134
13021
    }
135
}