1
mod blocking {
2
    mod singlethreaded {
3

            
4
        use std::{
5
            cell::Cell,
6
            rc::Rc,
7
            time::{Duration, Instant},
8
        };
9

            
10
        use crate::blocking::singlethreaded::{Progress, Runtime};
11

            
12
1
        #[test]
13
1
        fn basic() {
14
1
            let counter = Rc::new(Cell::new(0));
15
1
            let future_counter = counter.clone();
16
1
            let future = |runtime: Runtime<usize>| async move {
17
1
                future_counter.set(0);
18
1
                runtime.spend(1).await;
19
1
                future_counter.set(1);
20
1
                runtime.spend(1).await;
21
1
                future_counter.set(2);
22
1
            };
23

            
24
1
            let incomplete = match Runtime::run_with_budget(future, 0) {
25
1
                Progress::NoBudget(incomplete) => {
26
1
                    assert_eq!(counter.get(), 0);
27
1
                    incomplete
28
                }
29
                Progress::Complete(result) => unreachable!("future completed: {result:?}"),
30
            };
31
1
            let incomplete = match incomplete.continue_with_additional_budget(1) {
32
1
                Progress::NoBudget(incomplete) => {
33
1
                    assert_eq!(counter.get(), 1);
34
1
                    incomplete
35
                }
36
                Progress::Complete(result) => unreachable!("future completed: {result:?}"),
37
            };
38
1
            let result = match incomplete.continue_with_additional_budget(1) {
39
1
                Progress::Complete(result) => result,
40
                Progress::NoBudget(_incomplete) => {
41
                    unreachable!("future didn't complete");
42
                }
43
            };
44
1
            assert_eq!(result.balance, 0);
45
1
        }
46

            
47
1
        #[test]
48
1
        fn non_budget_parking() {
49
1
            // This test uses flume's bounded channel mixed with a thread sleep to cause
50
1
            // the async task to wait if it sends messages too quickly. This means this
51
1
            // test should take the (sleep_duration * number_of_iterations - 1) time to
52
1
            // complete. As written currently, this is 1.4 seconds with an assertion of
53
1
            // greater than 1s.
54
1
            let (sender, receiver) = flume::bounded(1);
55
1
            std::thread::spawn(move || {
56
17
                while let Ok(message) = receiver.recv() {
57
16
                    println!("R: received {message}");
58
16
                    std::thread::sleep(Duration::from_millis(100));
59
16
                    println!("R: done 'processing'");
60
16
                }
61
1
            });
62
1

            
63
1
            let task = |runtime: Runtime<usize>| async move {
64
17
                for message in 0..=15 {
65
16
                    println!("S: requesting budget");
66
16
                    runtime.spend(1).await;
67
16
                    println!("S: sending {message}");
68
16
                    sender.send_async(message).await.unwrap();
69
16
                    println!("S: message sent");
70
                }
71
1
            };
72

            
73
1
            let started_at = Instant::now();
74
1
            let mut progress = Runtime::run_with_budget(task, 0);
75
17
            while let Progress::NoBudget(incomplete) = progress {
76
16
                println!("E: providing budget");
77
16
                progress = incomplete.continue_with_additional_budget(1);
78
16
            }
79
1
            let elapsed_time = started_at.elapsed();
80
1
            assert!(elapsed_time > Duration::from_secs(1));
81
1
        }
82
    }
83

            
84
    mod threadsafe {
85

            
86
        use std::time::Duration;
87

            
88
        use crate::{
89
            blocking::threadsafe::{Progress, Runtime},
90
            ReplenishableBudget,
91
        };
92
1
        #[test]
93
1
        fn external_budget() {
94
1
            let budget = ReplenishableBudget::default();
95
1
            let future = |runtime: Runtime<ReplenishableBudget>| async move {
96
101
                for _ in 0..100 {
97
100
                    println!("F> Spend 1");
98
200
                    runtime.spend(1).await;
99
                }
100
1
                println!("Done");
101
1
            };
102

            
103
1
            let thread_budget = budget.clone();
104
1
            std::thread::spawn(move || {
105
101
                for _ in 0..100 {
106
100
                    println!("T> Replenish 1");
107
100
                    thread_budget.replenish(1);
108
100
                    std::thread::sleep(Duration::from_micros(1));
109
100
                }
110
1
                println!("T> Done");
111
1
            });
112

            
113
1
            if let Progress::NoBudget(mut incomplete) = Runtime::run_with_budget(future, budget) {
114
100
                while let Progress::NoBudget(new_incomplete_task) = incomplete.wait_for_budget() {
115
99
                    println!("M> Waiting to complete");
116
99
                    incomplete = new_incomplete_task;
117
99
                }
118
            };
119
1
        }
120

            
121
1
        #[test]
122
1
        fn spawn() {
123
1
            let budget = ReplenishableBudget::new(7);
124
1
            // This test causes contention using a blocking flume channel. One task
125
1
            // spends budget and sends messages while the other replenishes the
126
1
            // budget and receives messages. To ensure that these tasks aren't
127
1
            // completely in lock-step, the channel is bounded at 3 and budget is
128
1
            // allocated every 7 messages.
129
1
            let task_budget = budget.clone();
130
1
            let task = |runtime: Runtime<ReplenishableBudget>| async move {
131
1
                let (sender, receiver) = flume::bounded(3);
132
1

            
133
1
                let sending_task = runtime.spawn({
134
1
                    let runtime = runtime.clone();
135
1
                    async move {
136
101
                        for message in 0..100 {
137
100
                            println!("S: requesting budget");
138
100
                            runtime.spend(1).await;
139
100
                            println!("S: sending {message}");
140
100
                            sender.send_async(message).await.unwrap();
141
100
                            println!("S: message sent");
142
                        }
143
1
                    }
144
1
                });
145
1

            
146
1
                let receiving_task = runtime.spawn(async move {
147
1
                    let mut counter = 0;
148
101
                    while let Ok(message) = receiver.recv_async().await {
149
100
                        println!("R: received {message}");
150
100
                        counter += 1;
151
100
                        if counter % 7 == 0 {
152
14
                            task_budget.replenish(7);
153
86
                        }
154
                    }
155
1
                });
156
1

            
157
57
                sending_task.await;
158
1
                receiving_task.await;
159
1
            };
160

            
161
1
            let result = Runtime::run_with_budget(task, budget).wait_until_complete();
162
1
            assert_eq!(result.balance.remaining(), 15 * 7 - 100);
163
1
        }
164

            
165
1
        #[test]
166
1
        fn nightmare() {
167
1
            const TASKS: usize = 100;
168
1
            const ITERS_PER_TASK: usize = 100;
169
1
            // This test launches a ton of tasks, while an external thread is
170
1
            // filling the budget. This test is aimed to try to find deadlocks.
171
1
            let budget = ReplenishableBudget::new(0);
172
1
            std::thread::spawn({
173
1
                let budget = budget.clone();
174
1
                move || {
175
10000
                    for i in 0..TASKS * ITERS_PER_TASK {
176
10000
                        std::thread::sleep(Duration::from_micros(u64::try_from(i).unwrap() % 10));
177
10000
                        budget.replenish(1);
178
10000
                    }
179
1
                    println!("Budget Filled");
180
1
                }
181
1
            });
182
1

            
183
1
            let task = |runtime: Runtime<ReplenishableBudget>| async move {
184
1
                let (sender, receiver) = flume::unbounded();
185

            
186
101
                for task in 0..TASKS {
187
100
                    runtime.spawn({
188
100
                        let runtime = runtime.clone();
189
100
                        let sender = sender.clone();
190
100
                        async move {
191
10100
                            for _ in 0..ITERS_PER_TASK {
192
328567
                                runtime.spend(1).await;
193
10000
                                println!("{task} Spent 1");
194
                            }
195
100
                            sender.send(()).unwrap();
196
100
                        }
197
100
                    });
198
100
                }
199

            
200
                // Wait for all tasks to send the completion message.
201
101
                for _ in 0..TASKS {
202
4729
                    receiver.recv_async().await.unwrap();
203
                }
204
1
            };
205

            
206
1
            let result = Runtime::run_with_budget(task, budget).wait_until_complete();
207
1
            assert_eq!(result.balance.remaining(), 0);
208
1
        }
209
    }
210
}
211

            
212
mod asynchronous {
213
    use std::time::{Duration, Instant};
214

            
215
    use tokio::task::LocalSet;
216

            
217
    use crate::{
218
        asynchronous::singlethreaded::{Context, Progress},
219
        ReplenishableBudget,
220
    };
221

            
222
1
    #[tokio::test]
223
1
    async fn external_runtime_compatability() {
224
1
        // This test uses flume's bounded channel mixed with a thread sleep to cause
225
1
        // the async task to wait if it sends messages too quickly. This means this
226
1
        // test should take the (sleep_duration * number_of_iterations - 1) time to
227
1
        // complete. As written currently, this is 1.4 seconds with an assertion of
228
1
        // greater than 1s.
229
1
        let (sender, receiver) = flume::bounded(1);
230
1
        tokio::task::spawn(async move {
231
15
            while let Ok(message) = receiver.recv_async().await {
232
15
                println!("R: received {message}");
233
15
                tokio::time::sleep(Duration::from_millis(100)).await;
234
14
                println!("R: done 'processing'");
235
            }
236
1
        });
237
1

            
238
1
        let task = |context: Context<usize>| async move {
239
17
            for message in 0..=15 {
240
16
                println!("S: requesting budget");
241
16
                context.spend(1).await;
242
16
                println!("S: sending {message}");
243
29
                sender.send_async(message).await.unwrap();
244
16
                println!("S: message sent");
245
            }
246
1
        };
247

            
248
1
        let started_at = Instant::now();
249
1
        let mut progress = Context::run_with_budget(task, 0).await;
250
17
        while let Progress::NoBudget(incomplete) = progress {
251
16
            println!("E: providing budget");
252
29
            progress = incomplete.continue_with_additional_budget(1).await;
253
        }
254
1
        let elapsed_time = started_at.elapsed();
255
1
        assert!(elapsed_time > Duration::from_secs(1));
256
1
    }
257

            
258
1
    #[tokio::test]
259
1
    async fn external_budget() {
260
1
        let budget = ReplenishableBudget::default();
261
1
        let future = |context: Context<ReplenishableBudget>| async move {
262
101
            for _ in 0..100 {
263
100
                println!("F> Spend 1");
264
2500
                context.spend(1).await;
265
            }
266
1
            println!("Done");
267
1
        };
268

            
269
1
        let thread_budget = budget.clone();
270
1
        std::thread::spawn(move || {
271
101
            for _ in 0..100 {
272
100
                println!("T> Replenish 1");
273
100
                thread_budget.replenish(1);
274
100
                std::thread::sleep(Duration::from_micros(1));
275
100
            }
276
1
            println!("T> Done");
277
1
        });
278

            
279
1
        if let Progress::NoBudget(mut incomplete) = Context::run_with_budget(future, budget).await {
280
2500
            while let Progress::NoBudget(new_incomplete_task) = incomplete.wait_for_budget().await {
281
2499
                println!("M> Waiting to complete");
282
2499
                incomplete = new_incomplete_task;
283
2499
            }
284
        };
285
1
    }
286

            
287
1
    #[tokio::test]
288
    #[cfg_attr(miri, ignore)] // LocalSet causes undefined behavior errors in miri
289
1
    async fn nightmare() {
290
1
        const TASKS: usize = 100;
291
1
        const ITERS_PER_TASK: usize = 100;
292
1
        // This test launches a ton of tasks, while an external thread is
293
1
        // filling the budget. This test is aimed to try to find deadlocks.
294
1
        let budget = ReplenishableBudget::new(0);
295
1
        std::thread::spawn({
296
1
            let budget = budget.clone();
297
1
            move || {
298
10000
                for i in 0..TASKS * ITERS_PER_TASK {
299
10000
                    std::thread::sleep(Duration::from_micros(u64::try_from(i).unwrap() % 10));
300
10000
                    budget.replenish(1);
301
10000
                }
302
1
                println!("Budget Filled");
303
1
            }
304
1
        });
305
1

            
306
1
        let task = |context: Context<ReplenishableBudget>| async move {
307
1
            let (sender, receiver) = flume::unbounded();
308
1
            let task_set = LocalSet::new();
309

            
310
101
            for task in 0..TASKS {
311
100
                task_set.spawn_local({
312
100
                    let context = context.clone();
313
100
                    let sender = sender.clone();
314
100
                    async move {
315
10100
                        for _ in 0..ITERS_PER_TASK {
316
86165
                            context.spend(1).await;
317
10000
                            println!("{task} Spent 1");
318
                        }
319
100
                        sender.send(()).unwrap();
320
100
                    }
321
100
                });
322
100
            }
323

            
324
1
            task_set
325
1
                .run_until(async move {
326
                    // Wait for all tasks to send the completion message.
327
101
                    for _ in 0..TASKS {
328
2006
                        receiver.recv_async().await.unwrap();
329
                    }
330
2006
                })
331
2006
                .await;
332
1
        };
333

            
334
1
        let result = Context::run_with_budget(task, budget)
335
            .await
336
2006
            .wait_until_complete()
337
2006
            .await;
338
1
        assert_eq!(result.balance.remaining(), 0);
339
1
    }
340
}