1
use std::collections::VecDeque;
2
use std::sync::{Arc, Condvar, Mutex};
3

            
4
use crc32c::crc32c;
5
use okaywal::{file_manager, EntryWriter, LogPosition};
6

            
7
use crate::atlas::IndexMetadata;
8
use crate::commit_log::{CommitLogEntry, NewGrain};
9
use crate::format::{GrainId, TransactionId};
10
use crate::util::usize_to_u32;
11
use crate::wal::WalChunk;
12
use crate::{Database, Error, Result};
13

            
14
#[derive(Debug)]
15
pub struct Transaction<'db, FileManager>
16
where
17
    FileManager: file_manager::FileManager,
18
{
19
    database: &'db Database<FileManager>,
20
    entry: Option<EntryWriter<'db, FileManager>>,
21
    guard: Option<TransactionGuard>,
22
    state: Option<CommittingTransaction>,
23
}
24

            
25
impl<'db, FileManager> Transaction<'db, FileManager>
26
where
27
    FileManager: file_manager::FileManager,
28
{
29
3356
    pub(super) fn new(
30
3356
        database: &'db Database<FileManager>,
31
3356
        entry: EntryWriter<'db, FileManager>,
32
3356
        guard: TransactionGuard,
33
3356
    ) -> Result<Self> {
34
3356
        let metadata = guard.current_index_metadata();
35
3356
        Ok(Self {
36
3356
            database,
37
3356
            state: Some(CommittingTransaction {
38
3356
                metadata,
39
3356
                written_grains: Vec::new(),
40
3356
                log_entry: CommitLogEntry::new(
41
3356
                    TransactionId::from(entry.id()),
42
3356
                    metadata.commit_log_head,
43
3356
                    metadata.embedded_header_data,
44
3356
                    metadata.checkpoint_target,
45
3356
                    metadata.checkpointed_to,
46
3356
                ),
47
3356
            }),
48
3356

            
49
3356
            entry: Some(entry),
50
3356
            guard: Some(guard),
51
3356
        })
52
3356
    }
53

            
54
30754
    pub fn write(&mut self, data: &[u8]) -> Result<GrainId> {
55
30754
        let data_length = usize_to_u32(data.len())?;
56
30754
        let grain_id = self.database.data.atlas.reserve(data_length)?;
57

            
58
30754
        let entry = self.entry.as_mut().expect("entry missing");
59
30754
        let mut chunk = entry.begin_chunk(WalChunk::new_grain_length(data_length))?;
60
30754
        WalChunk::write_new_grain(grain_id, data, &mut chunk)?;
61
30754
        let record = chunk.finish()?;
62

            
63
30754
        let state = self.state.as_mut().expect("state missing");
64
30754
        state.written_grains.push((grain_id, record.position));
65
30754
        state.log_entry.new_grains.push(NewGrain {
66
30754
            id: grain_id,
67
30754
            crc32: crc32c(data),
68
30754
        });
69
30754
        Ok(grain_id)
70
30754
    }
71

            
72
    pub fn archive(&mut self, grain: GrainId) -> Result<()> {
73
2005
        self.database.data.atlas.check_grain_validity(grain)?;
74

            
75
2005
        let entry = self.entry.as_mut().expect("entry missing");
76
2005
        let mut chunk = entry.begin_chunk(WalChunk::COMMAND_LENGTH)?;
77
2005
        WalChunk::write_archive_grain(grain, &mut chunk)?;
78
2005
        chunk.finish()?;
79

            
80
2005
        let state = self.state.as_mut().expect("state missing");
81
2005
        state.log_entry.archived_grains.push(grain);
82
2005

            
83
2005
        Ok(())
84
2005
    }
85

            
86
34
    pub(crate) fn free_grains(&mut self, grains: &[GrainId]) -> Result<()> {
87
34
        let entry = self.entry.as_mut().expect("entry missing");
88
1153
        for grain in grains {
89
1119
            let mut chunk = entry.begin_chunk(WalChunk::COMMAND_LENGTH)?;
90
1119
            WalChunk::write_free_grain(*grain, &mut chunk)?;
91
1119
            chunk.finish()?;
92
        }
93

            
94
34
        let state = self.state.as_mut().expect("state missing");
95
34
        state.log_entry.freed_grains.extend(grains.iter().copied());
96
34

            
97
34
        Ok(())
98
34
    }
99

            
100
    #[allow(clippy::drop_ref)]
101
13
    pub fn set_embedded_header(&mut self, new_header: Option<GrainId>) -> Result<()> {
102
13
        let entry = self.entry.as_mut().expect("entry missing");
103
13
        let mut chunk = entry.begin_chunk(WalChunk::COMMAND_LENGTH)?;
104
13
        WalChunk::write_embedded_header_update(new_header, &mut chunk)?;
105
13
        chunk.finish()?;
106

            
107
13
        let mut state = self.state.as_mut().expect("state missing");
108
13
        if let Some(old_header) = state.log_entry.embedded_header_data {
109
9
            drop(state);
110
9
            self.archive(old_header)?;
111
9
            state = self.state.as_mut().expect("state missing");
112
4
        }
113

            
114
13
        state.metadata.embedded_header_data = new_header;
115
13
        state.log_entry.embedded_header_data = new_header;
116
13

            
117
13
        Ok(())
118
13
    }
119

            
120
3011
    pub fn checkpoint_to(&mut self, tx_id: TransactionId) -> Result<()> {
121
3011
        let entry = self.entry.as_mut().expect("entry missing");
122
3011
        let mut state = self.state.as_mut().expect("state missing");
123
3011
        if tx_id <= state.log_entry.checkpoint_target {
124
            // already the checkpoint target
125
31
            return Ok(());
126
2980
        } else if tx_id >= entry.id() {
127
1
            return Err(Error::InvalidTransactionId);
128
2979
        }
129

            
130
2979
        let mut chunk = entry.begin_chunk(WalChunk::COMMAND_LENGTH)?;
131
2979
        WalChunk::write_checkpoint_to(tx_id, &mut chunk)?;
132
2979
        chunk.finish()?;
133

            
134
2979
        state.log_entry.checkpoint_target = tx_id;
135
2979

            
136
2979
        Ok(())
137
3011
    }
138

            
139
35
    pub(crate) fn checkpointed_to(&mut self, tx_id: TransactionId) -> Result<()> {
140
35
        let entry = self.entry.as_mut().expect("entry missing");
141
35
        let mut state = self.state.as_mut().expect("state missing");
142
35
        if tx_id <= state.log_entry.checkpointed_to {
143
            // already the checkpoint target
144
            return Ok(());
145
35
        } else if tx_id >= entry.id() {
146
1
            return Err(Error::InvalidTransactionId);
147
34
        }
148

            
149
34
        let mut chunk = entry.begin_chunk(WalChunk::COMMAND_LENGTH)?;
150
34
        WalChunk::write_checkpointed_to(tx_id, &mut chunk)?;
151
34
        chunk.finish()?;
152

            
153
34
        state.log_entry.checkpointed_to = tx_id;
154
34

            
155
34
        Ok(())
156
35
    }
157

            
158
    #[allow(clippy::drop_ref)]
159
3353
    pub fn commit(mut self) -> Result<TransactionId> {
160
3353
        let state = self.state.as_mut().expect("state missing");
161
3353
        // Write the commit log entry
162
3353
        state.log_entry.freed_grains.sort_unstable();
163
3353
        let mut log_entry_bytes = Vec::new();
164
3353
        state.log_entry.serialize_to(&mut log_entry_bytes)?;
165
3353
        drop(state);
166
3353
        let new_commit_log_head = self.write(&log_entry_bytes)?;
167

            
168
3353
        let mut state = self.state.take().expect("state missing");
169
3353
        state.metadata.commit_log_head = Some(new_commit_log_head);
170
3353
        // Because we end up caching the log_entry, we need it to match what's
171
3353
        // on disk. What we just stored did not contain the newly written commit
172
3353
        // log head grain. We need to remove it from the entry.
173
3353
        state.log_entry.new_grains.pop();
174
3353

            
175
3353
        // Write the transaction tail
176
3353
        let mut entry = self.entry.take().expect("entry missing");
177
3353
        let mut chunk = entry.begin_chunk(WalChunk::COMMAND_LENGTH)?;
178
3353
        WalChunk::write_transaction_tail(new_commit_log_head, &mut chunk)?;
179
3353
        chunk.finish()?;
180

            
181
3353
        let guard = self.guard.take().expect("tx guard missing");
182
3353

            
183
3353
        let transaction_id = state.log_entry.transaction_id;
184
3353
        let finalizer = guard.stage(state, self.database);
185
3353

            
186
3353
        entry.commit()?;
187

            
188
3353
        finalizer.finalize()?;
189

            
190
3353
        Ok(transaction_id)
191
3353
    }
192

            
193
1
    pub fn rollback(mut self) -> Result<()> {
194
1
        self.rollback_transaction()
195
1
    }
196

            
197
3
    fn rollback_transaction(&mut self) -> Result<()> {
198
3
        let mut state = self.state.take().expect("state missing");
199
3
        let entry = self.entry.take().expect("entry missing");
200
3

            
201
3
        let result = entry.rollback();
202
3

            
203
3
        self.database
204
3
            .data
205
3
            .atlas
206
3
            .rollback_grains(state.written_grains.drain(..).map(|(g, _)| g))?;
207

            
208
3
        result?;
209

            
210
3
        Ok(())
211
3
    }
212
}
213

            
214
impl<'db, FileManager> Drop for Transaction<'db, FileManager>
215
where
216
    FileManager: file_manager::FileManager,
