1
#![doc = include_str!("./.crate-docs.md")]
2
#![forbid(unsafe_code)]
3
#![warn(
4
    clippy::cargo,
5
    missing_docs,
6
    // clippy::missing_docs_in_private_items,
7
    clippy::pedantic,
8
    future_incompatible,
9
    rust_2018_idioms,
10
)]
11
#![allow(
12
    clippy::missing_errors_doc, // TODO clippy::missing_errors_doc
13
    clippy::option_if_let_else,
14
    clippy::module_name_repetitions,
15
)]
16

            
17
mod reader;
18
mod writer;
19

            
20
use std::{
21
    fmt, io,
22
    marker::PhantomData,
23
    ops::{Deref, DerefMut},
24
    pin::Pin,
25
    task::{Context, Poll},
26
};
27

            
28
use futures_core::Stream;
29
use futures_sink::Sink;
30
use tokio::io::{AsyncRead, ReadBuf};
31
pub use transmog;
32
use transmog::Format;
33

            
34
pub use self::{
35
    reader::TransmogReader,
36
    writer::{AsyncDestination, SyncDestination, TransmogWriter, TransmogWriterFor},
37
};
38

            
39
/// Builder helper to specify types without the need of turbofishing.
40
pub struct Builder<TReads, TWrites, TStream, TFormat> {
41
    stream: TStream,
42
    format: TFormat,
43
    datatypes: PhantomData<(TReads, TWrites)>,
44
}
45

            
46
impl<TStream, TFormat> Builder<(), (), TStream, TFormat> {
47
    /// Returns a new stream builder for `stream` and `format`.
48
2
    pub fn new(stream: TStream, format: TFormat) -> Self {
49
2
        Self {
50
2
            stream,
51
2
            format,
52
2
            datatypes: PhantomData,
53
2
        }
54
2
    }
55
}
56

            
57
impl<TStream, TFormat> Builder<(), (), TStream, TFormat> {
58
    /// Sets `T` as the type for both sending and receiving.
59
2
    pub fn sends_and_receives<T>(self) -> Builder<T, T, TStream, TFormat>
60
2
    where
61
2
        TFormat: Format<'static, T>,
62
2
    {
63
2
        Builder {
64
2
            stream: self.stream,
65
2
            format: self.format,
66
2
            datatypes: PhantomData,
67
2
        }
68
2
    }
69
}
70

            
71
impl<TReads, TStream, TFormat> Builder<TReads, (), TStream, TFormat> {
72
    /// Sets `T` as the type of data that is written to this stream.
73
    pub fn sends<T>(self) -> Builder<TReads, T, TStream, TFormat>
74
    where
75
        TFormat: Format<'static, T>,
76
    {
77
        Builder {
78
            stream: self.stream,
79
            format: self.format,
80
            datatypes: PhantomData,
81
        }
82
    }
83
}
84

            
85
impl<TWrites, TStream, TFormat> Builder<(), TWrites, TStream, TFormat> {
86
    /// Sets `T` as the type of data that is read from this stream.
87
    pub fn receives<T>(self) -> Builder<T, TWrites, TStream, TFormat>
88
    where
89
        TFormat: Format<'static, T>,
90
    {
91
        Builder {
92
            stream: self.stream,
93
            format: self.format,
94
            datatypes: PhantomData,
95
        }
96
    }
97
}
98

            
99
impl<TReads, TWrites, TStream, TFormat> Builder<TReads, TWrites, TStream, TFormat>
100
where
101
    TFormat: Clone,
102
{
103
    /// Build this stream to include the serialized data's size before each
104
    /// serialized value.
105
    ///
106
    /// This is necessary for compatability with a remote [`TransmogReader`].
107
2
    pub fn for_async(self) -> TransmogStream<TReads, TWrites, TStream, AsyncDestination, TFormat> {
108
2
        TransmogStream::new(self.stream, self.format).for_async()
109
2
    }
110

            
111
    /// Build this stream only send Transmog-encoded values.
112
    ///
113
    /// This is necessary for compatability with stock Transmog receivers.
114
    pub fn for_sync(self) -> TransmogStream<TReads, TWrites, TStream, SyncDestination, TFormat> {
115
        TransmogStream::new(self.stream, self.format)
116
    }
117
}
118

            
119
/// A wrapper around an asynchronous stream that receives and sends bincode-encoded values.
120
///
121
/// To use, provide a stream that implements both [`AsyncWrite`](tokio::io::AsyncWrite) and [`AsyncRead`], and then use
122
/// [`Sink`] to send values and [`Stream`] to receive them.
123
///
124
/// Note that an `TransmogStream` must be of the type [`AsyncDestination`] in order to be
125
/// compatible with an [`TransmogReader`] on the remote end (recall that it requires the
126
/// serialized size prefixed to the serialized data). The default is [`SyncDestination`], but these
127
/// can be easily toggled between using [`TransmogStream::for_async`].
128
#[derive(Debug)]
129
pub struct TransmogStream<TReads, TWrites, TStream, TDestination, TFormat> {
130
    stream: TransmogReader<
131
        InternalTransmogWriter<TStream, TWrites, TDestination, TFormat>,
132
        TReads,
133
        TFormat,
134
    >,
135
}
136

            
137
#[doc(hidden)]
138
pub struct InternalTransmogWriter<TStream, T, TDestination, TFormat>(
139
    TransmogWriter<TStream, T, TDestination, TFormat>,
140
);
141

            
142
impl<TStream: fmt::Debug, T, TDestination, TFormat> fmt::Debug
143
    for InternalTransmogWriter<TStream, T, TDestination, TFormat>
