ouisync/db/
migrations.rs

1use super::{get_pragma, set_pragma, Connection, Error, Pool};
2use include_dir::{include_dir, Dir, File};
3use std::sync::LazyLock;
4
5/// Latest schema version
6pub static SCHEMA_VERSION: LazyLock<u32> = LazyLock::new(|| {
7    MIGRATIONS
8        .files()
9        .filter_map(get_migration)
10        .map(|(version, _)| version)
11        .max()
12        .unwrap_or(0)
13});
14
15/// Apply all pending migrations.
16pub(super) async fn run(pool: &Pool) -> Result<(), Error> {
17    let mut migrations: Vec<_> = MIGRATIONS.files().filter_map(get_migration).collect();
18    migrations.sort_by_key(|(version, _)| *version);
19
20    for (version, sql) in migrations {
21        apply(pool, version, sql).await?;
22    }
23
24    Ok(())
25}
26
27static MIGRATIONS: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/src/db/migrations");
28
29fn get_migration<'a>(file: &'a File<'_>) -> Option<(u32, &'a str)> {
30    if !file
31        .path()
32        .extension()
33        .map(|ext| ext == "sql")
34        .unwrap_or(false)
35    {
36        return None;
37    }
38
39    let stem = file.path().file_stem()?.to_str()?;
40
41    if !stem.starts_with('v') {
42        return None;
43    }
44    let version: u32 = stem[1..].parse().ok()?;
45    let sql = file.contents_utf8()?;
46
47    Some((version, sql))
48}
49
50async fn apply(pool: &Pool, dst_version: u32, sql: &str) -> Result<(), Error> {
51    let mut tx = pool.begin_write().await?;
52
53    let src_version = get_version(&mut tx).await?;
54    if src_version >= dst_version {
55        return Ok(());
56    }
57
58    assert_eq!(
59        dst_version,
60        src_version + 1,
61        "migrations must be applied in order"
62    );
63
64    sqlx::query(sql).execute(&mut tx).await?;
65    set_version(&mut tx, dst_version).await?;
66
67    tx.commit().await?;
68
69    Ok(())
70}
71
72/// Gets the current schema version of the database.
73async fn get_version(conn: &mut Connection) -> Result<u32, Error> {
74    get_pragma(conn, "user_version").await
75}
76
77/// Sets the current schema version of the database.
78async fn set_version(conn: &mut Connection, value: u32) -> Result<(), Error> {
79    set_pragma(conn, "user_version", value).await
80}