217
{
218
3356
    fn drop(&mut self) {
219
3356
        if self.entry.is_some() {
220
2
            self.rollback_transaction()
221
2
                .expect("error rolling back transaction");
222
3354
        }
223
3356
    }
224
}
225

            
226
6709
#[derive(Debug, Clone)]
227
pub struct TransactionLock {
228
    data: Arc<TransactionLockData>,
229
}
230

            
231
impl TransactionLock {
232
331
    pub fn new(initial_metadata: IndexMetadata) -> Self {
233
331
        Self {
234
331
            data: Arc::new(TransactionLockData {
235
331
                tx_lock: Mutex::new(TransactionState::new(initial_metadata)),
236
331
                tx_sync: Condvar::new(),
237
331
            }),
238
331
        }
239
331
    }
240

            
241
3356
    pub(super) fn lock(&self) -> TransactionGuard {
242
3356
        let mut state = self.data.tx_lock.lock().expect("can't panick");
243

            
244
        // Wait for the locked status to be relinquished
245
5887
        while state.in_transaction {
246
2531
            state = self.data.tx_sync.wait(state).expect("can't panick");
247
2531
        }
248

            
249
        // Acquire the locked status
250
3356
        state.in_transaction = true;
251
3356

            
252
3356
        // Return the guard
253
3356
        TransactionGuard { lock: self.clone() }
254
3356
    }
255
}
256

            
257
#[derive(Debug)]
258
struct TransactionLockData {
259
    tx_lock: Mutex<TransactionState>,
260
    tx_sync: Condvar,
261
}
262

            
263
/// Ensures only one thread has access to begin a transaction at any given time.
264
///
265
/// This guard ensures that no two threads try to update some of the in-memory
266
/// state at the same time. The Write-Ahead Log always ensures only one thread
267
/// can write to it already, but we need extra guarantees because we don't want
268
/// to publish some state until after the WAL has confirmed its commit.
269
#[derive(Debug)]
270
pub(super) struct TransactionGuard {
271
    lock: TransactionLock,
272
}
273

            
274
impl TransactionGuard {
275
3356
    pub fn current_index_metadata(&self) -> IndexMetadata {
276
3356
        let state = self.lock.data.tx_lock.lock().expect("cannot panic");
277
3356
        state.metadata
278
3356
    }
279

            
280
3353
    pub(super) fn stage<FileManager>(
281
3353
        self,
282
3353
        tx: CommittingTransaction,
283
3353
        db: &'_ Database<FileManager>,
284
3353
    ) -> TransactionFinalizer<'_, FileManager>
285
3353
    where
286
3353
        FileManager: file_manager::FileManager,
287
3353
    {
288
3353
        let id = tx.log_entry.transaction_id;
289
3353
        let mut state = self.lock.data.tx_lock.lock().expect("cannot panic");
290
3353
        state.metadata = tx.metadata;
291
3353
        state.committing_transactions.push_back(tx);
292
3353

            
293
3353
        TransactionFinalizer {
294
3353
            db,
295
3353
            lock: self.lock.clone(),
296
3353
            id,
297
3353
        }
298
3353
    }
299
}
300

            
301
impl Drop for TransactionGuard {
302
3356
    fn drop(&mut self) {
303
3356
        // Reset the locked status
304
3356
        let mut state = self.lock.data.tx_lock.lock().expect("can't panick");
305
3356
        state.in_transaction = false;
306
3356
        drop(state);
307
3356

            
308
3356
        // Notify the next waiter.
309
3356
        self.lock.data.tx_sync.notify_one();
310
3356
    }
311
}
312

            
313
#[derive(Debug)]
314
struct TransactionState {
315
    in_transaction: bool,
316
    metadata: IndexMetadata,
317
    committing_transactions: VecDeque<CommittingTransaction>,
318
}
319

            
320
impl TransactionState {
321
331
    pub fn new(initial_metadata: IndexMetadata) -> Self {
322
331
        Self {
323
331
            in_transaction: false,
324
331
            metadata: initial_metadata,
325
331
            committing_transactions: VecDeque::new(),
326
331
        }
327
331
    }
328
}
329

            
330
#[derive(Debug)]
331
pub(super) struct CommittingTransaction {
332
    metadata: IndexMetadata,
333
    written_grains: Vec<(GrainId, LogPosition)>,
334
    log_entry: CommitLogEntry,
335
}
336

            
337
#[derive(Debug)]
338
pub(super) struct TransactionFinalizer<'a, FileManager>
339
where
340
    FileManager: file_manager::FileManager,
