1
use std::{
2
    collections::{BTreeMap, HashMap},
3
    io::{Read, Write},
4
    iter::repeat_with,
5
    path::Path,
6
    sync::Arc,
7
    time::Duration,
8
};
9

            
10
use file_manager::{fs::StdFileManager, memory::MemoryFileManager, FileManager};
11
use parking_lot::Mutex;
12
use tempfile::tempdir;
13

            
14
use crate::{
15
    entry::NEW_ENTRY, Configuration, Entry, EntryId, LogManager, RecoveredSegment, Recovery,
16
    SegmentReader, WriteAheadLog,
17
};
18

            
19
212
#[derive(Default, Debug, Clone)]
20
struct LoggingCheckpointer {
21
    invocations: Arc<Mutex<Vec<CheckpointCall>>>,
22
}
23

            
24
412
#[derive(Debug)]
25
enum CheckpointCall {
26
    ShouldRecoverSegment {
27
        version_info: Vec<u8>,
28
    },
29
    Recover {
30
        entry_id: EntryId,
31
        data: Vec<Vec<u8>>,
32
    },
33
    Checkpoint {
34
        data: HashMap<EntryId, Vec<Vec<u8>>>,
35
    },
36
}
37

            
38
impl<M> LogManager<M> for LoggingCheckpointer
39
where
40
    M: file_manager::FileManager,
41
{
42
205
    fn should_recover_segment(&mut self, segment: &RecoveredSegment) -> std::io::Result<Recovery> {
43
205
        let mut invocations = self.invocations.lock();
44
205
        invocations.push(CheckpointCall::ShouldRecoverSegment {
45
205
            version_info: segment.version_info.clone(),
46
205
        });
47
205

            
48
205
        Ok(Recovery::Recover)
49
205
    }
50

            
51
105
    fn recover(&mut self, entry: &mut Entry<'_, M::File>) -> std::io::Result<()> {
52
105
        let entry_id = entry.id;
53

            
54
105
        if let Some(chunks) = entry.read_all_chunks()? {
55
105
            let mut invocations = self.invocations.lock();
56
105
            invocations.push(CheckpointCall::Recover {
57
105
                entry_id,
58
105
                data: chunks,
59
105
            });
60
105
        }
61

            
62
105
        Ok(())
63
105
    }
64

            
65
102
    fn checkpoint_to(
66
102
        &mut self,
67
102
        _last_checkpointed_id: EntryId,
68
102
        entries: &mut SegmentReader<M::File>,
69
102
        _wal: &WriteAheadLog<M>,
70
102
    ) -> std::io::Result<()> {
71
102
        let mut invocations = self.invocations.lock();
72
102
        let mut data = HashMap::new();
73
306
        while let Some(mut entry) = entries.read_entry()? {
74
204
            if let Some(entry_chunks) = entry.read_all_chunks()? {
75
204
                data.insert(entry.id(), entry_chunks);
76
204
            }
77
        }
78
102
        invocations.push(CheckpointCall::Checkpoint { data });
79
102

            
80
102
        Ok(())
81
102
    }
82
}
83

            
84
2
fn basic<M: FileManager, P: AsRef<Path>>(manager: M, path: P) {
85
2
    let checkpointer = LoggingCheckpointer::default();
86
2

            
87
2
    let message = b"hello world";
88
2

            
89
2
    let config = Configuration::default_with_manager(path, manager);
90
2

            
91
2
    let wal = config.clone().open(checkpointer.clone()).unwrap();
92
2
    let mut writer = wal.begin_entry().unwrap();
93
2
    let record = writer.write_chunk(message).unwrap();
94
2
    let written_entry_id = writer.commit().unwrap();
95
2
    println!("hello world written to {record:?} in {written_entry_id:?}");
96
2
    drop(wal);
97
2

            
98
2
    let invocations = checkpointer.invocations.lock();
99
2
    assert!(invocations.is_empty());
100
2
    drop(invocations);
101
2

            
102
2
    let wal = config.open(checkpointer.clone()).unwrap();
103
2

            
104
2
    let invocations = checkpointer.invocations.lock();
105
2
    println!("Invocations: {invocations:?}");
106
2
    assert_eq!(invocations.len(), 2);
107
2
    match &invocations[0] {
108
2
        CheckpointCall::ShouldRecoverSegment { version_info } => {
109
2
            assert!(version_info.is_empty());
110
        }
111
        other => unreachable!("unexpected invocation: {other:?}"),
112
    }
113
2
    match &invocations[1] {
114
2
        CheckpointCall::Recover { entry_id, data } => {
115
2
            assert_eq!(written_entry_id, *entry_id);
116
2
            assert_eq!(data.len(), 1);
117
2
            assert_eq!(data[0], message);
118
        }
119
        other => unreachable!("unexpected invocation: {other:?}"),
120
    }
121
2
    drop(invocations);
122
2

            
123
2
    let mut reader = wal.read_at(record.position).unwrap();
124
2
    let mut buffer = Vec::new();
125
2
    reader.read_to_end(&mut buffer).unwrap();
126
2
    assert_eq!(buffer, message);
127
2
    assert!(reader.crc_is_valid().expect("error validating crc"));
128
2
    drop(reader);
129
2

            
130
2
    let mut writer = wal.begin_entry().unwrap();
131
2
    let _ = writer.write_chunk(message).unwrap();
132
2
    let written_entry_id = writer.commit().unwrap();
133
2

            
134
2
    wal.checkpoint_active()
135
2
        .expect("Could not checkpoint active log");
136
2
    wal.shutdown().unwrap();
137
2

            
138
2
    let invocations = checkpointer.invocations.lock();
139
2
    println!("Invocations: {invocations:?}");
140
2
    assert_eq!(invocations.len(), 3);
141
2
    match &invocations[2] {
142
2
        CheckpointCall::Checkpoint { data } => {
143
2
            let item = data.get(&written_entry_id).expect(&format!(
144
2
                "Could not find checkpointed entry: {:?}",
145
2
                written_entry_id
146
2
            ));
147
2
            assert_eq!(item.len(), 1);
148
2
            assert_eq!(item[0], message);
149
        }
150
        other => unreachable!("unexpected invocation: {other:?}"),
151
    }
152
2
    drop(invocations);
153
2
}
154

            
155
1
#[test]
156
1
fn basic_std() {
157
1
    let dir = tempdir().unwrap();
158
1
    basic(StdFileManager::default(), &dir);
159
1
}
160

            
161
1
#[test]
162
1
fn basic_memory() {
163
1
    basic(MemoryFileManager::default(), "/");
164
1
}
165

            
166
4
#[derive(Debug, Default, Clone)]
167
struct VerifyingCheckpointer {
168
    entries: Arc<Mutex<BTreeMap<EntryId, Vec<Vec<u8>>>>>,
169
}
170

            
171
impl<M> LogManager<M> for VerifyingCheckpointer
172
where
173
    M: file_manager::FileManager,
