1
use std::sync::{Arc, Weak};
2
use std::thread::JoinHandle;
3

            
4
use okaywal::{file_manager, WriteAheadLog};
5
use watchable::{Watchable, Watcher};
6

            
7
use crate::format::TransactionId;
8
use crate::{Data, Database, Error, Result};
9

            
10
#[derive(Debug)]
11
pub struct Checkpointer {
12
    watchable: Watchable<TransactionId>,
13
    handle_receiver: flume::Receiver<JoinHandle<Result<(), Error>>>,
14
}
15

            
16
impl Checkpointer {
17
331
    pub fn new(current_checkpointed_transaction: TransactionId) -> (Self, Spawner) {
18
331
        let watchable = Watchable::new(current_checkpointed_transaction);
19
331
        let watcher = watchable.watch();
20
331
        let (handle_sender, handle_receiver) = flume::bounded(1);
21
331

            
22
331
        (
23
331
            Self {
24
331
                watchable,
25
331
                handle_receiver,
26
331
            },
27
331
            Spawner {
28
331
                watcher,
29
331
                handle_sender,
30
331
            },
31
331
        )
32
331
    }
33

            
34
368
    pub fn checkpoint_to(&self, tx_id: TransactionId) {
35
368
        let _ = self.watchable.update(tx_id);
36
368
    }
37

            
38
331
    pub fn shutdown(&self) -> Result<()> {
39
331
        self.watchable.shutdown();
40
331
        let join_handle = self
41
331
            .handle_receiver
42
331
            .recv()
43
331
            .expect("handle should always be sent after spawning");
44
331
        join_handle.join().map_err(|_| Error::ThreadJoin)?
45
331
    }
46
}
47

            
48
#[derive(Debug)]
49
pub struct Spawner {
50
    watcher: Watcher<TransactionId>,
51
    handle_sender: flume::Sender<JoinHandle<Result<(), Error>>>,
52
}
53

            
54
impl Spawner {
55
331
    pub(super) fn spawn<FileManager>(
56
331
        self,
57
331
        current_checkpointed_tx: TransactionId,
58
331
        data: &Arc<Data<FileManager>>,
59
331
        wal: &WriteAheadLog<FileManager>,
60
331
    ) -> Result<()>
61
331
    where
62
331
        FileManager: file_manager::FileManager,
63
331
    {
64
331
        let data = Arc::downgrade(data);
65
331
        let wal = wal.clone();
66
331
        let thread_handle = std::thread::Builder::new()
67
331
            .name(String::from("sediment-cp"))
68
331
            .spawn(move || {
69
331
                sediment_checkpoint_thread(current_checkpointed_tx, self.watcher, data, wal)
70
331
            })
71
331
            .expect("failed to spawn thread");
72
331
        self.handle_sender
73
331
            .send(thread_handle)
74
331
            .expect("this send should never fail");
75
331
        Ok(())
76
331
    }
77
}
78

            
79
331
fn sediment_checkpoint_thread<FileManager>(
80
331
    baseline_transaction: TransactionId,
81
331
    mut tx_receiver: Watcher<TransactionId>,
82
331
    data: Weak<Data<FileManager>>,
83
331
    wal: WriteAheadLog<FileManager>,
84
331
) -> Result<()>
85
331
where
86
331
    FileManager: file_manager::FileManager,
87
331
{
88
331
    let mut current_tx_id = baseline_transaction;
89
365
    while let Ok(transaction_to_checkpoint) = tx_receiver.next_value() {
90
34
        if transaction_to_checkpoint <= current_tx_id {
91
            continue;
92
34
        }
93

            
94
34
        if let Some(data) = data.upgrade() {
95
34
            let db = Database {
96
34
                data,
97
34
                wal: wal.clone(),
98
34
            };
99

            
100
            // Find all commit log entries that are <=
101
            // transaction_to_checkpoint.
102
34
            let mut current_commit_log = db.commit_log_head()?;
103
34
            let mut archived_grains = Vec::new();
104
34
            let mut commit_logs_to_archive = Vec::new();
105
7661
            while let Some(entry) = current_commit_log {
106
7657
                if entry.transaction_id > current_tx_id
107
7627
                    && entry.transaction_id <= transaction_to_checkpoint
108
1987
                {
109
1987
                    archived_grains.extend(entry.archived_grains.iter().copied());
110
1987
                    commit_logs_to_archive.push(entry.grain_id);
111
5670
                } else if entry.transaction_id <= current_tx_id {
112
                    // We can't go any further back.
113
30
                    break;
114
5640
                }
115

            
116
7627
                current_commit_log = entry.next_entry(&db)?;
117
            }
118

            
119
34
            let mut tx = db.begin_transaction()?;
120
2021
            for commit_log_id in commit_logs_to_archive {
121
1987
                tx.archive(commit_log_id)?;
122
            }
123
34
            tx.free_grains(&archived_grains)?;
124
34
            tx.checkpointed_to(transaction_to_checkpoint)?;
125
34
            tx.commit()?;
126

            
127
34
            current_tx_id = transaction_to_checkpoint;
128
        }
129
    }
130

            
131
331
    Ok(())
132
331
}