341
{
342
    db: &'a Database<FileManager>,
343
    lock: TransactionLock,
344
    id: TransactionId,
345
}
346

            
347
impl<'a, FileManager> TransactionFinalizer<'a, FileManager>
348
where
349
    FileManager: file_manager::FileManager,
350
{
351
3353
    pub fn finalize(self) -> Result<()> {
352
3353
        let mut state = self.lock.data.tx_lock.lock().expect("can't panic");
353

            
354
6706
        while state
355
6706
            .committing_transactions
356
6706
            .front()
357
6706
            .map_or(false, |tx| tx.log_entry.transaction_id <= self.id)
358
        {
359
3353
            let mut tx_to_commit = state
360
3353
                .committing_transactions
361
3353
                .pop_front()
362
3353
                .expect("just checked");
363
3353
            self.db.data.atlas.note_transaction_committed(
364
3353
                tx_to_commit.metadata,
365
3353
                tx_to_commit.written_grains.drain(..),
366
3353
                &tx_to_commit.log_entry.freed_grains,
367
3353
                false,
368
3353
            )?;
369
3353
            self.db.data.commit_logs.cache(
370
3353
                tx_to_commit
371
3353
                    .metadata
372
3353
                    .commit_log_head
373
3353
                    .expect("commit log must be present"),
374
3353
                Arc::new(tx_to_commit.log_entry),
375
3353
            )?;
376
        }
377

            
378
3353
        Ok(())
379
3353
    }
380
}