1
use std::{
2
    future::Future,
3
    marker::PhantomData,
4
    pin::Pin,
5
    task::{Context, Poll, Waker},
6
};
7

            
8
use crate::{BudgetContext, BudgetContextData, BudgetResult, Budgetable, Container};
9

            
10
4
fn run_with_budget<Budget, Backing, F>(
11
4
    future: impl FnOnce(BudgetContext<Backing, Budget>) -> F,
12
4
    initial_budget: Budget,
13
4
) -> BudgetedFuture<Budget, Backing, F>
14
4
where
15
4
    Budget: Budgetable,
16
4
    Backing: Container<BudgetContextData<Budget>>,
17
4
    F: Future,
18
4
{
19
4
    let context = BudgetContext {
20
4
        data: Backing::new(BudgetContextData {
21
4
            budget: initial_budget,
22
4
            paused_future: None,
23
4
        }),
24
4
        _budget: PhantomData,
25
4
    };
26
4
    BudgetedFuture {
27
4
        state: Some(BudgetedFutureState {
28
4
            future: Box::pin(future(context.clone())),
29
4
            context,
30
4
            _budget: PhantomData,
31
4
        }),
32
4
    }
33
4
}
34

            
35
/// Executes a future with a given budget when awaited.
36
///
37
/// This future is returned from [`run_with_budget()`].
38
#[must_use = "the future must be awaited to be executed"]
39
struct BudgetedFuture<Budget, Backing, F>
40
where
41
    Backing: Container<BudgetContextData<Budget>>,
42
{
43
    state: Option<BudgetedFutureState<Budget, Backing, F>>,
44
}
45

            
46
struct BudgetedFutureState<Budget, Backing, F>
47
where
48
    Backing: Container<BudgetContextData<Budget>>,
49
{
50
    future: Pin<Box<F>>,
51
    context: BudgetContext<Backing, Budget>,
52
    _budget: PhantomData<Budget>,
53
}
54

            
55
impl<Budget, Backing, F> Future for BudgetedFuture<Budget, Backing, F>
56
where
57
    F: Future,
58
    Budget: Budgetable,
59
    Backing: Container<BudgetContextData<Budget>>,
60
{
61
    type Output = Progress<Budget, Backing, F>;
62

            
63
63
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
64
63
        let state = self
65
63
            .state
66
63
            .take()
67
63
            .expect("poll called after future was complete");
68
63
        match poll_async_future_with_budget(state.future, cx, state.context) {
69
27
            BudgetPoll::Ready(result) => Poll::Ready(result),
70
36
            BudgetPoll::Pending { future, context } => {
71
36
                self.state = Some(BudgetedFutureState {
72
36
                    future,
73
36
                    context,
74
36
                    _budget: PhantomData,
75
36
                });
76
36
                Poll::Pending
77
            }
78
        }
79
63
    }
80
}
81

            
82
enum BudgetPoll<Budget, Backing, F>
83
where
84
    F: Future,
85
    Budget: Budgetable,
86
    Backing: Container<BudgetContextData<Budget>>,
87
{
88
    Ready(Progress<Budget, Backing, F>),
89
    Pending {
90
        future: Pin<Box<F>>,
91
        context: BudgetContext<Backing, Budget>,
92
    },
93
}
94

            
95
4569
fn poll_async_future_with_budget<Budget, Backing, F>(
96
4569
    mut future: Pin<Box<F>>,
97
4569
    cx: &mut Context<'_>,
98
4569
    budget_context: BudgetContext<Backing, Budget>,
99
4569
) -> BudgetPoll<Budget, Backing, F>
100
4569
where
101
4569
    Budget: Budgetable,
102
4569
    F: Future,
103
4569
    Backing: Container<BudgetContextData<Budget>>,
104
4569
{
105
4569
    budget_context.data.map_locked(|data| {
106
4569
        data.paused_future = None;
107
4569
    });
108
4569

            
109
4569
    let pinned_future = Pin::new(&mut future);
110
4569
    let future_result = pinned_future.poll(cx);
111
4569

            
112
4569
    match future_result {
113
4
        Poll::Ready(output) => BudgetPoll::Ready(Progress::Complete(BudgetResult {
114
4
            output,
115
4
            balance: budget_context.data.map_locked(|data| data.budget.clone()),
116
4
        })),
117
        Poll::Pending => {
118
4565
            let waker = budget_context
119
4565
                .data
120
4565
                .map_locked(|data| data.paused_future.take());
121
4565
            if let Some(waker) = waker {
122
4528
                BudgetPoll::Ready(Progress::NoBudget(IncompleteFuture {
123
4528
                    future,
124
4528
                    paused_future: waker,
125
4528
                    context: budget_context,
126
4528
                }))
127
            } else {
128
37
                BudgetPoll::Pending {
129
37
                    future,
130
37
                    context: budget_context,
131
37
                }
132
            }
133
        }
134
    }
135
4569
}
136

            
137
/// The progress of a future's execution.
138
enum Progress<Budget, Backing, F>
139
where
140
    Budget: Budgetable,
