1
use std::{
2
    marker::PhantomData,
3
    pin::Pin,
4
    task::{Context, Poll},
5
};
6

            
7
use futures_core::ready;
8
use futures_sink::Sink;
9
use ordered_varint::Variable;
10
use tokio::io::AsyncWrite;
11
use transmog::Format;
12

            
13
/// A wrapper around an asynchronous sink that accepts, serializes, and sends Transmog-encoded
14
/// values.
15
///
16
/// To use, provide a writer that implements [`AsyncWrite`], and then use [`Sink`] to send values.
17
///
18
/// Note that an `TransmogWriter` must be of the type [`AsyncDestination`] in order to be
19
/// compatible with an [`TransmogReader`](super::TransmogReader) on the remote end (recall that it requires the
20
/// serialized size prefixed to the serialized data). The default is [`SyncDestination`], but these
21
/// can be easily toggled between using [`TransmogWriter::for_async`].
22
#[derive(Debug)]
23
pub struct TransmogWriter<W, T, D, F> {
24
    format: F,
25
    writer: W,
26
    pub(crate) written: usize,
27
    pub(crate) buffer: Vec<u8>,
28
    scratch_buffer: Vec<u8>,
29
    from: PhantomData<T>,
30
    dest: PhantomData<D>,
31
}
32

            
33
impl<W, T, D, F> Unpin for TransmogWriter<W, T, D, F> where W: Unpin {}
34

            
35
impl<W, T, D, F> TransmogWriter<W, T, D, F> {
36
    /// Gets a reference to the underlying format.
37
    ///
38
    /// It is inadvisable to directly write to the underlying writer.
39
6
    pub fn format(&self) -> &F {
40
6
        &self.format
41
6
    }
42

            
43
    /// Gets a reference to the underlying writer.
44
    ///
45
    /// It is inadvisable to directly write to the underlying writer.
46
    pub fn get_ref(&self) -> &W {
47
        &self.writer
48
    }
49

            
50
    /// Gets a mutable reference to the underlying writer.
51
    ///
52
    /// It is inadvisable to directly write to the underlying writer.
53
13838
    pub fn get_mut(&mut self) -> &mut W {
54
13838
        &mut self.writer
55
13838
    }
56

            
57
    /// Unwraps this `TransmogWriter`, returning the underlying writer.
58
    ///
59
    /// Note that any leftover serialized data that has not yet been sent is lost.
60
12
    pub fn into_inner(self) -> (W, F) {
61
12
        (self.writer, self.format)
62
12
    }
63
}
64

            
65
impl<W, T, F> TransmogWriter<W, T, SyncDestination, F> {
66
    /// Returns a new instance that sends `format`-encoded data over `writer`.
67
30
    pub fn new(writer: W, format: F) -> Self {
68
30
        TransmogWriter {
69
30
            format,
70
30
            buffer: Vec::new(),
71
30
            scratch_buffer: Vec::new(),
72
30
            writer,
73
30
            written: 0,
74
30
            from: PhantomData,
75
30
            dest: PhantomData,
76
30
        }
77
30
    }
78

            
79
    /// Returns a new instance that sends `format`-encoded data over
80
    /// `W::defcfault()`.
81
    pub fn default_for(format: F) -> Self
82
    where
83
        W: Default,
84
    {
85
        Self::new(W::default(), format)
86
    }
87
}
88

            
89
impl<W, T, F> TransmogWriter<W, T, SyncDestination, F> {
90
    /// Make this writer include the serialized data's size before each serialized value.
91
    ///
92
    /// This is necessary for compatability with [`TransmogReader`](super::TransmogReader).
93
12
    pub fn for_async(self) -> TransmogWriter<W, T, AsyncDestination, F> {
94
12
        self.make_for()
95
12
    }
96
}
97

            
98
impl<W, T, D, F> TransmogWriter<W, T, D, F> {
99
18
    pub(crate) fn make_for<D2>(self) -> TransmogWriter<W, T, D2, F> {
100
18
        TransmogWriter {
101
18
            format: self.format,
102
18
            buffer: self.buffer,
103
18
            writer: self.writer,
104
18
            written: self.written,
105
18
            from: self.from,
106
18
            scratch_buffer: self.scratch_buffer,
107
18
            dest: PhantomData,
108
18
        }
109
18
    }
110
}
111

            
112
impl<W, T, F> TransmogWriter<W, T, AsyncDestination, F> {
113
    /// Make this writer only send Transmog-encoded values.
114
    ///
115
    /// This is necessary for compatability with stock Transmog receivers.
116
    pub fn for_sync(self) -> TransmogWriter<W, T, SyncDestination, F> {
117
        self.make_for()
118
    }
119
}
120

            
121
/// A marker that indicates that the wrapping type is compatible with [`TransmogReader`](super::TransmogReader).
122
#[derive(Debug)]
123
pub struct AsyncDestination;
124

            
125
/// A marker that indicates that the wrapping type is compatible with stock Transmog receivers.
126
#[derive(Debug)]
127
pub struct SyncDestination;
128

            
129
#[doc(hidden)]
130
pub trait TransmogWriterFor<T, F>
131
where
132
    F: Format<'static, T>,
