1#[macro_use]
2mod macros;
3
4mod connection;
5mod id;
6mod migrations;
7
8pub use id::DatabaseId;
9pub use migrations::SCHEMA_VERSION;
10
11use tracing::Span;
12
13use deadlock::ExpectShortLifetime;
14use ref_cast::RefCast;
15use sqlx::{
16 sqlite::{
17 Sqlite, SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous,
18 SqliteTransactionManager,
19 },
20 Row, SqlitePool, TransactionManager,
21};
22use std::{
23 fmt, io,
24 ops::{Deref, DerefMut},
25 panic::Location,
26 path::Path,
27 time::Duration,
28};
29#[cfg(test)]
30use tempfile::TempDir;
31use thiserror::Error;
32use tokio::{fs, task};
33
34const ACQUIRE_TIMEOUT: Duration = Duration::from_secs(5 * 60);
35const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
36const WARN_AFTER_CONNECTION_LIFETIME: Duration = Duration::from_secs(30);
37
38pub use self::connection::Connection;
39
40#[derive(Clone)]
42pub struct Pool {
43 reads: SqlitePool,
45 write: SqlitePool,
47}
48
49impl Pool {
50 async fn create(conn_options: SqliteConnectOptions) -> Result<Self, sqlx::Error> {
51 let conn_options = conn_options
52 .journal_mode(SqliteJournalMode::Wal)
53 .synchronous(SqliteSynchronous::Normal)
54 .pragma("recursive_triggers", "ON");
55
56 let pool_options = SqlitePoolOptions::new()
57 .test_before_acquire(false)
59 .idle_timeout(IDLE_TIMEOUT)
61 .acquire_timeout(ACQUIRE_TIMEOUT);
62
63 let write = pool_options
64 .clone()
65 .max_connections(1)
66 .connect_with(conn_options.clone().optimize_on_close(true, Some(1000)))
67 .await?;
68
69 let reads = pool_options
70 .max_connections(8)
71 .connect_with(conn_options.read_only(true))
72 .await?;
73
74 Ok(Self { reads, write })
75 }
76
77 #[track_caller]
79 pub fn acquire(&self) -> impl Future<Output = Result<PoolConnection, sqlx::Error>> + '_ {
80 PoolConnection::acquire(&self.reads, Location::caller())
81 }
82
83 #[track_caller]
85 pub fn begin_read(&self) -> impl Future<Output = Result<ReadTransaction, sqlx::Error>> + '_ {
86 ReadTransaction::begin(&self.reads, Location::caller())
87 }
88
89 #[track_caller]
91 pub fn begin_write(&self) -> impl Future<Output = Result<WriteTransaction, sqlx::Error>> + '_ {
92 let location = Location::caller();
93
94 async move {
95 Ok(WriteTransaction {
96 inner: ReadTransaction::begin(&self.write, location).await?,
97 })
98 }
99 }
100
101 pub(crate) async fn close(&self) -> Result<(), sqlx::Error> {
102 self.reads.close().await;
106 self.write.close().await;
107
108 Ok(())
109 }
110}
111
112pub struct PoolConnection {
114 inner: sqlx::pool::PoolConnection<Sqlite>,
115 _track_lifetime: ExpectShortLifetime,
116}
117
118impl PoolConnection {
119 async fn acquire(
121 pool: &SqlitePool,
122 location: &'static Location<'static>,
123 ) -> Result<Self, sqlx::Error> {
124 let inner = pool.acquire().await?;
125 let track_lifetime = ExpectShortLifetime::new_in(WARN_AFTER_CONNECTION_LIFETIME, location);
126
127 Ok(Self {
128 inner,
129 _track_lifetime: track_lifetime,
130 })
131 }
132}
133
134impl Deref for PoolConnection {
135 type Target = Connection;
136
137 fn deref(&self) -> &Self::Target {
138 Connection::ref_cast(self.inner.deref())
139 }
140}
141
142impl DerefMut for PoolConnection {
143 fn deref_mut(&mut self) -> &mut Self::Target {
144 Connection::ref_cast_mut(self.inner.deref_mut())
145 }
146}
147
148pub struct ReadTransaction {
156 inner: PoolConnection,
157 closed: bool,
158}
159
160impl ReadTransaction {
161 async fn begin(
163 pool: &SqlitePool,
164 location: &'static Location<'static>,
165 ) -> Result<Self, sqlx::Error> {
166 let mut inner = PoolConnection::acquire(pool, location).await?;
167 SqliteTransactionManager::begin(&mut inner.inner, None).await?;
168
169 Ok(Self {
170 inner,
171 closed: false,
172 })
173 }
174
175 async fn commit(mut self) -> Result<Committed, sqlx::Error> {
177 SqliteTransactionManager::commit(&mut self.inner.inner).await?;
178 self.closed = true;
179 Ok(Committed(self))
180 }
181}
182
183impl Deref for ReadTransaction {
184 type Target = Connection;
185
186 fn deref(&self) -> &Self::Target {
187 self.inner.deref()
188 }
189}
190
191impl DerefMut for ReadTransaction {
192 fn deref_mut(&mut self) -> &mut Self::Target {
193 self.inner.deref_mut()
194 }
195}
196
197impl fmt::Debug for ReadTransaction {
198 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
199 f.debug_struct("ReadTransaction").finish_non_exhaustive()
200 }
201}
202
203impl_executor_by_deref!(ReadTransaction);
204
205impl Drop for ReadTransaction {
206 fn drop(&mut self) {
207 if !self.closed {
208 SqliteTransactionManager::start_rollback(&mut self.inner.inner);
209 }
210 }
211}
212
213struct Committed(#[allow(dead_code)] ReadTransaction);
216
217pub struct WriteTransaction {
225 inner: ReadTransaction,
226}
227
228impl WriteTransaction {
229 pub async fn commit(self) -> Result<(), sqlx::Error> {
237 self.inner.commit().await?;
238 Ok(())
239 }
240
241 pub async fn commit_and_then<F, R>(self, f: F) -> Result<R, sqlx::Error>
280 where
281 F: FnOnce() -> R + Send + 'static,
282 R: Send + 'static,
283 {
284 let span = Span::current();
285
286 task::spawn(async move {
287 let _committed = self.inner.commit().await?;
289 let result = span.in_scope(f);
290 Ok(result)
291 })
292 .await
293 .unwrap()
294 }
295}
296
297impl Deref for WriteTransaction {
298 type Target = ReadTransaction;
299
300 fn deref(&self) -> &Self::Target {
301 &self.inner
302 }
303}
304
305impl DerefMut for WriteTransaction {
306 fn deref_mut(&mut self) -> &mut Self::Target {
307 &mut self.inner
308 }
309}
310
311impl std::fmt::Debug for WriteTransaction {
312 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
313 write!(f, "WriteTransaction{{ inner:{:?} }}", self.inner)
314 }
315}
316
317impl_executor_by_deref!(WriteTransaction);
318
319pub(crate) async fn create(path: impl AsRef<Path>) -> Result<Pool, Error> {
321 let path = path.as_ref();
322
323 if fs::metadata(path).await.is_ok() {
324 return Err(Error::Exists);
325 }
326
327 create_directory(path).await?;
328
329 let connect_options = SqliteConnectOptions::new()
330 .filename(path)
331 .create_if_missing(true);
332
333 let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
334
335 migrations::run(&pool).await?;
336
337 Ok(pool)
338}
339
340#[cfg(test)]
342pub(crate) async fn create_temp() -> Result<(TempDir, Pool), Error> {
343 let temp_dir = TempDir::new().map_err(Error::CreateDirectory)?;
344 let pool = create(temp_dir.path().join("temp.db")).await?;
345
346 Ok((temp_dir, pool))
347}
348
349pub(crate) async fn open(path: impl AsRef<Path>) -> Result<Pool, Error> {
351 let connect_options = SqliteConnectOptions::new().filename(path);
352 let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
353
354 migrations::run(&pool).await?;
355
356 Ok(pool)
357}
358
359pub async fn open_without_migrations(path: impl AsRef<Path>) -> Result<Pool, Error> {
361 let connect_options = SqliteConnectOptions::new().filename(path);
362 let pool = Pool::create(connect_options).await.map_err(Error::Open)?;
363
364 Ok(pool)
365}
366
367async fn create_directory(path: &Path) -> Result<(), Error> {
368 if let Some(dir) = path.parent() {
369 fs::create_dir_all(dir)
370 .await
371 .map_err(Error::CreateDirectory)?
372 }
373
374 Ok(())
375}
376
377pub(crate) const fn decode_u64(i: i64) -> u64 {
380 i as u64
381}
382
383pub(crate) const fn encode_u64(u: u64) -> i64 {
386 u as i64
387}
388
389#[derive(Debug, Error)]
390pub enum Error {
391 #[error("failed to create database directory")]
392 CreateDirectory(#[source] io::Error),
393 #[error("database already exists")]
394 Exists,
395 #[error("failed to open database")]
396 Open(#[source] sqlx::Error),
397 #[error("failed to execute database query")]
398 Query(#[from] sqlx::Error),
399}
400
401async fn get_pragma(conn: &mut Connection, name: &str) -> Result<u32, Error> {
402 Ok(sqlx::query(&format!("PRAGMA {}", name))
403 .fetch_one(&mut *conn)
404 .await?
405 .get(0))
406}
407
408async fn set_pragma(conn: &mut Connection, name: &str, value: u32) -> Result<(), Error> {
409 sqlx::query(&format!("PRAGMA {} = {}", name, value))
411 .execute(&mut *conn)
412 .await?;
413
414 Ok(())
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
424 fn decode_u64_sanity_check() {
425 assert_eq!(decode_u64(0), 0);
429 assert_eq!(decode_u64(1), 1);
430 assert_eq!(decode_u64(-1), u64::MAX);
431 assert_eq!(decode_u64(i64::MIN), u64::MAX / 2 + 1);
432 assert_eq!(decode_u64(i64::MAX), u64::MAX / 2);
433 }
434
435 #[test]
436 fn encode_u64_sanity_check() {
437 assert_eq!(encode_u64(0), 0);
438 assert_eq!(encode_u64(1), 1);
439 assert_eq!(encode_u64(u64::MAX / 2), i64::MAX);
440 assert_eq!(encode_u64(u64::MAX / 2 + 1), i64::MIN);
441 assert_eq!(encode_u64(u64::MAX), -1);
442 }
443}