1
use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
2
use std::sync::{Arc, Weak};
3

            
4
use okaywal::file_manager::{self, File as _};
5
use okaywal::{EntryId, LogManager, ReadChunkResult, WriteAheadLog};
6

            
7
use crate::format::{ByteUtil, GrainAllocationInfo, GrainAllocationStatus, GrainId, TransactionId};
8
use crate::store::BasinState;
9
use crate::util::{u32_to_usize, usize_to_u32};
10
use crate::{Data as DatabaseData, Database, Error, Result};
11

            
12
#[derive(Debug)]
13
pub struct WalManager<FileManager>
14
where
15
    FileManager: file_manager::FileManager,
16
{
17
    db: Weak<DatabaseData<FileManager>>,
18
    scratch: Vec<u8>,
19
}
20

            
21
impl<FileManager> WalManager<FileManager>
22
where
23
    FileManager: file_manager::FileManager,
24
{
25
331
    pub(super) fn new(database: &Arc<DatabaseData<FileManager>>) -> Self {
26
331
        Self {
27
331
            db: Arc::downgrade(database),
28
331
            scratch: Vec::new(),
29
331
        }
30
331
    }
31
}
32

            
33
impl<FileManager> LogManager<FileManager> for WalManager<FileManager>
34
where
35
    FileManager: file_manager::FileManager,
