1
use std::io::{self, Seek, SeekFrom, Write};
2

            
3
use crate::to_io_result::ToIoResult;
4

            
5
#[derive(Debug)]
6
pub struct Buffered<F>
7
where
8
    F: Bufferable + Seek + Write,
9
{
10
    buffer: Vec<u8>,
11
    position: u64,
12
    buffer_write_position: usize,
13
    length: u64,
14
    file: F,
15
}
16

            
17
impl<F> Buffered<F>
18
where
19
    F: Bufferable + Seek + Write,
20
{
21
720
    pub fn with_capacity(mut file: F, capacity: usize) -> io::Result<Self> {
22
720
        let length = file.len()?;
23
720
        let position = file.stream_position()?;
24
720
        Ok(Self {
25
720
            buffer: Vec::with_capacity(capacity),
26
720
            position,
27
720
            buffer_write_position: 0,
28
720
            length,
29
720
            file,
30
720
        })
31
720
    }
32

            
33
16804
    fn flush_buffer(&mut self) -> io::Result<()> {
34
16804
        if !self.buffer.is_empty() {
35
14556
            self.file.write_all(&self.buffer)?;
36
14556
            let bytes_written = u64::try_from(self.buffer.len()).to_io()?;
37
14556
            self.position += bytes_written;
38
14556
            self.length = self.length.max(self.position);
39
14556
            self.buffer_write_position = 0;
40
14556
            self.buffer.clear();
41
2248
        }
42
16804
        Ok(())
43
16804
    }
44

            
45
92574
    pub fn position(&self) -> u64 {
46
92574
        self.position + u64::try_from(self.buffer_write_position).expect("impossibly large buffer")
47
92574
    }
48

            
49
22348
    pub const fn buffer_position(&self) -> u64 {
50
22348
        self.position
51
22348
    }
52

            
53
11174
    pub fn inner(&self) -> &F {
54
11174
        &self.file
55
11174
    }
56
}
57

            
58
impl<F> Write for Buffered<F>
59
where
60
    F: Bufferable + Seek + Write,
61
{
62
215634
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
63
215634
        if self.buffer.capacity() == self.buffer_write_position {
64
566
            self.flush_buffer()?;
65
215068
        }
66

            
67
        // If what we're writing is larger than our buffer, skip the buffer
68
        // entirely.
69
215634
        if buf.len() > self.buffer.capacity() {
70
            // Ensure what we've buffered is already written.
71
2641
            self.flush_buffer()?;
72
2641
            let bytes_written = self.file.write(buf)?;
73
2641
            self.position += u64::try_from(bytes_written).to_io()?;
74
2641
            return Ok(bytes_written);
75
212993
        }
76
212993

            
77
212993
        let bytes_remaining = self.buffer.capacity() - self.buffer_write_position;
78
212993
        let bytes_to_write = buf.len().min(bytes_remaining);
79
212993
        if bytes_to_write > 0 {
80
212993
            let bytes_to_copy =
81
212993
                (self.buffer.len() - self.buffer_write_position).min(bytes_to_write);
82
212993
            if bytes_to_copy > 0 {
83
2
                self.buffer[self.buffer_write_position..self.buffer_write_position + bytes_to_copy]
84
2
                    .copy_from_slice(&buf[..bytes_to_copy]);
85
212991
            }
86
212993
            let bytes_to_extend = bytes_to_write - bytes_to_copy;
87
212993
            if bytes_to_extend > 0 {
88
212991
                self.buffer
89
212991
                    .extend_from_slice(&buf[bytes_to_copy..bytes_to_write]);
90
212991
            }
91
212993
            self.buffer_write_position += bytes_to_write;
92
        }
93

            
94
212993
        Ok(bytes_to_write)
95
215634
    }
96

            
97
    fn flush(&mut self) -> io::Result<()> {
98
11349
        self.flush_buffer()?;
99
11349
        self.file.flush()?;
100
11349
        Ok(())
101
11349
    }
102
}
103

            
104
impl<F> Seek for Buffered<F>
105
where
106
    F: Bufferable + Seek + Write,
107
{
108
2249
    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
109
2249
        let buffer_write_position = u64::try_from(self.buffer_write_position).to_io()?;
110
2249
        let new_position = match pos {
111
2249
            SeekFrom::Start(position) => position,
112
            SeekFrom::End(offset) => {
113
                if let Ok(offset) = u64::try_from(offset) {
114
                    offset + self.length
115
                } else {
116
                    let offset = u64::try_from(-offset).unwrap();
117
                    self.length - offset
118
                }
119
            }
120
            SeekFrom::Current(offset) => {
121
                if let Ok(offset) = u64::try_from(offset) {
122
                    self.position + buffer_write_position + offset
123
                } else {
124
                    let absolute_offset = -offset;
125
                    let offset = u64::try_from(absolute_offset).unwrap();
126
                    self.position + buffer_write_position - offset
127
                }
128
            }
129
        };
130

            
131
2249
        let buffer_len = u64::try_from(self.buffer.len()).unwrap();
132
2249
        let new_position_in_buffer = match new_position.checked_sub(self.position) {
133
1
            Some(position) if position < buffer_len => Some(position),
134
2248
            _ => None,
135
        };
136

            
137
2249
        if let Some(new_position_in_buffer) = new_position_in_buffer {
138
1
            self.buffer_write_position = usize::try_from(new_position_in_buffer).to_io()?;
139
        } else {
140
2248
            self.flush_buffer()?;
141
2248
            self.file.seek(SeekFrom::Start(new_position))?;
142
2248
            self.position = new_position;
143
        }
144

            
145
2249
        Ok(new_position)
146
2249
    }
147
}
148

            
149
pub trait Bufferable {
150
    fn len(&self) -> io::Result<u64>;
151
    fn set_len(&self, new_length: u64) -> io::Result<()>;
152
}
153

            
154
impl<F> Bufferable for F
155
where
156
    F: file_manager::File,
157
{
158
720
    fn len(&self) -> io::Result<u64> {
159
720
        file_manager::File::len(self)
160
720
    }
161

            
162
    fn set_len(&self, new_length: u64) -> io::Result<()> {
163
        self.set_len(new_length)
164
    }
165
}
166

            
167
#[derive(Debug)]
168
pub struct WriteBuffer {
169
    pub bytes: Vec<u8>,
170
    pub position: u64,
171
}