141
    F: Future,
142
    Backing: Container<BudgetContextData<Budget>>,
143
{
144
    /// The future was interrupted because it requested to spend more budget
145
    /// than was available.
146
    NoBudget(IncompleteFuture<Budget, Backing, F>),
147
    /// The future has completed.
148
    Complete(BudgetResult<F::Output, Budget>),
149
}
150

            
151
/// A future that was budgeted using [`run_with_budget()`] that has
152
/// not yet completed.
153
struct IncompleteFuture<Budget, Backing, F>
154
where
155
    F: Future,
156
    Backing: Container<BudgetContextData<Budget>>,
157
{
158
    future: Pin<Box<F>>,
159
    paused_future: Waker,
160
    context: BudgetContext<Backing, Budget>,
161
}
162

            
163
impl<Budget, Backing, F> IncompleteFuture<Budget, Backing, F>
164
where
165
    F: Future,
166
    Budget: Budgetable,
167
    Backing: Container<BudgetContextData<Budget>>,
168
{
169
    /// Adds `additional_budget` to the remaining balance and continues
170
    /// executing the future.
171
    ///
172
    /// This function returns a future that must be awaited for anything to happen.
173
23
    pub fn continue_with_additional_budget(
174
23
        self,
175
23
        additional_budget: usize,
176
23
    ) -> BudgetedFuture<Budget, Backing, F> {
177
23
        let Self {
178
23
            future,
179
23
            paused_future,
180
23
            context,
181
23
            ..
182
23
        } = self;
183
23
        paused_future.wake();
184
23
        context
185
23
            .data
186
23
            .map_locked(|data| data.budget.replenish(additional_budget));
187
23

            
188
23
        BudgetedFuture {
189
23
            state: Some(BudgetedFutureState {
190
23
                future,
191
23
                context,
192
23
                _budget: PhantomData,
193
23
            }),
194
23
        }
195
23
    }
196
    /// Waits for additional budget to be allocated through
197
    /// [`ReplenishableBudget::replenish()`].
198
4505
    pub fn wait_for_budget(self) -> WaitForBudgetFuture<Budget, Backing, F> {
199
4505
        let Self {
200
4505
            future,
201
4505
            paused_future: waker,
202
4505
            context,
203
4505
            ..
204
4505
        } = self;
205
4505
        WaitForBudgetFuture {
206
4505
            has_returned_pending: false,
207
4505
            paused_future: Some(waker),
208
4505
            future: BudgetedFuture {
209
4505
                state: Some(BudgetedFutureState {
210
4505
                    future,
211
4505
                    context,
212
4505
                    _budget: PhantomData,
213
4505
                }),
214
4505
            },
215
4505
        }
216
4505
    }
217
}
218

            
219
/// A future that waits for additional budget to be allocated through
220
/// [`ReplenishableBudget::replenish()`].
221
///
222
/// This must be awaited to be executed.
223
#[must_use = "the future must be awaited to be executed"]
224
struct WaitForBudgetFuture<Budget, Backing, F>
225
where
226
    F: Future,
227
    Backing: Container<BudgetContextData<Budget>>,
228
{
229
    has_returned_pending: bool,
230
    paused_future: Option<Waker>,
231
    future: BudgetedFuture<Budget, Backing, F>,
232
}
233

            
234
impl<Budget, Backing, F> Future for WaitForBudgetFuture<Budget, Backing, F>
235
where
236
    F: Future,
237
    Budget: Budgetable,
238
    Backing: Container<BudgetContextData<Budget>>,
239
{
240
    type Output = Progress<Budget, Backing, F>;
241

            
242
9011
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
243
9011
        if self.has_returned_pending {
244
4506
            if let Some(future) = self.paused_future.take() {
245
4505
                future.wake();
246
4505
            }
247

            
248
4506
            let state = self
249
4506
                .future
250
4506
                .state
251
4506
                .take()
252
4506
                .expect("poll called after future was complete");
253
4506
            match poll_async_future_with_budget(state.future, cx, state.context) {
254
4505
                BudgetPoll::Ready(result) => Poll::Ready(result),
255
1
                BudgetPoll::Pending { future, context } => {
256
1
                    self.future.state = Some(BudgetedFutureState {
257
1
                        future,
258
1
                        context,
259
1
                        _budget: PhantomData,
260
1
                    });
261
1
                    Poll::Pending
262
                }
263
            }
264
        } else {
265
4505
            self.has_returned_pending = true;
266
4505
            let state = self.future.state.as_ref().expect("always present");
267
4505
            state
268
4505
                .context
269
4505
                .data
270
4505
                .map_locked(|data| data.budget.add_waker(cx.waker()));
271
4505
            Poll::Pending
272
        }
273
9011
    }
