ouisync/db/
mod.rs

1#[macro_use]
2mod macros;
3
4mod connection;
5mod id;
6mod migrations;
7
8pub use id::DatabaseId;
9pub use migrations::SCHEMA_VERSION;
10
11use tracing::Span;
12
13use deadlock::ExpectShortLifetime;
14use ref_cast::RefCast;
15use sqlx::{
16    sqlite::{
17        Sqlite, SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous,
18        SqliteTransactionManager,
19    },
20    Row, SqlitePool, TransactionManager,
21};
22use std::{
23    fmt, io,
24    ops::{Deref, DerefMut},
25    panic::Location,
26    path::Path,
27    time::Duration,
28};
29#[cfg(test)]
30use tempfile::TempDir;
31use thiserror::Error;
32use tokio::{fs, task};
33
34const ACQUIRE_TIMEOUT: Duration = Duration::from_secs(5 * 60);
35const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
36const WARN_AFTER_CONNECTION_LIFETIME: Duration = Duration::from_secs(30);
37
38pub use self::connection::Connection;
39
40/// Database connection pool.
41#[derive(Clone)]
42pub struct Pool {
43    // Pool with multiple read-only connections
44    reads: SqlitePool,
45    // Pool with a single writable connection.
46    write: SqlitePool,
47}
48
49impl Pool {
50    async fn create(conn_options: SqliteConnectOptions) -> Result<Self, sqlx::Error> {
51        let conn_options = conn_options
52            .journal_mode(SqliteJournalMode::Wal)
53            .synchronous(SqliteSynchronous::Normal)
54            .pragma("recursive_triggers", "ON");
55
56        let pool_options = SqlitePoolOptions::new()
57            // Disable the test as it breaks cancel-safety (also it's unnecessary in our case)
58            .test_before_acquire(false)
59            // Expire idle connections to conserve resources (threads, file descriptors)
60            .idle_timeout(IDLE_TIMEOUT)
61            .acquire_timeout(ACQUIRE_TIMEOUT);
62
63        let write = pool_options
64            .clone()
65            .max_connections(1)
66            .connect_with(conn_options.clone().optimize_on_close(true, Some(1000)))
67            .await?;
68
69        let reads = pool_options
70            .max_connections(8)
71            .connect_with(conn_options.read_only(true))
72            .await?;
73
74        Ok(Self { reads, write })
75    }
76
77    /// Acquire a read-only database connection.
78    #[track_caller]
79    pub fn acquire(&self) -> impl Future<Output = Result<PoolConnection, sqlx::Error>> + '_ {
80        PoolConnection::acquire(&self.reads, Location::caller())
81    }
82
83    /// Begin a read-only transaction. See [`ReadTransaction`] for more details.
84    #[track_caller]
85    pub fn begin_read(&self) -> impl Future<Output = Result<ReadTransaction, sqlx::Error>> + '_ {
86        ReadTransaction::begin(&self.reads, Location::caller())
87    }
88
89    /// Begin a write transaction. See [`WriteTransaction`] for more details.
90    #[track_caller]
91    pub fn begin_write(&self) -> impl Future<Output = Result<WriteTransaction, sqlx::Error>> + '_ {
92        let location = Location::caller();
93
94        async move {
95            Ok(WriteTransaction {
96                inner: ReadTransaction::begin(&self.write, location).await?,
97            })
98        }
99    }
100
101    pub(crate) async fn close(&self) -> Result<(), sqlx::Error> {
102        // Make sure to first close `reads` and only then `write`. That way when closing the write
103        // connection it is the last remaining connection and so it performs a WAL checkpoint and
104        // removes the auxiliary db files (*-wal and *-shm).
105        self.reads.close().await;
106        self.write.close().await;
107
108        Ok(())
109    }
110}
111
112/// Database connection from pool
113pub struct PoolConnection {
114    inner: sqlx::pool::PoolConnection<Sqlite>,
115    _track_lifetime: ExpectShortLifetime,
116}
117
118impl PoolConnection {
119    // Internal
120    async fn acquire(
121        pool: &SqlitePool,
122        location: &'static Location<'static>,
123    ) -> Result<Self, sqlx::Error> {
124        let inner = pool.acquire().await?;
125        let track_lifetime = ExpectShortLifetime::new_in(WARN_AFTER_CONNECTION_LIFETIME, location);
126
127        Ok(Self {
128            inner,
129            _track_lifetime: track_lifetime,
130        })
131    }
132}
133
134impl Deref for PoolConnection {
135    type Target = Connection;
136
137    fn deref(&self) -> &Self::Target {
138        Connection::ref_cast(self.inner.deref())
139    }
140}
141
142impl DerefMut for PoolConnection {
143    fn deref_mut(&mut self) -> &mut Self::Target {
144        Connection::ref_cast_mut(self.inner.deref_mut())
145    }
146}
147
148/// Transaction that allows only reading.
149///
150/// This is useful if one wants to make sure the observed database content doesn't change for the
151/// duration of the transaction even in the presence of concurrent writes. In other words - a read
152/// transaction represents an immutable snapshot of the database at the point the transaction was
153/// created. A read transaction doesn't need to be committed or rolled back - it's implicitly ended
154/// when the `ReadTransaction` instance drops.
155pub struct ReadTransaction {
156    inner: PoolConnection,
157    closed: bool,
158}
159
160impl ReadTransaction {
161    // Internal
162    async fn begin(
163        pool: &SqlitePool,
164        location: &'static Location<'static>,
165    ) -> Result<Self, sqlx::Error> {
166        let mut inner = PoolConnection::acquire(pool, location).await?;
167        SqliteTransactionManager::begin(&mut inner.inner, None).await?;
168
169        Ok(Self {
170            inner,
171            closed: false,
172        })
173    }
174
175    // Internal
176    async fn commit(mut self) -> Result<Committed, sqlx::Error> {
177        SqliteTransactionManager::commit(&mut self.inner.inner).await?;
178        self.closed = true;
179        Ok(Committed(self))
180    }
181}
182
183impl Deref for ReadTransaction {
184    type Target = Connection;
185
186    fn deref(&self) -> &Self::Target {
187        self.inner.deref()
188    }
189}
190
191impl DerefMut for ReadTransaction {
192    fn deref_mut(&mut self) -> &mut Self::Target {
193        self.inner.deref_mut()
194    }
195}
196
197impl fmt::Debug for ReadTransaction {
198    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
199        f.debug_struct("ReadTransaction").finish_non_exhaustive()
200    }
201}
202
203impl_executor_by_deref!(ReadTransaction);
204
205impl Drop for ReadTransaction {
206    fn drop(&mut self) {
207        if !self.closed {
208            SqliteTransactionManager::start_rollback(&mut self.inner.inner);
209        }
210    }
211}
212
213// Wrapper for a transaction that's been committed. This allows to delay releasing the underlying
214// connection to the pool while disallowing using the connection to execute any db operations.
215struct Committed(#[allow(dead_code)] ReadTransaction);
216
217/// Transaction that allows both reading and writing.
218///
219/// At most one task can hold a write transaction at any time. Any other tasks are blocked on
220/// calling `begin_write` until the task that currently holds it is done with it (commits it or
221/// rolls it back). Performing read-only operations concurrently while a write transaction is in
222/// use is still allowed. Those operations will not see the writes performed via the write
223/// transaction until that transaction is committed however.
224pub struct WriteTransaction {
225    inner: ReadTransaction,
226}
227
228impl WriteTransaction {
229    /// Commits the transaction.
230    ///
231    /// # Cancel safety
232    ///
233    /// If the future returned by this function is cancelled before completion, the transaction
234    /// is guaranteed to be either committed or rolled back but there is no way to tell in advance
235    /// which of the two operations happens.
236    pub async fn commit(self) -> Result<(), sqlx::Error> {
237        self.inner.commit().await?;
238        Ok(())
239    }
240
241    /// Commits the transaction and if (and only if) the commit completes successfully, runs the
242    /// given closure.
243    ///
244    /// # Atomicity
245    ///
246    /// If the commit succeeds, the closure is guaranteed to complete before another write
247    /// transaction begins.
248    ///
249    /// # Cancel safety
250    ///
251    /// The commits completes and if it succeeds the closure gets called. This is guaranteed to
252    /// happen even if the future returned from this function is cancelled before completion.
253    ///
254    /// # Insufficient alternatives
255    ///
256    /// ## Calling `commit().await?` and then calling `f()`
257    ///
258    /// This is not enough because it has these possible outcomes depending on whether and when
259    /// cancellation happened:
260    ///
261    /// 1. `commit` completes successfully and `f` is called
262    /// 2. `commit` completes with error and `f` is not called
263    /// 3. `commit` is cancelled but the transaction is still committed and `f` is not called
264    /// 4. `commit` is cancelled and the transaction rolls back and `f` is not called
265    ///
266    /// Number 3 is typically not desirable.
267    ///
268    /// ## Calling `f` using a RAII guard
269    ///
270    /// This is still not enough because it has the following possible outcomes:
271    ///
272    /// 1. `commit` completes successfully and `f` is called
273    /// 2. `commit` completes with error and `f` is called
274    /// 3. `commit` is cancelled but the transaction is still committed and `f` is called
275    /// 4. `commit` is cancelled and the transaction rolls back and `f` is called
276    ///
277    /// Numbers 2 and 4 are not desirable. Number 2 can be handled by explicitly handling the error
278    /// case and disabling the guard but there is nothing to do about number 4.
279    pub async fn commit_and_then<F, R>(self, f: F) -> Result<R, sqlx::Error>
280    where
281        F: FnOnce() -> R + Send + 'static,
282        R: Send + 'static,
283    {
284        let span = Span::current();
285
286        task::spawn(async move {
287            // IMPORTANT: `_committed` must live until `f` completes.
288            let _committed = self.inner.commit().await?;
289            let result = span.in_scope(f);
290            Ok(result)
291        })
292        .await
293        .unwrap()
294    }
295}
296
297impl Deref for WriteTransaction {
298    type Target = ReadTransaction;
299
300    fn deref(&self) -> &Self::Target {
301        &self.inner
302    }
303}
304
305impl DerefMut for WriteTransaction {
306    fn deref_mut(&mut self) -> &mut Self::Target {
307        &mut self.inner
308    }
309}
310
311impl std::fmt::Debug for WriteTransaction {
312    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
313        write!(f, "WriteTransaction{{ inner:{:?} }}", self.inner)
314    }
315}
316
317impl_executor_by_deref!(WriteTransaction);
318
319/// Creates a new database and opens a connection to it.
320pub(crate) async fn create(path: impl AsRef<Path>) -> Result<Pool, Error> {
321    let path = path.as_ref();
322
323    if fs::metadata(path).await.is_ok() {
324        return Err(Error::Exists);
325    }
326
327    create_directory(path).await?;
328
329    let connect_options = SqliteConnectOptions::new()
330        .filename(path)
331        .create_if_missing(true);
332
333    let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
334
335    migrations::run(&pool).await?;
336
337    Ok(pool)
338}
339
340/// Creates a new database in a temporary directory. Useful for tests.
341#[cfg(test)]
342pub(crate) async fn create_temp() -> Result<(TempDir, Pool), Error> {
343    let temp_dir = TempDir::new().map_err(Error::CreateDirectory)?;
344    let pool = create(temp_dir.path().join("temp.db")).await?;
345
346    Ok((temp_dir, pool))
347}
348
349/// Opens a connection to the specified database. Fails if the db doesn't exist.
350pub(crate) async fn open(path: impl AsRef<Path>) -> Result<Pool, Error> {
351    let connect_options = SqliteConnectOptions::new().filename(path);
352    let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
353
354    migrations::run(&pool).await?;
355
356    Ok(pool)
357}
358
359/// Opens a connection to the specified database. Fails if the db doesn't exist.
360pub async fn open_without_migrations(path: impl AsRef<Path>) -> Result<Pool, Error> {
361    let connect_options = SqliteConnectOptions::new().filename(path);
362    let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
363
364    Ok(pool)
365}
366
367async fn create_directory(path: &Path) -> Result<(), Error> {
368    if let Some(dir) = path.parent() {
369        fs::create_dir_all(dir)
370            .await
371            .map_err(Error::CreateDirectory)?
372    }
373
374    Ok(())
375}
376
377// Explicit cast from `i64` to `u64` to work around the lack of native `u64` support in the sqlx
378// crate.
379pub(crate) const fn decode_u64(i: i64) -> u64 {
380    i as u64
381}
382
383// Explicit cast from `u64` to `i64` to work around the lack of native `u64` support in the sqlx
384// crate.
385pub(crate) const fn encode_u64(u: u64) -> i64 {
386    u as i64
387}
388
389#[derive(Debug, Error)]
390pub enum Error {
391    #[error("failed to create database directory")]
392    CreateDirectory(#[source] io::Error),
393    #[error("database already exists")]
394    Exists,
395    #[error("failed to open database")]
396    Open(#[source] sqlx::Error),
397    #[error("failed to execute database query")]
398    Query(#[from] sqlx::Error),
399}
400
401async fn get_pragma(conn: &mut Connection, name: &str) -> Result<u32, Error> {
402    Ok(sqlx::query(&format!("PRAGMA {}", name))
403        .fetch_one(&mut *conn)
404        .await?
405        .get(0))
406}
407
408async fn set_pragma(conn: &mut Connection, name: &str, value: u32) -> Result<(), Error> {
409    // `bind` doesn't seem to be supported for setting PRAGMAs...
410    sqlx::query(&format!("PRAGMA {} = {}", name, value))
411        .execute(&mut *conn)
412        .await?;
413
414    Ok(())
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    // Check the casts are lossless
422
423    #[test]
424    fn decode_u64_sanity_check() {
425        // [0i64,     i64::MAX] -> [0u64,             u64::MAX / 2]
426        // [i64::MIN,    -1i64] -> [u64::MAX / 2 + 1,     u64::MAX]
427
428        assert_eq!(decode_u64(0), 0);
429        assert_eq!(decode_u64(1), 1);
430        assert_eq!(decode_u64(-1), u64::MAX);
431        assert_eq!(decode_u64(i64::MIN), u64::MAX / 2 + 1);
432        assert_eq!(decode_u64(i64::MAX), u64::MAX / 2);
433    }
434
435    #[test]
436    fn encode_u64_sanity_check() {
437        assert_eq!(encode_u64(0), 0);
438        assert_eq!(encode_u64(1), 1);
439        assert_eq!(encode_u64(u64::MAX / 2), i64::MAX);
440        assert_eq!(encode_u64(u64::MAX / 2 + 1), i64::MIN);
441        assert_eq!(encode_u64(u64::MAX), -1);
442    }
443}