1
use std::{
2
    io,
3
    marker::PhantomData,
4
    pin::Pin,
5
    task::{Context, Poll},
6
};
7

            
8
use bytes::{Buf, BytesMut};
9
use futures_core::{ready, Stream};
10
use ordered_varint::Variable;
11
use tokio::io::{AsyncRead, ReadBuf};
12
use transmog::OwnedDeserializer;
13

            
14
/// A wrapper around an asynchronous reader that produces an asynchronous stream
15
/// of Transmog-decoded values.
16
///
17
/// To use, provide a reader that implements [`AsyncRead`], and then use
18
/// [`Stream`] to access the deserialized values.
19
///
20
/// Note that the sender *must* prefix each serialized item with its size
21
/// encoded using [`ordered-varint`](ordered_varint).
22
#[derive(Debug)]
23
pub struct TransmogReader<R, T, F> {
24
    format: F,
25
    reader: R,
26
    pub(crate) buffer: BytesMut,
27
    into: PhantomData<T>,
28
}
29

            
30
impl<R, T, F> Unpin for TransmogReader<R, T, F> where R: Unpin {}
31

            
32
impl<R, T, F> TransmogReader<R, T, F> {
33
    /// Gets a reference to the underlying reader.
34
    ///
35
    /// It is inadvisable to directly read from the underlying reader.
36
    pub fn get_ref(&self) -> &R {
37
        &self.reader
38
    }
39

            
40
    /// Gets a mutable reference to the underlying reader.
41
    ///
42
    /// It is inadvisable to directly read from the underlying reader.
43
163868
    pub fn get_mut(&mut self) -> &mut R {
44
163868
        &mut self.reader
45
163868
    }
46

            
47
    /// Returns a reference to the internally buffered data.
48
    ///
49
    /// This will not attempt to fill the buffer if it is empty.
50
    pub fn buffer(&self) -> &[u8] {
51
        &self.buffer[..]
52
    }
53

            
54
    /// Unwraps this `TransmogReader`, returning the underlying reader.
55
    ///
56
    /// Note that any leftover data in the internal buffer is lost.
57
12
    pub fn into_inner(self) -> R {
58
12
        self.reader
59
12
    }
60
}
61

            
62
impl<R, T, F> TransmogReader<R, T, F> {
63
    /// Returns a new instance that reads `format`-encoded data for `reader`.
64
30
    pub fn new(reader: R, format: F) -> Self {
65
30
        TransmogReader {
66
30
            format,
67
30
            buffer: BytesMut::with_capacity(8192),
68
30
            reader,
69
30
            into: PhantomData,
70
30
        }
71
30
    }
72

            
73
    /// Returns a new instance that reads `format`-encoded data for `R::default()`.
74
    pub fn default_for(format: F) -> Self
75
    where
76
        R: Default,
77
    {
78
        Self::new(R::default(), format)
79
    }
80
}
81

            
82
impl<R, T, F> Stream for TransmogReader<R, T, F>
83
where
84
    R: AsyncRead + Unpin,
85
    F: OwnedDeserializer<T>,
86
{
87
    type Item = Result<T, F::Error>;
88
164175
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89
        loop {
90
164193
            let fill_result = ready!(self
91
164193
                .as_mut()
92
164193
                .fill(cx, 9)
93
164193
                .map_err(<F::Error as From<std::io::Error>>::from))?;
94

            
95
163894
            let mut buf_reader = &self.buffer[..];
96
163894
            let buffer_start = buf_reader.as_ptr() as usize;
97
163894
            if let Ok(message_size) = u64::decode_variable(&mut buf_reader) {
98
163890
                let header_len = buf_reader.as_ptr() as usize - buffer_start;
99
163890
                let target_buffer_size = usize::try_from(message_size).unwrap() + header_len;
100

            
101
163890
                ready!(self
102
163890
                    .as_mut()
103
163890
                    .fill(cx, target_buffer_size)
104
163890
                    .map_err(<F::Error as From<std::io::Error>>::from))?;
105

            
106
163872
                if self.buffer.len() >= target_buffer_size {
107
163854
                    let message = self
108
163854
                        .format
109
163854
                        .deserialize_owned(&self.buffer[header_len..target_buffer_size])
110
163854
                        .unwrap();
111
163854
                    self.buffer.advance(target_buffer_size);
112
163854
                    break Poll::Ready(Some(Ok(message)));
113
18
                }
114
4
            } else if let ReadResult::Eof = fill_result {
115
4
                break Poll::Ready(None);
116
            }
117
        }
118
164175
    }
119
}
120

            
121
#[derive(Debug)]
122
enum ReadResult {
123
    ReceivedData,
124
    Eof,
125
}
126

            
127
impl<R, T, F> TransmogReader<R, T, F>
128
where
129
    R: AsyncRead + Unpin,
130
{
131
328083
    fn fill(
132
328083
        mut self: Pin<&mut Self>,
133
328083
        cx: &mut Context<'_>,
134
328083
        target_size: usize,
135
328083
    ) -> Poll<Result<ReadResult, io::Error>> {
136
328083
        if self.buffer.len() >= target_size {
137
            // we already have the bytes we need!
138
292049
            return Poll::Ready(Ok(ReadResult::ReceivedData));
139
36034
        }
140
36034

            
141
36034
        // make sure we can fit all the data we're about to read
142
36034
        // and then some, so we don't do a gazillion syscalls
143
36034
        if self.buffer.capacity() < target_size {
144
36028
            let missing = target_size - self.buffer.capacity();
145
36028
            self.buffer.reserve(missing);
146
36028
        }
147

            
148
36034
        let had = self.buffer.len();
149
36034
        // this is the bit we'll be reading into
150
36034
        let mut rest = self.buffer.split_off(had);
151
36034
        // this is safe because we're not extending beyond the reserved capacity
152
36034
        // and we're never reading unwritten bytes
153
36034
        let max = rest.capacity();
154
36034
        // In the original implementation, this was an unsafe operation.
155
36034
        // unsafe { rest.set_len(max) };
156
36034
        rest.resize(max, 0);
157
36034

            
158
36034
        let mut buf = ReadBuf::new(&mut rest[..]);
159
36034
        ready!(Pin::new(&mut self.reader).poll_read(cx, &mut buf))?;
160
35717
        let n = buf.filled().len();
161
35717
        // adopt the new bytes
162
35717
        let read = rest.split_to(n);
163
35717
        self.buffer.unsplit(read);
164
35717
        if n == 0 {
165
6
            return Poll::Ready(Ok(ReadResult::Eof));
166
35711
        }
167
35711

            
168
35711
        Poll::Ready(Ok(ReadResult::ReceivedData))
169
328083
    }
170
}