144
{
145
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146
        self.get_ref().fmt(f)
147
    }
148
}
149

            
150
impl<TReads, TWrites, TStream, TDestination, TFormat>
151
    TransmogStream<TReads, TWrites, TStream, TDestination, TFormat>
152
{
153
    /// Gets a reference to the underlying stream.
154
    ///
155
    /// It is inadvisable to directly read from or write to the underlying stream.
156
    pub fn get_ref(&self) -> &TStream {
157
        self.stream.get_ref().0.get_ref()
158
    }
159

            
160
    /// Gets a mutable reference to the underlying stream.
161
    ///
162
    /// It is inadvisable to directly read from or write to the underlying stream.
163
    pub fn get_mut(&mut self) -> &mut TStream {
164
        self.stream.get_mut().0.get_mut()
165
    }
166

            
167
    /// Unwraps this `TransmogStream`, returning the underlying stream.
168
    ///
169
    /// Note that any leftover serialized data that has not yet been sent, or received data that
170
    /// has not yet been deserialized, is lost.
171
12
    pub fn into_inner(self) -> (TStream, TFormat) {
172
12
        self.stream.into_inner().0.into_inner()
173
12
    }
174
}
175

            
176
impl<TStream, TFormat> TransmogStream<(), (), TStream, SyncDestination, TFormat> {
177
    /// Creates a new instance that sends `format`-encoded payloads over `stream`.
178
2
    pub fn build(stream: TStream, format: TFormat) -> Builder<(), (), TStream, TFormat> {
179
2
        Builder::new(stream, format)
180
2
    }
181
}
182

            
183
impl<TReads, TWrites, TStream, TFormat>
184
    TransmogStream<TReads, TWrites, TStream, SyncDestination, TFormat>
185
where
186
    TFormat: Clone,
187
{
188
    /// Creates a new instance that sends `format`-encoded payloads over `stream`.
189
12
    pub fn new(stream: TStream, format: TFormat) -> Self {
190
12
        TransmogStream {
191
12
            stream: TransmogReader::new(
192
12
                InternalTransmogWriter(TransmogWriter::new(stream, format.clone())),
193
12
                format,
194
12
            ),
195
12
        }
196
12
    }
197

            
198
    /// Creates a new instance that sends `format`-encoded payloads over the
199
    /// default stream for `TStream`.
200
    pub fn default_for(format: TFormat) -> Self
201
    where
202
        TStream: Default,
203
    {
204
        Self::new(TStream::default(), format)
205
    }
206
}
207

            
208
impl<TReads, TWrites, TStream, TDestination, TFormat>
209
    TransmogStream<TReads, TWrites, TStream, TDestination, TFormat>
210
where
211
    TFormat: Clone,
212
{
213
    /// Make this stream include the serialized data's size before each serialized value.
214
    ///
215
    /// This is necessary for compatability with a remote [`TransmogReader`].
216
12
    pub fn for_async(self) -> TransmogStream<TReads, TWrites, TStream, AsyncDestination, TFormat> {
217
12
        let (stream, format) = self.into_inner();
218
12
        TransmogStream {
219
12
            stream: TransmogReader::new(
220
12
                InternalTransmogWriter(TransmogWriter::new(stream, format.clone()).for_async()),
221
12
                format,
222
12
            ),
223
12
        }
224
12
    }
225

            
226
    /// Make this stream only send Transmog-encoded values.
227
    ///
228
    /// This is necessary for compatability with stock Transmog receivers.
229
    pub fn for_sync(self) -> TransmogStream<TReads, TWrites, TStream, SyncDestination, TFormat> {
230
        let (stream, format) = self.into_inner();
231
        TransmogStream::new(stream, format)
232
    }
233
}
234

            
235
/// A reader of Transmog-encoded data from a [`TcpStream`](tokio::net::TcpStream).
236
pub type TransmogTokioTcpReader<'a, TReads, TFormat> =
237
    TransmogReader<tokio::net::tcp::ReadHalf<'a>, TReads, TFormat>;
