ouisync/store/
migrations.rs

1use super::{error::Error, Store, WriteTransaction};
2use crate::{
3    crypto::sign::{Keypair, PublicKey},
4    repository::data_version,
5};
6
7pub const DATA_VERSION: u64 = 1;
8
9pub(super) async fn run_data(
10    store: &Store,
11    this_writer_id: PublicKey,
12    write_keys: &Keypair,
13) -> Result<(), Error> {
14    v1::run(store, this_writer_id, write_keys).await?;
15
16    // Ensure we are at the latest version.
17    assert_eq!(
18        data_version::get(store.acquire_read().await?.db()).await?,
19        DATA_VERSION
20    );
21
22    Ok(())
23}
24
25async fn begin(store: &Store, dst_version: u64) -> Result<Option<WriteTransaction>, Error> {
26    let mut tx = store.begin_write().await?;
27
28    let src_version = data_version::get(tx.db()).await?;
29    if src_version >= dst_version {
30        return Ok(None);
31    }
32
33    assert_eq!(
34        dst_version,
35        src_version + 1,
36        "migrations must be applied in order"
37    );
38
39    // Bumping the data version before running the migration. This is OK because if the migration
40    // fails, it gets rolled back.
41    data_version::set(tx.db(), dst_version).await?;
42
43    Ok(Some(tx))
44}
45
46/// Recompute block ids so the new ids are computed from both the ciphertext and the nonce (as
47/// opposed from only the ciphertext), to protect against nonce tampering.
48mod v1 {
49    use super::{
50        super::{root_node, Changeset, Reader},
51        *,
52    };
53    use crate::{
54        crypto::{sign::Keypair, Hash},
55        protocol::{BlockContent, BlockId, BlockNonce, LeafNode, Proof, RootNode, RootNodeFilter},
56    };
57    use futures_util::TryStreamExt;
58    use sqlx::Row;
59
60    pub(super) async fn run(
61        store: &Store,
62        this_writer_id: PublicKey,
63        write_keys: &Keypair,
64    ) -> Result<(), Error> {
65        let Some(mut tx) = begin(store, 1).await? else {
66            return Ok(());
67        };
68
69        // Temporary table to map old block ids to new block ids.
70        sqlx::query(
71            "CREATE TEMPORARY TABLE block_id_translations (
72                 old_block_id BLOB NOT NULL UNIQUE,
73                 new_block_id BLOB NOT NULL UNIQUE
74             )",
75        )
76        .execute(tx.db())
77        .await?;
78
79        recompute_block_ids(&mut tx).await?;
80        recompute_index_hashes(&mut tx, this_writer_id, write_keys).await?;
81
82        // Remove the temp table
83        sqlx::query("DROP TABLE block_id_translations")
84            .execute(tx.db())
85            .await?;
86
87        tx.commit().await?;
88
89        Ok(())
90    }
91
92    async fn recompute_block_ids(tx: &mut WriteTransaction) -> Result<(), Error> {
93        loop {
94            let map: Vec<_> = sqlx::query(
95                "SELECT id, nonce, content
96                 FROM blocks
97                 WHERE id NOT IN (SELECT new_block_id FROM block_id_translations)
98                 LIMIT 1024",
99            )
100            .fetch(tx.db())
101            .try_filter_map(|row| async move {
102                let old_id: BlockId = row.get(0);
103
104                let nonce: &[u8] = row.get(1);
105                let nonce = BlockNonce::try_from(nonce)
106                    .map_err(|error| sqlx::Error::Decode(error.into()))?;
107
108                let mut content = BlockContent::new();
109                content.copy_from_slice(row.get(2));
110
111                let new_id = BlockId::new(&content, &nonce);
112
113                if new_id != old_id {
114                    Ok(Some((old_id, new_id)))
115                } else {
116                    Ok(None)
117                }
118            })
119            .try_collect()
120            .await?;
121
122            if map.is_empty() {
123                break;
124            }
125
126            for (old_id, new_id) in map {
127                sqlx::query("UPDATE blocks SET id = ? WHERE id = ?")
128                    .bind(&new_id)
129                    .bind(&old_id)
130                    .execute(tx.db())
131                    .await?;
132
133                sqlx::query(
134                    "INSERT INTO block_id_translations (old_block_id, new_block_id) VALUES (?, ?)",
135                )
136                .bind(&old_id)
137                .bind(&new_id)
138                .execute(tx.db())
139                .await?;
140            }
141        }
142
143        Ok(())
144    }
145
146    async fn recompute_index_hashes(
147        tx: &mut WriteTransaction,
148        this_writer_id: PublicKey,
149        write_keys: &Keypair,
150    ) -> Result<(), Error> {
151        let root_nodes: Vec<_> = root_node::load_all_latest_approved(tx.db())
152            .try_collect()
153            .await?;
154
155        for root_node in root_nodes {
156            recompute_index_hashes_at(tx, root_node, this_writer_id, write_keys).await?;
157        }
158
159        Ok(())
160    }
161
162    async fn recompute_index_hashes_at(
163        tx: &mut WriteTransaction,
164        root_node: RootNode,
165        this_writer_id: PublicKey,
166        write_keys: &Keypair,
167    ) -> Result<(), Error> {
168        let mut last_locator = Hash::from([0; Hash::SIZE]);
169
170        loop {
171            let leaf_nodes = load_leaf_nodes(tx, &root_node, &last_locator, 1024).await?;
172            if leaf_nodes.is_empty() {
173                break;
174            }
175
176            let mut changeset = Changeset::new();
177
178            // Link the locators to the new block ids
179            for leaf_node in leaf_nodes {
180                changeset.link_block(
181                    leaf_node.locator,
182                    leaf_node.block_id,
183                    leaf_node.block_presence,
184                );
185
186                last_locator = leaf_node.locator;
187            }
188
189            changeset
190                .apply(tx, &root_node.proof.writer_id, write_keys)
191                .await?;
192        }
193
194        let new_root_node =
195            root_node::load_latest_approved(tx.db(), &root_node.proof.writer_id).await?;
196        let new_root_node = if new_root_node.proof.writer_id == this_writer_id {
197            // Bump the vv of the local branch
198            let hash = new_root_node.proof.hash;
199            let version_vector = new_root_node
200                .proof
201                .into_version_vector()
202                .incremented(this_writer_id);
203            let proof = Proof::new(this_writer_id, version_vector, hash, write_keys);
204            let (new_root_node, _) =
205                root_node::create(tx.db(), proof, new_root_node.summary, RootNodeFilter::Any)
206                    .await?;
207
208            new_root_node
209        } else {
210            new_root_node
211        };
212
213        // Remove the original snapshot and any intermediate snapshots created during the migration
214        // (child nodes are removed by db triggers).
215        sqlx::query(
216            "DELETE FROM snapshot_root_nodes
217             WHERE writer_id = ? AND snapshot_id >= ? AND snapshot_id < ?",
218        )
219        .bind(&root_node.proof.writer_id)
220        .bind(root_node.snapshot_id)
221        .bind(new_root_node.snapshot_id)
222        .execute(tx.db())
223        .await?;
224
225        Ok(())
226    }
227
228    // Load batch of leaf nodes belonging to the given root node with their block ids translated
229    // from the old ones to the new ones.
230    async fn load_leaf_nodes(
231        r: &mut Reader,
232        root_node: &RootNode,
233        last_locator: &Hash,
234        batch_size: u32,
235    ) -> Result<Vec<LeafNode>, Error> {
236        sqlx::query(
237            "WITH RECURSIVE
238                 inner_nodes(hash) AS (
239                     SELECT i.hash
240                         FROM snapshot_inner_nodes AS i
241                             INNER JOIN snapshot_root_nodes AS r ON r.hash = i.parent
242                         WHERE r.snapshot_id = ?
243                     UNION ALL
244                     SELECT c.hash
245                         FROM snapshot_inner_nodes AS c
246                             INNER JOIN inner_nodes AS p ON p.hash = c.parent
247                 )
248             SELECT l.locator, l.block_presence, t.new_block_id
249                 FROM snapshot_leaf_nodes AS l
250                     INNER JOIN block_id_translations AS t ON t.old_block_id = l.block_id
251                 WHERE l.parent IN inner_nodes AND l.locator > ?
252                 ORDER BY l.locator
253                 LIMIT ?
254             ",
255        )
256        .bind(root_node.snapshot_id)
257        .bind(last_locator)
258        .bind(batch_size)
259        .fetch(r.db())
260        .map_ok(|row| LeafNode {
261            locator: row.get(0),
262            block_id: row.get(2),
263            block_presence: row.get(1),
264        })
265        .err_into()
266        .try_collect()
267        .await
268    }
269}