ouisync/db/
mod.rs

1#[macro_use]
2mod macros;
3
4mod connection;
5mod id;
6mod migrations;
7
8pub use id::DatabaseId;
9pub use migrations::SCHEMA_VERSION;
10
11use tracing::Span;
12
13use deadlock::ExpectShortLifetime;
14use ref_cast::RefCast;
15use sqlx::{
16    ConnectOptions, Row, SqlitePool, TransactionManager,
17    sqlite::{
18        Sqlite, SqliteAutoVacuum, SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions,
19        SqliteSynchronous, SqliteTransactionManager,
20    },
21};
22use std::{
23    fmt, io,
24    ops::{Deref, DerefMut},
25    panic::Location,
26    path::Path,
27    time::Duration,
28};
29#[cfg(test)]
30use tempfile::TempDir;
31use thiserror::Error;
32use tokio::{fs, sync::Semaphore, task};
33
34const ACQUIRE_TIMEOUT: Duration = Duration::from_secs(5 * 60);
35const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
36const WARN_AFTER_CONNECTION_LIFETIME: Duration = Duration::from_secs(30);
37
38pub use self::connection::Connection;
39
40/// Database connection pool.
41#[derive(Clone)]
42pub struct Pool {
43    // Pool with multiple read-only connections
44    reads: SqlitePool,
45    // Pool with a single writable connection.
46    write: SqlitePool,
47}
48
49impl Pool {
50    async fn create(conn_options: SqliteConnectOptions) -> Result<Self, sqlx::Error> {
51        if fs::try_exists(conn_options.get_filename())
52            .await
53            .unwrap_or(false)
54        {
55            // Try to enable auto-vacuum, but if it fails [^1] it's not a critical failure as we can
56            // keep using the db without auto-vacuum [^2]. Just log the error and keep going.
57            //
58            // [^1]: for example, because of low disk space
59            // [^2]: with the downside that disk space won't be reclaimed after data deletion
60            if let Err(error) = enable_auto_vacuum(conn_options.get_filename()).await {
61                tracing::warn!(?error, "failed to enable auto vacuum");
62            }
63        }
64
65        let conn_options = conn_options
66            .journal_mode(SqliteJournalMode::Wal)
67            .synchronous(SqliteSynchronous::Normal)
68            .pragma("recursive_triggers", "ON");
69
70        let pool_options = SqlitePoolOptions::new()
71            // Disable the test as it breaks cancel-safety (also it's unnecessary in our case)
72            .test_before_acquire(false)
73            // Expire idle connections to conserve resources (threads, file descriptors)
74            .idle_timeout(IDLE_TIMEOUT)
75            .acquire_timeout(ACQUIRE_TIMEOUT);
76
77        let write = pool_options
78            .clone()
79            .max_connections(1)
80            .connect_with(
81                conn_options
82                    .clone()
83                    .auto_vacuum(SqliteAutoVacuum::Full)
84                    .optimize_on_close(true, Some(1000)),
85            )
86            .await?;
87
88        let reads = pool_options
89            .max_connections(8)
90            .connect_with(conn_options.read_only(true))
91            .await?;
92
93        Ok(Self { reads, write })
94    }
95
96    /// Acquire a read-only database connection.
97    #[track_caller]
98    pub fn acquire(&self) -> impl Future<Output = Result<PoolConnection, sqlx::Error>> + '_ {
99        PoolConnection::acquire(&self.reads, Location::caller())
100    }
101
102    /// Begin a read-only transaction. See [`ReadTransaction`] for more details.
103    #[track_caller]
104    pub fn begin_read(&self) -> impl Future<Output = Result<ReadTransaction, sqlx::Error>> + '_ {
105        ReadTransaction::begin(&self.reads, Location::caller())
106    }
107
108    /// Begin a write transaction. See [`WriteTransaction`] for more details.
109    #[track_caller]
110    pub fn begin_write(&self) -> impl Future<Output = Result<WriteTransaction, sqlx::Error>> + '_ {
111        let location = Location::caller();
112
113        async move {
114            Ok(WriteTransaction {
115                inner: ReadTransaction::begin(&self.write, location).await?,
116            })
117        }
118    }
119
120    pub(crate) async fn close(&self) -> Result<(), sqlx::Error> {
121        // Make sure to first close `reads` and only then `write`. That way when closing the write
122        // connection it is the last remaining connection and so it performs a WAL checkpoint and
123        // removes the auxiliary db files (*-wal and *-shm).
124        self.reads.close().await;
125        self.write.close().await;
126
127        Ok(())
128    }
129}
130
131/// Database connection from pool
132pub struct PoolConnection {
133    inner: sqlx::pool::PoolConnection<Sqlite>,
134    _track_lifetime: ExpectShortLifetime,
135}
136
137impl PoolConnection {
138    // Internal
139    async fn acquire(
140        pool: &SqlitePool,
141        location: &'static Location<'static>,
142    ) -> Result<Self, sqlx::Error> {
143        let inner = pool.acquire().await?;
144        let track_lifetime = ExpectShortLifetime::new_in(WARN_AFTER_CONNECTION_LIFETIME, location);
145
146        Ok(Self {
147            inner,
148            _track_lifetime: track_lifetime,
149        })
150    }
151}
152
153impl Deref for PoolConnection {
154    type Target = Connection;
155
156    fn deref(&self) -> &Self::Target {
157        Connection::ref_cast(self.inner.deref())
158    }
159}
160
161impl DerefMut for PoolConnection {
162    fn deref_mut(&mut self) -> &mut Self::Target {
163        Connection::ref_cast_mut(self.inner.deref_mut())
164    }
165}
166
167/// Transaction that allows only reading.
168///
169/// This is useful if one wants to make sure the observed database content doesn't change for the
170/// duration of the transaction even in the presence of concurrent writes. In other words - a read
171/// transaction represents an immutable snapshot of the database at the point the transaction was
172/// created. A read transaction doesn't need to be committed or rolled back - it's implicitly ended
173/// when the `ReadTransaction` instance drops.
174pub struct ReadTransaction {
175    inner: PoolConnection,
176    closed: bool,
177}
178
179impl ReadTransaction {
180    // Internal
181    async fn begin(
182        pool: &SqlitePool,
183        location: &'static Location<'static>,
184    ) -> Result<Self, sqlx::Error> {
185        let mut inner = PoolConnection::acquire(pool, location).await?;
186        SqliteTransactionManager::begin(&mut inner.inner, None).await?;
187
188        Ok(Self {
189            inner,
190            closed: false,
191        })
192    }
193
194    // Internal
195    async fn commit(mut self) -> Result<Committed, sqlx::Error> {
196        SqliteTransactionManager::commit(&mut self.inner.inner).await?;
197        self.closed = true;
198        Ok(Committed(self))
199    }
200}
201
202impl Deref for ReadTransaction {
203    type Target = Connection;
204
205    fn deref(&self) -> &Self::Target {
206        self.inner.deref()
207    }
208}
209
210impl DerefMut for ReadTransaction {
211    fn deref_mut(&mut self) -> &mut Self::Target {
212        self.inner.deref_mut()
213    }
214}
215
216impl fmt::Debug for ReadTransaction {
217    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
218        f.debug_struct("ReadTransaction").finish_non_exhaustive()
219    }
220}
221
222impl_executor_by_deref!(ReadTransaction);
223
224impl Drop for ReadTransaction {
225    fn drop(&mut self) {
226        if !self.closed {
227            SqliteTransactionManager::start_rollback(&mut self.inner.inner);
228        }
229    }
230}
231
232// Wrapper for a transaction that's been committed. This allows to delay releasing the underlying
233// connection to the pool while disallowing using the connection to execute any db operations.
234struct Committed(#[allow(dead_code)] ReadTransaction);
235
236/// Transaction that allows both reading and writing.
237///
238/// At most one task can hold a write transaction at any time. Any other tasks are blocked on
239/// calling `begin_write` until the task that currently holds it is done with it (commits it or
240/// rolls it back). Performing read-only operations concurrently while a write transaction is in
241/// use is still allowed. Those operations will not see the writes performed via the write
242/// transaction until that transaction is committed however.
243pub struct WriteTransaction {
244    inner: ReadTransaction,
245}
246
247impl WriteTransaction {
248    /// Commits the transaction.
249    ///
250    /// # Cancel safety
251    ///
252    /// If the future returned by this function is cancelled before completion, the transaction
253    /// is guaranteed to be either committed or rolled back but there is no way to tell in advance
254    /// which of the two operations happens.
255    pub async fn commit(self) -> Result<(), sqlx::Error> {
256        self.inner.commit().await?;
257        Ok(())
258    }
259
260    /// Commits the transaction and if (and only if) the commit completes successfully, runs the
261    /// given closure.
262    ///
263    /// # Atomicity
264    ///
265    /// If the commit succeeds, the closure is guaranteed to complete before another write
266    /// transaction begins.
267    ///
268    /// # Cancel safety
269    ///
270    /// The commits completes and if it succeeds the closure gets called. This is guaranteed to
271    /// happen even if the future returned from this function is cancelled before completion.
272    ///
273    /// # Insufficient alternatives
274    ///
275    /// ## Calling `commit().await?` and then calling `f()`
276    ///
277    /// This is not enough because it has these possible outcomes depending on whether and when
278    /// cancellation happened:
279    ///
280    /// 1. `commit` completes successfully and `f` is called
281    /// 2. `commit` completes with error and `f` is not called
282    /// 3. `commit` is cancelled but the transaction is still committed and `f` is not called
283    /// 4. `commit` is cancelled and the transaction rolls back and `f` is not called
284    ///
285    /// Number 3 is typically not desirable.
286    ///
287    /// ## Calling `f` using a RAII guard
288    ///
289    /// This is still not enough because it has the following possible outcomes:
290    ///
291    /// 1. `commit` completes successfully and `f` is called
292    /// 2. `commit` completes with error and `f` is called
293    /// 3. `commit` is cancelled but the transaction is still committed and `f` is called
294    /// 4. `commit` is cancelled and the transaction rolls back and `f` is called
295    ///
296    /// Numbers 2 and 4 are not desirable. Number 2 can be handled by explicitly handling the error
297    /// case and disabling the guard but there is nothing to do about number 4.
298    pub async fn commit_and_then<F, R>(self, f: F) -> Result<R, sqlx::Error>
299    where
300        F: FnOnce() -> R + Send + 'static,
301        R: Send + 'static,
302    {
303        let span = Span::current();
304
305        task::spawn(async move {
306            // IMPORTANT: `_committed` must live until `f` completes.
307            let _committed = self.inner.commit().await?;
308            let result = span.in_scope(f);
309            Ok(result)
310        })
311        .await
312        .unwrap()
313    }
314}
315
316impl Deref for WriteTransaction {
317    type Target = ReadTransaction;
318
319    fn deref(&self) -> &Self::Target {
320        &self.inner
321    }
322}
323
324impl DerefMut for WriteTransaction {
325    fn deref_mut(&mut self) -> &mut Self::Target {
326        &mut self.inner
327    }
328}
329
330impl std::fmt::Debug for WriteTransaction {
331    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
332        write!(f, "WriteTransaction{{ inner:{:?} }}", self.inner)
333    }
334}
335
336impl_executor_by_deref!(WriteTransaction);
337
338#[derive(Debug, Error)]
339pub enum Error {
340    #[error("failed to create database directory")]
341    CreateDirectory(#[source] io::Error),
342    #[error("database already exists")]
343    Exists,
344    #[error("failed to open database")]
345    Open(#[source] sqlx::Error),
346    #[error("failed to execute database query")]
347    Query(#[from] sqlx::Error),
348}
349
350/// Creates a new database and opens a connection to it.
351pub(crate) async fn create(path: impl AsRef<Path>) -> Result<Pool, Error> {
352    let path = path.as_ref();
353
354    if fs::metadata(path).await.is_ok() {
355        return Err(Error::Exists);
356    }
357
358    create_directory(path).await?;
359
360    let connect_options = SqliteConnectOptions::new()
361        .filename(path)
362        .create_if_missing(true);
363
364    let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
365
366    migrations::run(&pool).await?;
367
368    Ok(pool)
369}
370
371/// Creates a new database in a temporary directory. Useful for tests.
372#[cfg(test)]
373pub(crate) async fn create_temp() -> Result<(TempDir, Pool), Error> {
374    let temp_dir = TempDir::new().map_err(Error::CreateDirectory)?;
375    let pool = create(temp_dir.path().join("temp.db")).await?;
376
377    Ok((temp_dir, pool))
378}
379
380/// Opens a connection to the specified database. Fails if the db doesn't exist.
381pub(crate) async fn open(path: impl AsRef<Path>) -> Result<Pool, Error> {
382    let connect_options = SqliteConnectOptions::new().filename(path);
383    let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
384
385    migrations::run(&pool).await?;
386
387    Ok(pool)
388}
389
390/// Opens a connection to the specified database. Fails if the db doesn't exist.
391pub async fn open_without_migrations(path: impl AsRef<Path>) -> Result<Pool, Error> {
392    let connect_options = SqliteConnectOptions::new().filename(path);
393    let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
394
395    Ok(pool)
396}
397
398async fn create_directory(path: &Path) -> Result<(), Error> {
399    if let Some(dir) = path.parent() {
400        fs::create_dir_all(dir)
401            .await
402            .map_err(Error::CreateDirectory)?
403    }
404
405    Ok(())
406}
407
408// Explicit cast from `i64` to `u64` to work around the lack of native `u64` support in the sqlx
409// crate.
410pub(crate) const fn decode_u64(i: i64) -> u64 {
411    i as u64
412}
413
414// Explicit cast from `u64` to `i64` to work around the lack of native `u64` support in the sqlx
415// crate.
416pub(crate) const fn encode_u64(u: u64) -> i64 {
417    u as i64
418}
419
420// Enable auto-vacuum on the given database unless already enabled
421async fn enable_auto_vacuum(db_path: &Path) -> Result<(), Error> {
422    let mut conn = SqliteConnectOptions::new()
423        .filename(db_path)
424        .connect()
425        .await?;
426
427    let auto_vacuum: u32 = sqlx::query("PRAGMA auto_vacuum")
428        .fetch_one(&mut conn)
429        .await?
430        .get(0);
431
432    if auto_vacuum != 0 {
433        return Ok(());
434    }
435
436    // VACUUM requires up to twice the size of the original database of free disk space. To prevent
437    // exhausting available disk space when opening multiple databases in parallel, we limit the
438    // number of dbs being vacuumed at the same time to one. Dbs that don't need to be vacuumed are
439    // not affected by this limit.
440    static SEMAPHORE: Semaphore = Semaphore::const_new(1);
441    let _permit = SEMAPHORE.acquire().await.unwrap();
442
443    sqlx::query("PRAGMA auto_vacuum=FULL")
444        .execute(&mut conn)
445        .await?;
446
447    // Execute `VACUUM` command for the `auto_vacuum` pragma to take effect.
448    sqlx::query("VACUUM").execute(&mut conn).await?;
449
450    Ok(())
451}
452
453async fn get_pragma(conn: &mut Connection, name: &str) -> Result<u32, Error> {
454    Ok(sqlx::query(&format!("PRAGMA {name}"))
455        .fetch_one(&mut *conn)
456        .await?
457        .get(0))
458}
459
460async fn set_pragma(conn: &mut Connection, name: &str, value: u32) -> Result<(), Error> {
461    // `bind` doesn't seem to be supported for setting PRAGMAs...
462    sqlx::query(&format!("PRAGMA {name} = {value}"))
463        .execute(&mut *conn)
464        .await?;
465
466    Ok(())
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    // Check the casts are lossless
474
475    #[test]
476    fn decode_u64_sanity_check() {
477        // [0i64,     i64::MAX] -> [0u64,             u64::MAX / 2]
478        // [i64::MIN,    -1i64] -> [u64::MAX / 2 + 1,     u64::MAX]
479
480        assert_eq!(decode_u64(0), 0);
481        assert_eq!(decode_u64(1), 1);
482        assert_eq!(decode_u64(-1), u64::MAX);
483        assert_eq!(decode_u64(i64::MIN), u64::MAX / 2 + 1);
484        assert_eq!(decode_u64(i64::MAX), u64::MAX / 2);
485    }
486
487    #[test]
488    fn encode_u64_sanity_check() {
489        assert_eq!(encode_u64(0), 0);
490        assert_eq!(encode_u64(1), 1);
491        assert_eq!(encode_u64(u64::MAX / 2), i64::MAX);
492        assert_eq!(encode_u64(u64::MAX / 2 + 1), i64::MIN);
493        assert_eq!(encode_u64(u64::MAX), -1);
494    }
495}