238
/// A writer of Transmog-encoded data to a [`TcpStream`](tokio::net::TcpStream).
239
pub type TransmogTokioTcpWriter<'a, TWrites, TDestination, TFormat> =
240
    TransmogWriter<tokio::net::tcp::WriteHalf<'a>, TWrites, TDestination, TFormat>;
241

            
242
impl<TReads, TWrites, TDestination, TFormat>
243
    TransmogStream<TReads, TWrites, tokio::net::TcpStream, TDestination, TFormat>
244
where
245
    TFormat: Clone,
246
{
247
    /// Split a TCP-based stream into a read half and a write half.
248
    ///
249
    /// This is more performant than using a lock-based split like the one provided by `tokio-io`
250
    /// or `futures-util` since we know that reads and writes to a `TcpStream` can continue
251
    /// concurrently.
252
    ///
253
    /// Any partially sent or received state is preserved.
254
6
    pub fn tcp_split(
255
6
        &mut self,
256
6
    ) -> (
257
6
        TransmogTokioTcpReader<'_, TReads, TFormat>,
258
6
        TransmogTokioTcpWriter<'_, TWrites, TDestination, TFormat>,
259
6
    ) {
260
6
        // First, steal the reader state so it isn't lost
261
6
        let rbuff = self.stream.buffer.split();
262
6
        // Then, fish out the writer
263
6
        let writer = &mut self.stream.get_mut().0;
264
6
        let format = writer.format().clone();
265
6
        // And steal the writer state so it isn't lost
266
6
        let write_buffer = writer.buffer.split_off(0);
267
6
        let write_buffer_written = writer.written;
268
6
        // Now split the stream
269
6
        let (r, w) = writer.get_mut().split();
270
6
        // Then put the reader back together
271
6
        let mut reader = TransmogReader::new(r, format.clone());
272
6
        reader.buffer = rbuff;
273
6
        // And then the writer
274
6
        let mut writer: TransmogWriter<_, _, TDestination, TFormat> =
275
6
            TransmogWriter::new(w, format).make_for();
276
6
        writer.buffer = write_buffer;
277
6
        writer.written = write_buffer_written;
278
6
        // All good!
279
6
        (reader, writer)
280
6
    }
281
}
282

            
283
impl<TStream, T, TDestination, TFormat> AsyncRead
284
    for InternalTransmogWriter<TStream, T, TDestination, TFormat>
285
where
286
    TStream: AsyncRead + Unpin,
287
{
288
13832
    fn poll_read(
289
13832
        self: Pin<&mut Self>,
290
13832
        cx: &mut Context<'_>,
291
13832
        buf: &mut ReadBuf<'_>,
292
13832
    ) -> Poll<Result<(), io::Error>> {
293
13832
        Pin::new(self.get_mut().get_mut()).poll_read(cx, buf)
294
13832
    }
295
}
296

            
297
impl<TStream, T, TDestination, TFormat> Deref
298
    for InternalTransmogWriter<TStream, T, TDestination, TFormat>
299
{
300
    type Target = TransmogWriter<TStream, T, TDestination, TFormat>;
301
    fn deref(&self) -> &Self::Target {
302
        &self.0
303
    }
304
}
305
impl<TStream, T, TDestination, TFormat> DerefMut
306
    for InternalTransmogWriter<TStream, T, TDestination, TFormat>
307
{
308
177694
    fn deref_mut(&mut self) -> &mut Self::Target {
309
177694
        &mut self.0
310
177694
    }
311
}
312

            
313
impl<TReads, TWrites, TStream, TDestination, TFormat> Stream
314
    for TransmogStream<TReads, TWrites, TStream, TDestination, TFormat>
315
where
316
    TStream: Unpin,
317
    TransmogReader<
318
        InternalTransmogWriter<TStream, TWrites, TDestination, TFormat>,
319
        TReads,
320
        TFormat,
321
    >: Stream<Item = Result<TReads, TFormat::Error>>,
322
    TFormat: Format<'static, TWrites>,
323
{
324
    type Item = Result<TReads, TFormat::Error>;
325
82050
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
326
82050
        Pin::new(&mut self.stream).poll_next(cx)
327
82050
    }