174
{
175
1
    fn recover(&mut self, entry: &mut Entry<'_, M::File>) -> std::io::Result<()> {
176
1
        dbg!(entry.id);
177
1
        if let Some(chunks) = entry.read_all_chunks()? {
178
1
            let mut entries = self.entries.lock();
179
1
            entries.insert(dbg!(entry.id), chunks);
180
1
        }
181

            
182
1
        Ok(())
183
1
    }
184

            
185
15
    fn checkpoint_to(
186
15
        &mut self,
187
15
        last_checkpointed_id: EntryId,
188
15
        reader: &mut SegmentReader<M::File>,
189
15
        _wal: &WriteAheadLog<M>,
190
15
    ) -> std::io::Result<()> {
191
15
        println!("Checkpointed to {last_checkpointed_id:?}");
192
15
        let mut entries = self.entries.lock();
193
104
        while let Some(mut entry) = reader.read_entry().unwrap() {
194
89
            let expected_data = entries.remove(&entry.id).expect("unknown entry id");
195
89
            let stored_data = entry
196
89
                .read_all_chunks()
197
89
                .unwrap()
198
89
                .expect("encountered aborted entry");
199
89
            assert_eq!(expected_data, stored_data);
200
        }
201
15
        entries.retain(|entry_id, _| *entry_id > last_checkpointed_id);
202
15
        Ok(())
203
15
    }
204
}
205

            
206
2
fn multithreaded<M: FileManager, P: AsRef<Path>>(manager: M, path: P) {
207
2
    let mut threads = Vec::new();
208
2

            
209
2
    let checkpointer = VerifyingCheckpointer::default();
210
2
    let original_entries = checkpointer.entries.clone();
211
2
    let wal = Configuration::default_with_manager(path.as_ref(), manager.clone())
212
2
        .open(checkpointer)
213
2
        .unwrap();
214

            
215
12
    for _ in 0..5 {
216
10
        let wal = wal.clone();
217
10
        let written_entries = original_entries.clone();
218
10
        threads.push(std::thread::spawn(move || {
219
10
            let mut rng = fastrand::Rng::new();
220
100
            for _ in 1..10 {
221
90
                let mut messages = Vec::with_capacity(rng.usize(1..=8));
222
90
                let mut writer = wal.begin_entry().unwrap();
223
417
                for _ in 0..messages.capacity() {
224
13822449
                    let message = repeat_with(|| 42)
225
417
                        .take(rng.usize(..65_536))
226
417
                        .collect::<Vec<_>>();
227
417
                    // let message = vec![42; 256];
228
417
                    writer.write_chunk(&message).unwrap();
229
417
                    messages.push(message);
230
417
                }
231
                // Lock entries before pushing the commit to ensure that a
232
                // checkpoint operation can't happen before we insert this
233
                // entry.
234
90
                let mut entries = written_entries.lock();
235
90
                entries.insert(writer.id(), messages);
236
90
                drop(entries);
237
90
                writer.commit().unwrap();
238
            }
239
10
        }));
240
10
    }
241

            
242
12
    for thread in threads {
243
10
        thread.join().unwrap();
244
10
    }
245

            
246
2
    wal.shutdown().unwrap();
247
2

            
248
2
    println!("Reopening log");
249
2
    let checkpointer = VerifyingCheckpointer::default();
250
2
    let recovered_entries = checkpointer.entries.clone();
251
2
    let _wal = Configuration::default_with_manager(path.as_ref(), manager)
252
2
        .open(checkpointer)
253
2
        .unwrap();
254
2
    let recovered_entries = recovered_entries.lock();
255
2
    let original_entries = original_entries.lock();
256
2
    // Check keys first because it's easier to verify a list of ids than it is
257
2
    // to look at debug output of a bunch of bytes.
258
2
    assert_eq!(
259
2
        original_entries.keys().collect::<Vec<_>>(),
260
2
        recovered_entries.keys().collect::<Vec<_>>()
261
2
    );
262
2
    assert_eq!(&*original_entries, &*recovered_entries);
263
2
}
264

            
265
1
#[test]
266
1
fn multithreaded_std() {
267
1
    let dir = tempdir().unwrap();
268
1
    multithreaded(StdFileManager::default(), &dir);
269
1
}
270

            
271
1
#[test]
272
1
fn multithreaded_memory() {
273
1
    multithreaded(MemoryFileManager::default(), "/");
274
1
}
275

            
276
2
fn aborted_entry<M: FileManager, P: AsRef<Path>>(manager: M, path: P) {
277
2
    let checkpointer = LoggingCheckpointer::default();
278
2

            
279
2
    let message = b"hello world";
280
2

            
281
2
    let wal = Configuration::default_with_manager(path.as_ref(), manager.clone())
282
2
        .open(checkpointer.clone())
283
2
        .unwrap();
284
2
    let mut writer = wal.begin_entry().unwrap();
285
2
    let record = writer.write_chunk(message).unwrap();
286
2
    let written_entry_id = writer
287
2
        .commit_and(|file| file.write_all(&[NEW_ENTRY]))
288
2
        .unwrap();
289
2
    println!("hello world written to {record:?} in {written_entry_id:?}");
290
2
    drop(wal);
291
2

            
292
2
    let invocations = checkpointer.invocations.lock();
293
2
    assert!(invocations.is_empty());
294
2
    drop(invocations);
295
2

            
296
2
    Configuration::default_with_manager(path.as_ref(), manager)
297
2
        .open(checkpointer.clone())
298
2
        .unwrap();
299
2

            
300
2
    let invocations = checkpointer.invocations.lock();
301
2
    assert_eq!(invocations.len(), 2);
302
2
    match &invocations[0] {
303
2
        CheckpointCall::ShouldRecoverSegment { version_info } => {
304
2
            assert!(version_info.is_empty());
305
        }
306
        other => unreachable!("unexpected invocation: {other:?}"),
307
    }
308
2
    match &invocations[1] {
309
2
        CheckpointCall::Recover { entry_id, data } => {
310
2
            assert_eq!(written_entry_id, *entry_id);
311
2
            assert_eq!(data.len(), 1);
312
2
            assert_eq!(data[0], message);
313
        }
314
        other => unreachable!("unexpected invocation: {other:?}"),
315
    }
316
2
    drop(invocations);
317
2
}
318

            
319
1
#[test]
320
1
fn aborted_entry_std() {
321
1
    let dir = tempdir().unwrap();
322
1
    aborted_entry(StdFileManager::default(), &dir);
323
1
}
324

            
325
1
#[test]
326
1
fn aborted_entry_memory() {
327
1
    aborted_entry(MemoryFileManager::default(), "/");
328
1
}
329

            
330
2
fn always_checkpointing<M: FileManager, P: AsRef<Path>>(manager: M, path: P) {
331
2
    let checkpointer = LoggingCheckpointer::default();
332
2
    let config =
333
2
        Configuration::default_with_manager(path.as_ref(), manager).checkpoint_after_bytes(33);
334
2

            
335
2
    let mut written_chunks = Vec::new();
336
202
    for i in 0_usize..100 {
337
200
        println!("About to insert {i}");
338
200
        let wal = config.clone().open(checkpointer.clone()).unwrap();
339
200
        let mut writer = wal.begin_entry().unwrap();
340
200
        println!("Writing {i}");
341
200
        let record = writer.write_chunk(&i.to_be_bytes()).unwrap();
342
200
        let written_entry_id = writer.commit().unwrap();
343
200
        println!("{i} written to {record:?} in {written_entry_id:?}");
344
200
        written_chunks.push((written_entry_id, record));
345
200
        wal.shutdown().unwrap();
346
200
    }
347

            
348
    // Because we write and close the database so quickly, it's possible for the
349
    // final write to not have been checkpointed by the background thread. So,
350
    // we'll reopen the WAL one more time and sleep for a moment to allow it to
351
    // finish.
352
2
    let _wal = config.open(checkpointer.clone()).unwrap();
353
2
    std::thread::sleep(Duration::from_millis(100));
354
2

            
355
2
    let invocations = checkpointer.invocations.lock();
356
2
    println!("Invocations: {invocations:?}");
357

            
358
200
    for (index, (entry_id, _)) in written_chunks.into_iter().enumerate() {
359
200
        println!("Checking {index}");
360
200
        let chunk_data = invocations
361
200
            .iter()
362
200
            .find_map(|call| {
363
20200
                if let CheckpointCall::Checkpoint { data, .. } = call {
364
5100
                    data.get(&entry_id)
365
                } else {
366
15100
                    None
367
                }
368
20200
            })
369
200
            .expect("entry not checkpointed");
370
200

            
371
200
        assert_eq!(chunk_data[0], index.to_be_bytes());
372
    }
373
2
}
374

            
375
1
#[test]
376
1
fn always_checkpointing_std() {
377
1
    let dir = tempdir().unwrap();
378
1
    always_checkpointing(StdFileManager::default(), &dir);
379
1
}
380

            
381
1
#[test]
382
1
fn always_checkpointing_memory() {
383
1
    always_checkpointing(MemoryFileManager::default(), "/");
384
1
}
385

            
386
1
#[test]
387
1
fn recover_after_first_rollback() {
388
1
    let dir = tempdir().unwrap();
389
1
    let checkpointer = LoggingCheckpointer::default();
390
1
    let log = WriteAheadLog::recover(&dir, checkpointer.clone()).unwrap();
391
1

            
392
1
    // this entry is rolled back
393
1
    log.begin_entry().unwrap();
394
1

            
395
1
    let mut entry = log.begin_entry().unwrap();
396
1
    let expected_id = entry.id();
397
1
    entry.write_chunk(b"test").unwrap();
398
1
    entry.commit().unwrap();
399
1

            
400
1
    log.shutdown().unwrap();
401
1

            
402
1
    WriteAheadLog::recover(&dir, checkpointer.clone()).unwrap();
403
1

            
404
1
    let invocations = checkpointer.invocations.lock();
405
1
    println!("Invocations: {invocations:?}");
406
1
    assert_eq!(invocations.len(), 2);
407
1
    assert!(matches!(
408
1
        invocations[0],
409
        CheckpointCall::ShouldRecoverSegment { .. }
410
    ));
411
1
    let CheckpointCall::Recover { entry_id, data } = &invocations[1] else {
412
        unreachable!()
413
    };
414
1
    assert_eq!(*entry_id, expected_id);
415
1
    assert_eq!(data.len(), 1);
416
1
    assert_eq!(data[0], b"test");
417
1
}