36
{
37
    fn recover(&mut self, entry: &mut okaywal::Entry<'_, FileManager::File>) -> io::Result<()> {
38
2
        if let Some(database) = self.db.upgrade() {
39
2
            let mut written_grains = Vec::new();
40
2
            let mut freed_grains = Vec::new();
41
2
            let mut index_metadata = database.atlas.current_index_metadata()?;
42
            loop {
43
10
                match entry.read_chunk()? {
44
8
                    ReadChunkResult::Chunk(mut chunk) => {
45
8
                        let position = chunk.log_position();
46
8
                        self.scratch
47
8
                            .resize(u32_to_usize(chunk.bytes_remaining())?, 0);
48
8
                        chunk.read_exact(&mut self.scratch)?;
49

            
50
8
                        let chunk = WalChunk::read(&self.scratch)?;
51
8
                        match chunk {
52
4
                            WalChunk::NewGrainInWal { id, .. } => {
53
4
                                written_grains.push((id, position))
54
                            }
55
2
                            WalChunk::FinishTransaction { commit_log_entry } => {
56
2
                                index_metadata.commit_log_head = Some(commit_log_entry);
57
2
                            }
58
2
                            WalChunk::UpdatedEmbeddedHeader(header) => {
59
2
                                index_metadata.embedded_header_data = header;
60
2
                            }
61
                            WalChunk::CheckpointTo(tx_id) => {
62
                                index_metadata.checkpoint_target = tx_id;
63
                            }
64
                            WalChunk::CheckpointedTo(tx_id) => {
65
                                index_metadata.checkpointed_to = tx_id;
66
                            }
67
                            WalChunk::FreeGrain(id) => {
68
                                freed_grains.push(id);
69
                            }
70
                            // Archiving a grain doesn't have any effect on the in-memory state
71
                            WalChunk::ArchiveGrain(_) => {}
72
                        }
73
                    }
74
2
                    ReadChunkResult::EndOfEntry => break,
75
                    ReadChunkResult::AbortedEntry => return Ok(()),
76
                }
77
            }
78

            
79
2
            freed_grains.sort_unstable();
80
2

            
81
2
            database.atlas.note_transaction_committed(
82
2
                index_metadata,
83
2
                written_grains,
84
2
                &freed_grains,
85
2
                true,
86
2
            )?;
87
        }
88

            
89
2
        Ok(())
90
2
    }
91

            
92
    fn checkpoint_to(
93
        &mut self,
94
        last_checkpointed_id: okaywal::EntryId,
95
        checkpointed_entries: &mut okaywal::SegmentReader<FileManager::File>,
96
        wal: &WriteAheadLog<FileManager>,
97
    ) -> std::io::Result<()> {
98
359
        if let Some(database) = self.db.upgrade() {
99
359
            let database = Database {
100
359
                data: database,
101
359
                wal: wal.clone(),
102
359
            };
103
359
            let fsyncs = database.data.store.file_manager.new_fsync_batch()?;
104
359
            let latest_tx_id = TransactionId::from(last_checkpointed_id);
105
359
            let mut store = database.data.store.lock()?;
106
359
            let mut needs_directory_sync = store.needs_directory_sync;
107
359
            let mut all_changed_grains = Vec::new();
108
359
            let mut latest_commit_log_entry = store.index.active.commit_log_head;
109
359
            let mut latest_embedded_header_data = store.index.active.embedded_header_data;
110
359
            let mut latest_checkpoint_target = store.index.active.checkpoint_target;
111
359
            let mut latest_checkpointed_to = store.index.active.checkpointed_to;
112
359
            // We allocate the transaction grains vec once and reuse the vec to
113
359
            // avoid reallocating.
114
359
            let mut transaction_grains = Vec::new();
115
3349
            'entry_loop: while let Some(mut entry) = checkpointed_entries.read_entry()? {
116
2990
                let checkpointed_tx = TransactionId::from(entry.id());
117
2990
                // Because an entry could be aborted, we need to make sure we don't
118
2990
                // modify our DiskState until after we've read every chunk. We will
119
2990
                // write new grain data directly to the segments, but the headers
120
2990
                // won't be updated until after the loop as well.
121
2990
                transaction_grains.clear();
122
39145
                while let Some(mut chunk) = match entry.read_chunk()? {
123
36155
                    ReadChunkResult::Chunk(chunk) => Some(chunk),
124
2990
                    ReadChunkResult::EndOfEntry => None,
125
                    ReadChunkResult::AbortedEntry => continue 'entry_loop,
126
                } {
127
36155
                    self.scratch.clear();
128
36155
                    chunk.read_to_end(&mut self.scratch)?;
129
36155
                    if !chunk.check_crc()? {
130
                        return Err(Error::ChecksumFailed.into());
131
36155
                    }
132
36155

            
133
36155
                    match WalChunk::read(&self.scratch)? {
134
27928
                        WalChunk::NewGrainInWal { id, data } => {
135
27928
                            let basin = store.basins.get_or_insert_with(id.basin_id(), || {
136
37
                                BasinState::default_for(id.basin_id())
137
27928
                            });
138
27928
                            let stratum = basin.get_or_allocate_stratum(
139
27928
                                id.stratum_id(),
140
27928
                                &database.data.store.directory,
141
27928
                            );
142
27928
                            let mut file = BufWriter::new(stratum.get_or_open_file(
143
27928
                                &database.data.store.file_manager,
144
27928
                                &mut needs_directory_sync,
145
27928
                            )?);
146

            
147
                            // Write the grain data to disk.
148
27928
                            let file_position = id.file_position();
149
27928
                            file.seek(SeekFrom::Start(file_position))?;
150
27928
                            file.write_all(&checkpointed_tx.to_be_bytes())?;
151
27928
                            file.write_all(&usize_to_u32(data.len())?.to_be_bytes())?;
152
27928
                            file.write_all(data)?;
153
27928
                            let crc32 = crc32c::crc32c(data);
154
27928
                            file.write_all(&crc32.to_be_bytes())?;
155
27928
                            file.flush()?;
156

            
157
27928
                            transaction_grains.push((id, GrainAllocationStatus::Allocated));
158
                        }
159
1667
                        WalChunk::ArchiveGrain(id) => {
160
1667
                            transaction_grains.push((id, GrainAllocationStatus::Archived));
161
1667
                        }
162
907
                        WalChunk::FreeGrain(id) => {
163
907
                            transaction_grains.push((id, GrainAllocationStatus::Free));
164
907
                        }
165
2990
                        WalChunk::FinishTransaction { commit_log_entry } => {
166
2990
                            latest_commit_log_entry = Some(commit_log_entry);
167
2990
                        }
168
10
                        WalChunk::UpdatedEmbeddedHeader(header) => {
169
10
                            latest_embedded_header_data = header;
170
10
                        }
171
2623
                        WalChunk::CheckpointTo(tx_id) => {
172
2623
                            latest_checkpoint_target = tx_id;
173
2623
                        }
174
30
                        WalChunk::CheckpointedTo(tx_id) => {
175
30
                            latest_checkpointed_to = tx_id;
176
30
                        }
177
                    }
178
                }
179

            
180
2990
                all_changed_grains.append(&mut transaction_grains);
181
            }
182

            
183
359
            all_changed_grains.sort_unstable();
184
359

            
185
359
            let mut index = 0;
186
761
            while let Some((first_id, _)) = all_changed_grains.get(index).cloned() {
187
402
                let basin = store.basins.get_or_insert_with(first_id.basin_id(), || {
188
                    BasinState::default_for(first_id.basin_id())
189
402
                });
190
402
                let stratum = basin
191
402
                    .get_or_allocate_stratum(first_id.stratum_id(), &database.data.store.directory);
192

            
193
                // Update the stratum header for the disk state.
194
                loop {
195
                    // This is a messy match statement, but the goal is to only
196
                    // re-lookup basin and stratum when we jump to a new
197
                    // stratum.
198
30904
                    match all_changed_grains.get(index).copied() {
199
30502
                        Some((id, status))
200
30545
                            if id.basin_id() == first_id.basin_id()
201
30525
                                && id.stratum_id() == first_id.stratum_id() =>
202
30502
                        {
203
30502
                            let local_index = usize::from(id.local_grain_index().as_u16());
204
30502
                            if status == GrainAllocationStatus::Free {
205
907
                                // Free grains are just 0s.
206
907
                                stratum.header.active.grains
207
907
                                    [local_index..local_index + usize::from(id.grain_count())]
208
907
                                    .fill(0);
209
907
                            } else {
210
221417
                                for index in 0..id.grain_count() {
211
221417
                                    let status = if status == GrainAllocationStatus::Allocated {
212
217800
                                        GrainAllocationInfo::allocated(id.grain_count() - index)
213
                                    } else {
214
3617
                                        GrainAllocationInfo::archived(id.grain_count() - index)
215
                                    };
216
221417
                                    stratum.header.active.grains
217
221417
                                        [local_index + usize::from(index)] = status.0;
218
                                }
219
                            }
220
                        }
221
402
                        _ => break,
222
                    }
223
30502
                    index += 1;
224
                }
225
402
                stratum.write_header(latest_tx_id, &fsyncs)?;
226
            }
227

            
228
359
            store.index.active.commit_log_head = latest_commit_log_entry;
229
359
            store.index.active.embedded_header_data = latest_embedded_header_data;
230
359
            store.index.active.checkpoint_target = latest_checkpoint_target;
231
359
            store.index.active.checkpointed_to = latest_checkpointed_to;
232
359

            
233
359
            store.write_header(latest_tx_id, &fsyncs)?;
234

            
235
359
            if needs_directory_sync {
236
38
                store.needs_directory_sync = false;
237
38
                fsyncs.queue_fsync_all(store.directory.try_clone()?)?;
238
321
            }
239

            
240
359
            fsyncs.wait_all()?;
241

            
242
359
            database
243
359
                .data
244
359
                .atlas
245
359
                .note_grains_checkpointed(&all_changed_grains)?;
246

            
247
359
            database
248
359
                .data
249
359
                .checkpointer
250
359
                .checkpoint_to(latest_checkpoint_target);
251
359

            
252
359
            Ok(())
253
        } else {
254
            // TODO OkayWAL should have a way to be told "shut down" from this
255
            // callback.
256
            Err(io::Error::from(Error::Shutdown))
257
        }
258
359
    }