328
}
329

            
330
impl<TReads, TWrites, TStream, TDestination, TFormat> Sink<TWrites>
331
    for TransmogStream<TReads, TWrites, TStream, TDestination, TFormat>
332
where
333
    TStream: Unpin,
334
    TransmogWriter<TStream, TWrites, TDestination, TFormat>: Sink<TWrites, Error = TFormat::Error>,
335
    TFormat: Format<'static, TWrites>,
336
{
337
    type Error = TFormat::Error;
338

            
339
81927
    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
340
81927
        Pin::new(&mut **self.stream.get_mut()).poll_ready(cx)
341
81927
    }
342

            
343
81927
    fn start_send(mut self: Pin<&mut Self>, item: TWrites) -> Result<(), Self::Error> {
344
81927
        Pin::new(&mut **self.stream.get_mut()).start_send(item)
345
81927
    }
346

            
347
7
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
348
7
        Pin::new(&mut **self.stream.get_mut()).poll_flush(cx)
349
7
    }
350

            
351
1
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
352
1
        Pin::new(&mut **self.stream.get_mut()).poll_close(cx)
353
1
    }
354
}
355

            
356
#[cfg(test)]
357
mod tests {
358
    use futures::prelude::*;
359
    use transmog::OwnedDeserializer;
360
    use transmog_bincode::Bincode;
361
    use transmog_pot::Pot;
362

            
363
    use super::*;
364

            
365
4
    async fn it_works<
366
4
        T: std::fmt::Debug + Clone + PartialEq + Send,
367
4
        TFormat: Format<'static, T> + OwnedDeserializer<T> + Clone + 'static,
368
4
    >(
369
4
        format: TFormat,
370
4
        values: &[T],
371
4
    ) {
372
4
        let echo = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
373
4
        let addr = echo.local_addr().unwrap();
374
4

            
375
4
        let task_format = format.clone();
376
4
        tokio::spawn(async move {
377
4
            let (stream, _) = echo.accept().await.unwrap();
378
4
            let mut stream = TransmogStream::<T, T, _, _, _>::new(stream, task_format).for_async();
379
4
            let (r, w) = stream.tcp_split();
380
19
            r.forward(w).await.unwrap();
381
4
        });
382

            
383
4
        let client = tokio::net::TcpStream::connect(&addr).await.unwrap();
384
4
        let mut client = TransmogStream::<T, T, _, _, _>::new(client, format).for_async();
385

            
386
10
        for value in values {
387
6
            client.send(value.clone()).await.unwrap();
388
13
            assert_eq!(&client.next().await.unwrap().unwrap(), value);
389
        }
390

            
391
4
        drop(client);
392
4
    }
393

            
394
1
    #[tokio::test]
395
1
    async fn it_works_bincode() {
396
1
        // Test short payloads
397
3
        it_works(Bincode::default(), &[44, 42]).await;
398
        // Test a long payload
399
5
        it_works(Bincode::default(), &[vec![0_u8; 1_000_000]]).await;
400
    }
401

            
402
1
    #[tokio::test]
403
1
    async fn it_works_pot() {
404
1
        // Test short payloads
405
3
        it_works(Pot::default(), &[44, 42]).await;
406
        // Test a long payload
407
6
        it_works(Pot::default(), &[vec![0_u8; 1_000_000]]).await;
408
    }
409

            
410
1
    #[tokio::test]
411
1
    async fn lots() {
412
1
        let echo = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
413
1
        let addr = echo.local_addr().unwrap();
414
1

            
415
1
        tokio::spawn(async move {
416
1
            let (stream, _) = echo.accept().await.unwrap();
417
1
            let mut stream =
418
1
                TransmogStream::<usize, usize, _, _, _>::new(stream, Bincode::default())
419
1
                    .for_async();
420
1
            let (r, w) = stream.tcp_split();
421
172
            r.forward(w).await.unwrap();
422
1
        });
423
1

            
424
1
        let n = 81920;
425
1
        let stream = tokio::net::TcpStream::connect(&addr).await.unwrap();
426
1
        let mut c = TransmogStream::new(stream, Bincode::default()).for_async();
427
1

            
428
1
        futures::stream::iter(0_usize..n)
429
1
            .map(Ok)
430
1
            .forward(&mut c)
431
            .await
432
1
            .unwrap();
433
1

            
434
1
        let mut at = 0;
435
81921
        while let Some(got) = c.next().await.transpose().unwrap() {
436
81920
            assert_eq!(at, got);
437
81920
            at += 1;
438
        }
439
1
        assert_eq!(at, n);
440
    }
441
}