1#[macro_use]
2mod macros;
3
4mod connection;
5mod id;
6mod migrations;
7
8pub use id::DatabaseId;
9pub use migrations::SCHEMA_VERSION;
10
11use tracing::Span;
12
13use deadlock::ExpectShortLifetime;
14use ref_cast::RefCast;
15use sqlx::{
16 ConnectOptions, Row, SqlitePool, TransactionManager,
17 sqlite::{
18 Sqlite, SqliteAutoVacuum, SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions,
19 SqliteSynchronous, SqliteTransactionManager,
20 },
21};
22use std::{
23 fmt, io,
24 ops::{Deref, DerefMut},
25 panic::Location,
26 path::Path,
27 time::Duration,
28};
29#[cfg(test)]
30use tempfile::TempDir;
31use thiserror::Error;
32use tokio::{fs, sync::Semaphore, task};
33
34const ACQUIRE_TIMEOUT: Duration = Duration::from_secs(5 * 60);
35const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
36const WARN_AFTER_CONNECTION_LIFETIME: Duration = Duration::from_secs(30);
37
38pub use self::connection::Connection;
39
40#[derive(Clone)]
42pub struct Pool {
43 reads: SqlitePool,
45 write: SqlitePool,
47}
48
49impl Pool {
50 async fn create(conn_options: SqliteConnectOptions) -> Result<Self, sqlx::Error> {
51 if fs::try_exists(conn_options.get_filename())
52 .await
53 .unwrap_or(false)
54 {
55 if let Err(error) = enable_auto_vacuum(conn_options.get_filename()).await {
61 tracing::warn!(?error, "failed to enable auto vacuum");
62 }
63 }
64
65 let conn_options = conn_options
66 .journal_mode(SqliteJournalMode::Wal)
67 .synchronous(SqliteSynchronous::Normal)
68 .pragma("recursive_triggers", "ON");
69
70 let pool_options = SqlitePoolOptions::new()
71 .test_before_acquire(false)
73 .idle_timeout(IDLE_TIMEOUT)
75 .acquire_timeout(ACQUIRE_TIMEOUT);
76
77 let write = pool_options
78 .clone()
79 .max_connections(1)
80 .connect_with(
81 conn_options
82 .clone()
83 .auto_vacuum(SqliteAutoVacuum::Full)
84 .optimize_on_close(true, Some(1000)),
85 )
86 .await?;
87
88 let reads = pool_options
89 .max_connections(8)
90 .connect_with(conn_options.read_only(true))
91 .await?;
92
93 Ok(Self { reads, write })
94 }
95
96 #[track_caller]
98 pub fn acquire(&self) -> impl Future<Output = Result<PoolConnection, sqlx::Error>> + '_ {
99 PoolConnection::acquire(&self.reads, Location::caller())
100 }
101
102 #[track_caller]
104 pub fn begin_read(&self) -> impl Future<Output = Result<ReadTransaction, sqlx::Error>> + '_ {
105 ReadTransaction::begin(&self.reads, Location::caller())
106 }
107
108 #[track_caller]
110 pub fn begin_write(&self) -> impl Future<Output = Result<WriteTransaction, sqlx::Error>> + '_ {
111 let location = Location::caller();
112
113 async move {
114 Ok(WriteTransaction {
115 inner: ReadTransaction::begin(&self.write, location).await?,
116 })
117 }
118 }
119
120 pub(crate) async fn close(&self) -> Result<(), sqlx::Error> {
121 self.reads.close().await;
125 self.write.close().await;
126
127 Ok(())
128 }
129}
130
131pub struct PoolConnection {
133 inner: sqlx::pool::PoolConnection<Sqlite>,
134 _track_lifetime: ExpectShortLifetime,
135}
136
137impl PoolConnection {
138 async fn acquire(
140 pool: &SqlitePool,
141 location: &'static Location<'static>,
142 ) -> Result<Self, sqlx::Error> {
143 let inner = pool.acquire().await?;
144 let track_lifetime = ExpectShortLifetime::new_in(WARN_AFTER_CONNECTION_LIFETIME, location);
145
146 Ok(Self {
147 inner,
148 _track_lifetime: track_lifetime,
149 })
150 }
151}
152
153impl Deref for PoolConnection {
154 type Target = Connection;
155
156 fn deref(&self) -> &Self::Target {
157 Connection::ref_cast(self.inner.deref())
158 }
159}
160
161impl DerefMut for PoolConnection {
162 fn deref_mut(&mut self) -> &mut Self::Target {
163 Connection::ref_cast_mut(self.inner.deref_mut())
164 }
165}
166
167pub struct ReadTransaction {
175 inner: PoolConnection,
176 closed: bool,
177}
178
179impl ReadTransaction {
180 async fn begin(
182 pool: &SqlitePool,
183 location: &'static Location<'static>,
184 ) -> Result<Self, sqlx::Error> {
185 let mut inner = PoolConnection::acquire(pool, location).await?;
186 SqliteTransactionManager::begin(&mut inner.inner, None).await?;
187
188 Ok(Self {
189 inner,
190 closed: false,
191 })
192 }
193
194 async fn commit(mut self) -> Result<Committed, sqlx::Error> {
196 SqliteTransactionManager::commit(&mut self.inner.inner).await?;
197 self.closed = true;
198 Ok(Committed(self))
199 }
200}
201
202impl Deref for ReadTransaction {
203 type Target = Connection;
204
205 fn deref(&self) -> &Self::Target {
206 self.inner.deref()
207 }
208}
209
210impl DerefMut for ReadTransaction {
211 fn deref_mut(&mut self) -> &mut Self::Target {
212 self.inner.deref_mut()
213 }
214}
215
216impl fmt::Debug for ReadTransaction {
217 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
218 f.debug_struct("ReadTransaction").finish_non_exhaustive()
219 }
220}
221
222impl_executor_by_deref!(ReadTransaction);
223
224impl Drop for ReadTransaction {
225 fn drop(&mut self) {
226 if !self.closed {
227 SqliteTransactionManager::start_rollback(&mut self.inner.inner);
228 }
229 }
230}
231
232struct Committed(#[allow(dead_code)] ReadTransaction);
235
236pub struct WriteTransaction {
244 inner: ReadTransaction,
245}
246
247impl WriteTransaction {
248 pub async fn commit(self) -> Result<(), sqlx::Error> {
256 self.inner.commit().await?;
257 Ok(())
258 }
259
260 pub async fn commit_and_then<F, R>(self, f: F) -> Result<R, sqlx::Error>
299 where
300 F: FnOnce() -> R + Send + 'static,
301 R: Send + 'static,
302 {
303 let span = Span::current();
304
305 task::spawn(async move {
306 let _committed = self.inner.commit().await?;
308 let result = span.in_scope(f);
309 Ok(result)
310 })
311 .await
312 .unwrap()
313 }
314}
315
316impl Deref for WriteTransaction {
317 type Target = ReadTransaction;
318
319 fn deref(&self) -> &Self::Target {
320 &self.inner
321 }
322}
323
324impl DerefMut for WriteTransaction {
325 fn deref_mut(&mut self) -> &mut Self::Target {
326 &mut self.inner
327 }
328}
329
330impl std::fmt::Debug for WriteTransaction {
331 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
332 write!(f, "WriteTransaction{{ inner:{:?} }}", self.inner)
333 }
334}
335
336impl_executor_by_deref!(WriteTransaction);
337
338#[derive(Debug, Error)]
339pub enum Error {
340 #[error("failed to create database directory")]
341 CreateDirectory(#[source] io::Error),
342 #[error("database already exists")]
343 Exists,
344 #[error("failed to open database")]
345 Open(#[source] sqlx::Error),
346 #[error("failed to execute database query")]
347 Query(#[from] sqlx::Error),
348}
349
350pub(crate) async fn create(path: impl AsRef<Path>) -> Result<Pool, Error> {
352 let path = path.as_ref();
353
354 if fs::metadata(path).await.is_ok() {
355 return Err(Error::Exists);
356 }
357
358 create_directory(path).await?;
359
360 let connect_options = SqliteConnectOptions::new()
361 .filename(path)
362 .create_if_missing(true);
363
364 let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
365
366 migrations::run(&pool).await?;
367
368 Ok(pool)
369}
370
371#[cfg(test)]
373pub(crate) async fn create_temp() -> Result<(TempDir, Pool), Error> {
374 let temp_dir = TempDir::new().map_err(Error::CreateDirectory)?;
375 let pool = create(temp_dir.path().join("temp.db")).await?;
376
377 Ok((temp_dir, pool))
378}
379
380pub(crate) async fn open(path: impl AsRef<Path>) -> Result<Pool, Error> {
382 let connect_options = SqliteConnectOptions::new().filename(path);
383 let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
384
385 migrations::run(&pool).await?;
386
387 Ok(pool)
388}
389
390pub async fn open_without_migrations(path: impl AsRef<Path>) -> Result<Pool, Error> {
392 let connect_options = SqliteConnectOptions::new().filename(path);
393 let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
394
395 Ok(pool)
396}
397
398async fn create_directory(path: &Path) -> Result<(), Error> {
399 if let Some(dir) = path.parent() {
400 fs::create_dir_all(dir)
401 .await
402 .map_err(Error::CreateDirectory)?
403 }
404
405 Ok(())
406}
407
408pub(crate) const fn decode_u64(i: i64) -> u64 {
411 i as u64
412}
413
414pub(crate) const fn encode_u64(u: u64) -> i64 {
417 u as i64
418}
419
420async fn enable_auto_vacuum(db_path: &Path) -> Result<(), Error> {
422 let mut conn = SqliteConnectOptions::new()
423 .filename(db_path)
424 .connect()
425 .await?;
426
427 let auto_vacuum: u32 = sqlx::query("PRAGMA auto_vacuum")
428 .fetch_one(&mut conn)
429 .await?
430 .get(0);
431
432 if auto_vacuum != 0 {
433 return Ok(());
434 }
435
436 static SEMAPHORE: Semaphore = Semaphore::const_new(1);
441 let _permit = SEMAPHORE.acquire().await.unwrap();
442
443 sqlx::query("PRAGMA auto_vacuum=FULL")
444 .execute(&mut conn)
445 .await?;
446
447 sqlx::query("VACUUM").execute(&mut conn).await?;
449
450 Ok(())
451}
452
453async fn get_pragma(conn: &mut Connection, name: &str) -> Result<u32, Error> {
454 Ok(sqlx::query(&format!("PRAGMA {name}"))
455 .fetch_one(&mut *conn)
456 .await?
457 .get(0))
458}
459
460async fn set_pragma(conn: &mut Connection, name: &str, value: u32) -> Result<(), Error> {
461 sqlx::query(&format!("PRAGMA {name} = {value}"))
463 .execute(&mut *conn)
464 .await?;
465
466 Ok(())
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
476 fn decode_u64_sanity_check() {
477 assert_eq!(decode_u64(0), 0);
481 assert_eq!(decode_u64(1), 1);
482 assert_eq!(decode_u64(-1), u64::MAX);
483 assert_eq!(decode_u64(i64::MIN), u64::MAX / 2 + 1);
484 assert_eq!(decode_u64(i64::MAX), u64::MAX / 2);
485 }
486
487 #[test]
488 fn encode_u64_sanity_check() {
489 assert_eq!(encode_u64(0), 0);
490 assert_eq!(encode_u64(1), 1);
491 assert_eq!(encode_u64(u64::MAX / 2), i64::MAX);
492 assert_eq!(encode_u64(u64::MAX / 2 + 1), i64::MIN);
493 assert_eq!(encode_u64(u64::MAX), -1);
494 }
495}