259
}
260

            
261
#[derive(Debug)]
262
pub enum WalChunk<'a> {
263
    NewGrainInWal { id: GrainId, data: &'a [u8] },
264
    ArchiveGrain(GrainId),
265
    FreeGrain(GrainId),
266
    UpdatedEmbeddedHeader(Option<GrainId>),
267
    CheckpointTo(TransactionId),
268
    CheckpointedTo(TransactionId),
269
    FinishTransaction { commit_log_entry: GrainId },
270
}
271

            
272
impl<'a> WalChunk<'a> {
273
    pub const COMMAND_LENGTH: u32 = 9;
274
    pub const COMMAND_LENGTH_USIZE: usize = Self::COMMAND_LENGTH as usize;
275

            
276
36165
    pub fn read(buffer: &'a [u8]) -> Result<Self> {
277
36165
        if buffer.len() < Self::COMMAND_LENGTH_USIZE {
278
1
            return Err(Error::ValueOutOfBounds);
279
36164
        }
280
36164
        let kind = buffer[0];
281
36164
        match kind {
282
            0 => Ok(Self::NewGrainInWal {
283
27932
                id: GrainId::from_bytes(&buffer[1..9]).ok_or(Error::InvalidGrainId)?,
284
27932
                data: &buffer[9..],
285
            }),
286
            1 => Ok(Self::ArchiveGrain(
287
1667
                GrainId::from_bytes(&buffer[1..9]).ok_or(Error::InvalidGrainId)?,
288
            )),
289
            2 => Ok(Self::FreeGrain(
290
907
                GrainId::from_bytes(&buffer[1..9]).ok_or(Error::InvalidGrainId)?,
291
            )),
292
12
            3 => Ok(Self::UpdatedEmbeddedHeader(GrainId::from_bytes(
293
12
                &buffer[1..9],
294
12
            ))),
295
2623
            4 => Ok(Self::CheckpointTo(TransactionId::from(EntryId(
296
2623
                u64::from_be_bytes(buffer[1..9].try_into().expect("u64 is 8 bytes")),
297
2623
            )))),
298
30
            5 => Ok(Self::CheckpointedTo(TransactionId::from(EntryId(
299
30
                u64::from_be_bytes(buffer[1..9].try_into().expect("u64 is 8 bytes")),
300
30
            )))),
301
            255 => Ok(Self::FinishTransaction {
302
2992
                commit_log_entry: GrainId::from_bytes(&buffer[1..9])
303
2992
                    .ok_or(Error::InvalidGrainId)?,
304
            }),
305
1
            _ => Err(Error::verification_failed("invalid wal chunk")),
306
        }
307
36165
    }
308

            
309
    pub fn write_new_grain<W: Write>(grain_id: GrainId, data: &[u8], writer: &mut W) -> Result<()> {
310
30754
        writer.write_all(&[0])?;
311
30754
        writer.write_all(&grain_id.to_be_bytes())?;
312
30754
        writer.write_all(data)?;
313
30754
        Ok(())
314
30754
    }
315

            
316
    pub fn write_archive_grain<W: Write>(grain_id: GrainId, writer: &mut W) -> Result<()> {
317
2005
        writer.write_all(&[1])?;
318
2005
        writer.write_all(&grain_id.to_be_bytes())?;
319
2005
        Ok(())
320
2005
    }
321

            
322
    pub fn write_free_grain<W: Write>(grain_id: GrainId, writer: &mut W) -> Result<()> {
323
1119
        writer.write_all(&[2])?;
324
1119
        writer.write_all(&grain_id.to_be_bytes())?;
325
1119
        Ok(())
326
1119
    }
327

            
328
    pub fn write_embedded_header_update<W: Write>(
329
        new_embedded_header: Option<GrainId>,
330
        writer: &mut W,
331
    ) -> Result<()> {
332
13
        writer.write_all(&[3])?;
333
13
        writer.write_all(&new_embedded_header.unwrap_or(GrainId::NONE).to_be_bytes())?;
334
13
        Ok(())
335
13
    }
336

            
337
    pub fn write_checkpoint_to<W: Write>(
338
        checkpoint_to: TransactionId,
339
        writer: &mut W,
340
    ) -> Result<()> {
341
2979
        writer.write_all(&[4])?;
342
2979
        writer.write_all(&checkpoint_to.to_be_bytes())?;
343
2979
        Ok(())
344
2979
    }
345

            
346
    pub fn write_checkpointed_to<W: Write>(
347
        checkpointed_to: TransactionId,
348
        writer: &mut W,
349
    ) -> Result<()> {
350
34
        writer.write_all(&[5])?;
351
34
        writer.write_all(&checkpointed_to.to_be_bytes())?;
352
34
        Ok(())
353
34
    }
354

            
355
30754
    pub const fn new_grain_length(data_length: u32) -> u32 {
356
30754
        data_length + 9
357
30754
    }
358

            
359
    pub fn write_transaction_tail<W: Write>(
360
        commit_log_entry_id: GrainId,
361
        writer: &mut W,
362
    ) -> Result<()> {
363
3353
        writer.write_all(&[255])?;
364
3353
        writer.write_all(&commit_log_entry_id.to_be_bytes())?;
365
3353
        Ok(())
366
3353
    }
367
}
368

            
369
1
#[test]
370
1
fn wal_chunk_error_tests() {
371
1
    // The valid WalChunks are all tested by virtue of testing sediment. The
372
1
    // errors, however, are nearly impossible to simulate due to the wal being
373
1
    // completely abstracted away.
374
1
    let Error::VerificationFailed(_) = WalChunk::read(&[254,0,0,0,0,0,0,0,0]).unwrap_err() else { unreachable!() };
375
1
    let Error::ValueOutOfBounds = WalChunk::read(&[254]).unwrap_err() else { unreachable!() };
376
1
}