1
#![doc = include_str!("./.crate-docs.md")]
2
#![forbid(unsafe_code)]
3
#![warn(
4
    clippy::cargo,
5
    missing_docs,
6
    // clippy::missing_docs_in_private_items,
7
    clippy::pedantic,
8
    future_incompatible,
9
    rust_2018_idioms,
10
)]
11
#![allow(
12
    clippy::missing_errors_doc, // TODO clippy::missing_errors_doc
13
    clippy::option_if_let_else,
14
)]
15

            
16
use std::{
17
    fmt::Display,
18
    io::{BufRead, BufReader, Read, Write},
19
};
20

            
21
use ordered_varint::Variable;
22
pub use transmog;
23

            
24
const MAGIC_CODE: &[u8] = b"DVer";
25

            
26
/// A type that has a constant version number.
27
pub trait ConstVersioned {
28
    /// The version of this type.
29
    const VERSION: u64;
30
}
31

            
32
/// A type that has a version number.
33
pub trait Versioned {
34
    /// The version of this value.
35
    fn version(&self) -> u64;
36
}
37

            
38
impl<T: ConstVersioned> Versioned for T {
39
1
    fn version(&self) -> u64 {
40
1
        T::VERSION
41
1
    }
42
}
43

            
44
impl Versioned for u64 {
45
8
    fn version(&self) -> u64 {
46
8
        *self
47
8
    }
48
}
49

            
50
14
fn header(version: u64) -> Option<Vec<u8>> {
51
14
    if version > 0 {
52
10
        let mut header = Vec::with_capacity(13);
53
10
        header.extend(MAGIC_CODE);
54
10
        version
55
10
            .encode_variable(&mut header)
56
10
            .expect("version too large");
57
10
        Some(header)
58
    } else {
59
4
        None
60
    }
61
14
}
62

            
63
/// Write a version header for `versioned`, if needed, to `write`.
64
pub fn write_header<V: Versioned, W: Write>(
65
    versioned: &V,
66
    mut write: W,
67
) -> Result<(), std::io::Error> {
68
2
    if let Some(header) = header(versioned.version()) {
69
1
        write.write_all(&header)?;
70
1
    }
71
2
    Ok(())
72
2
}
73

            
74
/// Wrap `data` with a version header for `versioned`, if needed.
75
pub fn wrap<V: Versioned>(versioned: &V, mut data: Vec<u8>) -> Vec<u8> {
76
4
    if let Some(header) = header(versioned.version()) {
77
3
        data.reserve(header.len());
78
3
        data.splice(0..0, header);
79
4
    }
80

            
81
4
    data
82
4
}
83

            
84
/// Decode a payload that may or may not contain a version header. If no header
85
/// is found, `callback` is invoked with `0`. If a header is found, the parsed
86
/// version number is passed to `callback`.
87
3
pub fn decode<E: Display, T, R: Read, F: FnOnce(u64, BufReader<R>) -> Result<T, Error<E>>>(
88
3
    data: R,
89
3
    callback: F,
90
3
) -> Result<T, Error<E>> {
91
3
    let mut buffered = BufReader::with_capacity(13, data);
92
3
    let mut peeked_header = buffered.fill_buf()?;
93

            
94
3
    if peeked_header.starts_with(&MAGIC_CODE[0..4]) {
95
1
        let header_start = peeked_header.as_ptr() as usize;
96
1
        peeked_header = &peeked_header[4..];
97

            
98
1
        let version = u64::decode_variable(&mut peeked_header)?;
99
1
        let header_end = peeked_header.as_ptr() as usize;
100
1
        buffered.consume(header_end - header_start);
101
1

            
102
1
        callback(version, buffered)
103
    } else {
104
2
        callback(0, buffered)
105
    }
106
3
}
107

            
108
/// Decode a payload that may or may not contain a version header. If no header
109
/// is found, the result is `(0, data)`. If a header is found, the parsed
110
/// version number is returned along with a slice reference containing the
111
/// previously-wrapped data.
112
#[must_use]
113
14
pub fn unwrap_version(mut data: &[u8]) -> (u64, &[u8]) {
114
14
    if data.starts_with(&MAGIC_CODE[0..4]) {
115
7
        data = &data[4..];
116
7
        if let Ok(version) = u64::decode_variable(&mut data) {
117
7
            return (version, data);
118
        }
119
7
    }
120
7
    (0, data)
121
14
}
122

            
123
/// An error from `transmog-versions`.
124
#[derive(thiserror::Error, Debug)]
125
pub enum Error<E: Display> {
126
    /// An unknown version was encountered.
127
    #[error("{0}")]
128
    UnknownVersion(#[from] UnknownVersion),
129
    /// An io error occurred
130
    #[error("io error: {0}")]
131
    Io(#[from] std::io::Error),
132
    /// An error occurred from a format.
133
    #[error("{0}")]
134
    Format(E),
135
}
136

            
137
/// An unknown version was encountered.
138
#[derive(thiserror::Error, Debug)]
139
#[error("unknown version: {0}")]
140
pub struct UnknownVersion(pub u64);
141

            
142
1
#[test]
143
1
fn basic_tests() {
144
1
    use std::convert::Infallible;
145
1
    let payload = b"hello world";
146
1
    let mut wrapped_with_0_version = Vec::new();
147
1
    write_header(&0_u64, &mut wrapped_with_0_version).unwrap();
148
1
    wrapped_with_0_version.extend(payload);
149
1
    decode::<Infallible, _, _, _>(&wrapped_with_0_version[..], |version, mut contained| {
150
1
        assert_eq!(version, 0);
151
1
        let mut bytes = Vec::new();
152
1
        contained.read_to_end(&mut bytes).unwrap();
153
1
        assert_eq!(bytes, payload);
154
1
        Ok(())
155
1
    })
156
1
    .unwrap();
157
1

            
158
1
    let bytes = wrap(&1_u64, payload.to_vec());
159
1
    let (version, unwrapped_bytes) = unwrap_version(&bytes);
160
1
    assert_eq!(version, 1);
161
1
    assert_eq!(unwrapped_bytes, payload);
162

            
163
1
    let unwrapped_version = unwrap_version(&payload[..]);
164
1
    assert_eq!(unwrapped_version, (0, &payload[..]));
165
1
}