133
{
134
    fn append(&mut self, item: &T) -> Result<(), F::Error>;
135
}
136

            
137
impl<W, T, F> TransmogWriterFor<T, F> for TransmogWriter<W, T, AsyncDestination, F>
138
where
139
    F: Format<'static, T>,
140
{
141
163854
    fn append(&mut self, item: &T) -> Result<(), F::Error> {
142
163854
        if let Some(serialized_length) = self.format.serialized_size(item)? {
143
163846
            let size = usize_to_u64(serialized_length)?;
144
163846
            size.encode_variable(&mut self.buffer)?;
145
163846
            self.format.serialize_into(item, &mut self.buffer)?;
146
        } else {
147
            // Use a scratch buffer to measure the size. This introduces an
148
            // extra data copy, but by reusing the scratch buffer, that should
149
            // be the only overhead.
150
8
            self.scratch_buffer.truncate(0);
151
8
            self.format.serialize_into(item, &mut self.scratch_buffer)?;
152

            
153
8
            let size = usize_to_u64(self.scratch_buffer.len())?;
154
8
            size.encode_variable(&mut self.buffer)?;
155
8
            self.buffer.append(&mut self.scratch_buffer);
156
        }
157
163854
        Ok(())
158
163854
    }
159
}
160

            
161
163854
fn usize_to_u64(value: usize) -> Result<u64, std::io::Error> {
162
163854
    u64::try_from(value).map_err(|_| std::io::Error::from(std::io::ErrorKind::OutOfMemory))
163
163854
}
164

            
165
impl<W, T, F> TransmogWriterFor<T, F> for TransmogWriter<W, T, SyncDestination, F>
166
where
167
    F: Format<'static, T>,
168
{
169
    fn append(&mut self, item: &T) -> Result<(), F::Error> {
170
        self.format.serialize_into(item, &mut self.buffer)
171
    }
172
}
173

            
174
impl<W, T, D, F> Sink<T> for TransmogWriter<W, T, D, F>
175
where
176
    F: Format<'static, T>,
177
    W: AsyncWrite + Unpin,
178
    Self: TransmogWriterFor<T, F>,
179
{
180
    type Error = F::Error;
181

            
182
163854
    fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
183
163854
        Poll::Ready(Ok(()))
184
163854
    }
185

            
186
    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
187
163854
        self.append(&item)?;
188
163854
        Ok(())
189
163854
    }
190

            
191
206
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
192
206
        // allow us to borrow fields separately
193
206
        let this = self.get_mut();
194

            
195
        // write stuff out if we need to
196
222
        while this.written != this.buffer.len() {
197
16
            let n =
198
187
                ready!(Pin::new(&mut this.writer).poll_write(cx, &this.buffer[this.written..]))?;
199
16
            this.written += n;
200
        }
201

            
202
        // we have to flush before we're really done
203
35
        this.buffer.clear();
204
35
        this.written = 0;
205
35
        Pin::new(&mut this.writer)
206
35
            .poll_flush(cx)
207
35
            .map_err(<F::Error as From<std::io::Error>>::from)
208
206
    }
209

            
210
4
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
211
4
        ready!(self.as_mut().poll_flush(cx))?;
212
4
        Pin::new(&mut self.writer)
213
4
            .poll_shutdown(cx)
214
4
            .map_err(<F::Error as From<std::io::Error>>::from)
215
4
    }
216
}