274
}
275

            
276
macro_rules! define_public_interface {
277
    ($modulename:ident, $backing:ident, $moduledocs:literal) => {
278
        #[doc = $moduledocs]
279
        pub mod $modulename {
280
            use std::{
281
                future::Future,
282
                pin::Pin,
283
                task::{self, Poll},
284
            };
285

            
286
            use crate::{BudgetResult, Budgetable, spend::$modulename::SpendBudget};
287

            
288
            type Backing<Budget> = crate::$backing<crate::BudgetContextData<Budget>>;
289
            type BudgetContext<Budget> = crate::BudgetContext<Backing<Budget>, Budget>;
290

            
291
            /// A budget-limited asynchronous context.
292
100
            #[derive(Clone, Debug)]
293
            pub struct Context<Budget>(BudgetContext<Budget>)
294
            where
295
                Budget: Budgetable;
296

            
297
            impl<Budget> Context<Budget>
298
            where
299
                Budget: Budgetable,
300
            {
301

            
302
                /// Executes `future` with the provided budget. The future will run until it
303
                /// completes or until it has invoked [`spend()`](Self::spend) enough to
304
                /// exhaust the budget provided. If the future never called
305
                /// [`spend()`](Self::spend), this function will not return until the future
306
                /// has completed.
307
                ///
308
                /// This function returns a [`Future`] which must be awaited to execute the
309
                /// function.
310
                ///
311
                /// This implementation is runtime agnostic.
312
                ///
313
                /// # Panics
314
                ///
315
                /// Panics when called within from within `future` or any code invoked by
316
                /// `future`.
317
4
                pub fn run_with_budget<F>(
318
4
                    future: impl FnOnce(Context<Budget>) -> F,
319
4
                    initial_budget: Budget,
320
4
                ) -> BudgetedFuture<Budget, F>
321
4
                where
322
4
                    F: Future,
323
4
                {
324
4
                    BudgetedFuture(super::run_with_budget(
325
4
                        |context| future(Context(context)),
326
4
                        initial_budget,
327
4
                    ))
328
4
                }
329

            
330
                /// Retrieves the current budget.
331
                ///
332
                /// This function should only be called by code that is guaranteed to be running
333
                /// by this executor. When called outside of code run by this executor, this function will.
334
                #[must_use]
335
                pub fn budget(&self) -> usize {
336
                    self.0.budget()
337
                }
338

            
339
                /// Spends `amount` from the curent budget.
340
                ///
341
                /// This function returns a future which must be awaited.
342
10120
                pub fn spend(&self, amount: usize) -> SpendBudget<'_, Budget> {
343
10120
                    // How do we re-export SpendBudget since it's sahrd with async too. crate-level module?
344
10120
                    SpendBudget::from(self.0.spend(amount))
345
10120
                }
346
            }
347

            
348
            /// Executes a future with a given budget when awaited.
349
            ///
350
            /// This future is returned from [`Context::run_with_budget()`].
351
            #[must_use = "the future must be awaited to be executed"]
352
            pub struct BudgetedFuture<Budget, F>(super::BudgetedFuture<Budget, Backing<Budget>, F>)
353
            where
354
                Budget: Budgetable;
355

            
356
            impl<Budget, F> Future for BudgetedFuture<Budget, F>
357
            where
358
                Budget: Budgetable,
359
                F: Future,
360
            {
361
                type Output = Progress<Budget, F>;
362

            
363
63
                fn poll(
364
63
                    mut self: Pin<&mut Self>,
365
63
                    cx: &mut task::Context<'_>,
366
63
                ) -> Poll<Self::Output> {
367
63
                    let inner = Pin::new(&mut self.0);
368
63
                    match inner.poll(cx) {
369
27
                        Poll::Ready(output) => Poll::Ready(Progress::from(output)),
370
36
                        Poll::Pending => Poll::Pending,
371
                    }
372
63
                }
373
            }
374
            /// The progress of a future's execution.
375
            pub enum Progress<Budget, F>
376
            where
377
                Budget: Budgetable,
378
                F: Future,
379
            {
380
                /// The future was interrupted because it requested to spend more budget
381
                /// than was available.
382
                NoBudget(IncompleteFuture<Budget, F>),
383
                /// The future has completed.
384
                Complete(BudgetResult<F::Output, Budget>),
385
            }
386

            
387
            impl<Budget, F> Progress<Budget, F>
388
            where
389
                Budget: Budgetable,
390
                F: Future,
391
            {
392
                /// Continues executing the contained future until it is
393
                /// completed.
394
                ///
395
                /// This function will never return if the future enters an
396
                /// infinite loop or deadlocks, regardless of whether the budget
397
                /// is exhausted or not.
398
1
                pub async fn wait_until_complete(self) -> BudgetResult<F::Output, Budget> {
399
                    let mut progress = self;
400
                    loop {
401
                        match progress {
402
                            Progress::NoBudget(incomplete) => {
403
                                progress = incomplete.wait_for_budget().await;
404
                            }
405
                            Progress::Complete(result) => break result,
406
                        }
407
                    }
408
                }
409
            }
410

            
411
            impl<Budget, F> From<super::Progress<Budget, Backing<Budget>, F>> for Progress<Budget, F>
412
            where
413
                Budget: Budgetable,
414
                F: Future, {
415
4532
                fn from(progress: super::Progress<Budget, Backing<Budget>, F>) -> Self {
416
4532
                    match progress {
417
4528
                        super::Progress::NoBudget(incomplete) => Progress::NoBudget(IncompleteFuture(incomplete)),
418
4
                        super::Progress::Complete(result) => Progress::Complete(result)
419
                    }
420
4532
                }
421
            }
422

            
423
            /// A future that was budgeted using [`Context::run_with_budget()`]
424
            /// that has not yet completed.
425
            pub struct IncompleteFuture<Budget, F>(
426
                pub(super) super::IncompleteFuture<Budget, Backing<Budget>, F>,
427
            )
428
            where
429
                F: Future,
430
                Budget: Budgetable;
431

            
432
            impl<Budget, F> IncompleteFuture<Budget, F>
433
            where
434
                F: Future,
435
                Budget: Budgetable,
436
            {
437
                /// Adds `additional_budget` to the remaining balance and continues
438
                /// executing the future.
439
                ///
440
                /// This function returns a future that must be awaited for anything to happen.
441
23
                pub fn continue_with_additional_budget(
442
23
                    self,
443
23
                    additional_budget: usize,
444
23
                ) -> BudgetedFuture<Budget, F> {
445
23
                    BudgetedFuture(self.0.continue_with_additional_budget(additional_budget))
446
23
                }
447

            
448
                /// Waits for additional budget to be allocated through
449
                /// [`ReplenishableBudget::replenish()`](crate::ReplenishableBudget::replenish).
450
4505
                pub fn wait_for_budget(self) -> WaitForBudgetFuture<Budget, F> {
451
4505
                    WaitForBudgetFuture(self.0.wait_for_budget())
452
4505
                }
453

            
454
            }
455

            
456
            /// A future that waits for additional budget to be allocated
457
            /// through
458
            /// [`ReplenishableBudget::replenish()`](crate::ReplenishableBudget::replenish).
459
            ///
460
            /// This must be awaited to be executed.
461
            #[must_use = "the future must be awaited to be executed"]
462
            pub struct WaitForBudgetFuture<Budget, F>(
463
                pub(super) super::WaitForBudgetFuture<Budget, Backing<Budget>, F>,
464
            )
465
            where
466
                F: Future,
467
                Budget: Budgetable;
468

            
469
            impl<Budget, F> Future for WaitForBudgetFuture<Budget, F>
470
            where
471
                F: Future,
472
                Budget: Budgetable,
473
            {
474
                type Output = Progress<Budget, F>;
475

            
476
9011
                fn poll(
477
9011
                    mut self: Pin<&mut Self>,
478
9011
                    cx: &mut task::Context<'_>,
479
9011
                ) -> Poll<Self::Output> {
480
9011
                    let inner = Pin::new(&mut self.0);
481
9011
                    match inner.poll(cx) {
482
4505
                        Poll::Ready(output) => Poll::Ready(Progress::from(output)),
483
4506
                        Poll::Pending => Poll::Pending,
484
                    }
485
9011
                }
486
            }
487
        }
488
    };
489
}
490

            
491
define_public_interface!(
492
    threadsafe,
493
    SyncContainer,
494
    "A threadsafe (`Send + Sync`), asynchronous budgeting implementation that is runtime agnostic.\n\nThe only difference between this module and the [`singlethreaded`] module is that this one uses [`std::sync::Arc`] and [`std::sync::Mutex`] instead of [`std::rc::Rc`] and [`std::cell::RefCell`]."
495
);
496

            
497
4
define_public_interface!(
498
4
    singlethreaded,
499
4
    NotSyncContainer,
500
4
    "A single-threaded (`!Send + !Sync`), asynchronous budgeting implementation that is runtime agnostic.\n\nThe only difference between this module and the [`threadsafe`] module is that this one uses [`std::rc::Rc`] and [`std::cell::RefCell`] instead of [`std::sync::Arc`] and [`std::sync::Mutex`]."
501
4
);