ouisync/store/
root_node.rs

1use super::error::Error;
2use crate::{
3    crypto::{sign::PublicKey, Hash},
4    db,
5    debug::DebugPrinter,
6    protocol::{
7        BlockId, MultiBlockPresence, NodeState, Proof, RootNode, RootNodeFilter, RootNodeKind,
8        SingleBlockPresence, Summary,
9    },
10    version_vector::VersionVector,
11};
12use futures_util::{Stream, StreamExt, TryStreamExt};
13use sqlx::{sqlite::SqliteRow, FromRow, Row};
14use std::{cmp::Ordering, fmt, future};
15
16/// Status of receiving a root node
17#[derive(PartialEq, Eq, Debug)]
18pub(crate) enum RootNodeStatus {
19    /// The incoming node is more up to date than the nodos we already have. Contains info about
20    /// which part of the node is more up to date and the block presence of the latest node we have
21    /// from the same branch as the incoming node.
22    Updated(RootNodeUpdated, MultiBlockPresence),
23    /// The node is outdated - discard it.
24    Outdated,
25}
26
27#[derive(PartialEq, Eq, Debug)]
28pub(crate) enum RootNodeUpdated {
29    /// The node represents a new snapshot - write it into the store and requests its children.
30    Snapshot,
31    /// The node represents a snapshot we already have but its block presence indicated the peer potentially has some
32    /// blocks we don't have. Don't write it into the store but do request its children.
33    Blocks,
34}
35
36impl RootNodeStatus {
37    pub fn request_children(&self) -> Option<MultiBlockPresence> {
38        match self {
39            Self::Updated(_, block_presence) => Some(*block_presence),
40            Self::Outdated => None,
41        }
42    }
43
44    pub fn write(&self) -> bool {
45        match self {
46            Self::Updated(RootNodeUpdated::Snapshot, _) => true,
47            Self::Updated(RootNodeUpdated::Blocks, _) | Self::Outdated => false,
48        }
49    }
50}
51
52impl fmt::Display for RootNodeStatus {
53    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54        match self {
55            Self::Updated(RootNodeUpdated::Snapshot, _) => write!(f, "new snapshot"),
56            Self::Updated(RootNodeUpdated::Blocks, _) => write!(f, "new blocks"),
57            Self::Outdated => write!(f, "outdated"),
58        }
59    }
60}
61
62/// Creates a root node with the specified proof and summary.
63///
64/// A root node can be either "published" or "draft". A published node is one whose version vector
65/// is strictly greater than the version vector of any previous node in the same branch.
66/// A draft node is one whose version vector is equal to the version vector of the previous node.
67///
68/// The `filter` parameter determines what kind of node can be created.
69///
70/// Attempt to create a node whose version vector is less than or concurrent to the previous one is
71/// not allowed.
72///
73/// Returns also the kind of node (published or draft) that was actually created.
74pub(super) async fn create(
75    tx: &mut db::WriteTransaction,
76    proof: Proof,
77    mut summary: Summary,
78    filter: RootNodeFilter,
79) -> Result<(RootNode, RootNodeKind), Error> {
80    // Check that the root node to be created is newer than the latest existing root node in
81    // the same branch.
82    let old_vv: VersionVector = sqlx::query(
83        "SELECT versions
84         FROM snapshot_root_nodes
85         WHERE snapshot_id = (
86             SELECT MAX(snapshot_id)
87             FROM snapshot_root_nodes
88             WHERE writer_id = ?
89         )",
90    )
91    .bind(&proof.writer_id)
92    .map(|row| row.get(0))
93    .fetch_optional(&mut *tx)
94    .await?
95    .unwrap_or_else(VersionVector::new);
96
97    let kind = match (proof.version_vector.partial_cmp(&old_vv), filter) {
98        (Some(Ordering::Greater), _) => RootNodeKind::Published,
99        (Some(Ordering::Equal), RootNodeFilter::Any) => RootNodeKind::Draft,
100        (Some(Ordering::Equal), RootNodeFilter::Published) => return Err(Error::OutdatedRootNode),
101        (Some(Ordering::Less), _) => return Err(Error::OutdatedRootNode),
102        (None, _) => return Err(Error::ConcurrentRootNode),
103    };
104
105    // Inherit non-incomplete state from existing nodes with the same hash.
106    // (All nodes with the same hash have the same state so it's enough to fetch only the first one)
107    if summary.state == NodeState::Incomplete {
108        let state = sqlx::query(
109            "SELECT state FROM snapshot_root_nodes WHERE hash = ? AND state <> ? LIMIT 1",
110        )
111        .bind(&proof.hash)
112        .bind(NodeState::Incomplete)
113        .fetch_optional(&mut *tx)
114        .await?
115        .map(|row| row.get(0));
116
117        if let Some(state) = state {
118            summary.state = state;
119        }
120    }
121
122    let snapshot_id = sqlx::query(
123        "INSERT INTO snapshot_root_nodes (
124             writer_id,
125             versions,
126             hash,
127             signature,
128             state,
129             block_presence
130         )
131         VALUES (?, ?, ?, ?, ?, ?)
132         RETURNING snapshot_id",
133    )
134    .bind(&proof.writer_id)
135    .bind(&proof.version_vector)
136    .bind(&proof.hash)
137    .bind(&proof.signature)
138    .bind(summary.state)
139    .bind(&summary.block_presence)
140    .map(|row| row.get(0))
141    .fetch_one(tx)
142    .await?;
143
144    let node = RootNode {
145        snapshot_id,
146        proof,
147        summary,
148    };
149
150    Ok((node, kind))
151}
152
153/// Returns the latest approved root node of the specified branch.
154pub(super) async fn load_latest_approved(
155    conn: &mut db::Connection,
156    branch_id: &PublicKey,
157) -> Result<RootNode, Error> {
158    sqlx::query_as(
159        "SELECT
160             snapshot_id,
161             writer_id,
162             versions,
163             hash,
164             signature,
165             state,
166             block_presence
167         FROM
168             snapshot_root_nodes
169         WHERE
170             snapshot_id = (
171                 SELECT MAX(snapshot_id)
172                 FROM snapshot_root_nodes
173                 WHERE writer_id = ? AND state = ?
174             )
175        ",
176    )
177    .bind(branch_id)
178    .bind(NodeState::Approved)
179    .fetch_optional(conn)
180    .await?
181    .ok_or(Error::BranchNotFound)
182}
183
184/// Load the previous approved root node of the same writer.
185pub(super) async fn load_prev_approved(
186    conn: &mut db::Connection,
187    node: &RootNode,
188) -> Result<Option<RootNode>, Error> {
189    sqlx::query_as(
190        "SELECT
191            snapshot_id,
192            writer_id,
193            versions,
194            hash,
195            signature,
196            state,
197            block_presence
198         FROM snapshot_root_nodes
199         WHERE writer_id = ? AND state = ? AND snapshot_id < ?
200         ORDER BY snapshot_id DESC
201         LIMIT 1",
202    )
203    .bind(&node.proof.writer_id)
204    .bind(NodeState::Approved)
205    .bind(node.snapshot_id)
206    .fetch(conn)
207    .err_into()
208    .try_next()
209    .await
210}
211
212/// Return the latest approved root nodes of all known writers.
213pub(super) fn load_all_latest_approved(
214    conn: &mut db::Connection,
215) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
216    sqlx::query_as(
217        "SELECT
218             snapshot_id,
219             writer_id,
220             versions,
221             hash,
222             signature,
223             state,
224             block_presence
225         FROM
226             snapshot_root_nodes
227         WHERE
228             snapshot_id IN (
229                 SELECT MAX(snapshot_id)
230                 FROM snapshot_root_nodes
231                 WHERE state = ?
232                 GROUP BY writer_id
233             )",
234    )
235    .bind(NodeState::Approved)
236    .fetch(conn)
237    .err_into()
238}
239
240/// Return the latest root nodes of all known writers according to the following order of
241/// preferrence: approved, complete, incomplete, rejected. That is, returns the latest approved
242/// node if it exists, otherwise the latest complete, etc...
243pub(super) fn load_all_latest_preferred(
244    conn: &mut db::Connection,
245) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
246    // Partition all root nodes by their writer_id. Then sort each partition according to the
247    // preferrence as described in the above doc comment. Then take the first row from each
248    // partition.
249
250    // TODO: Is this the best way to do this (simple, efficient, etc...)?
251
252    sqlx::query_as(
253        "SELECT
254             snapshot_id,
255             writer_id,
256             versions,
257             hash,
258             signature,
259             state,
260             block_presence
261         FROM (
262             SELECT
263                 *,
264                 ROW_NUMBER() OVER (
265                     PARTITION BY writer_id
266                     ORDER BY
267                         CASE state
268                             WHEN ? THEN 0
269                             WHEN ? THEN 1
270                             WHEN ? THEN 2
271                             WHEN ? THEN 3
272                         END,
273                         snapshot_id DESC
274                 ) AS position
275             FROM snapshot_root_nodes
276         )
277         WHERE position = 1",
278    )
279    .bind(NodeState::Approved)
280    .bind(NodeState::Complete)
281    .bind(NodeState::Incomplete)
282    .bind(NodeState::Rejected)
283    .fetch(conn)
284    .err_into()
285}
286
287/// Return the latest root nodes of all known writers in any state.
288pub(super) fn load_all_latest(
289    conn: &mut db::Connection,
290) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
291    sqlx::query_as(
292        "SELECT
293             snapshot_id,
294             writer_id,
295             versions,
296             hash,
297             signature,
298             state,
299             block_presence
300         FROM
301             snapshot_root_nodes
302         WHERE
303             snapshot_id IN (
304                 SELECT MAX(snapshot_id)
305                 FROM snapshot_root_nodes
306                 GROUP BY writer_id
307             )",
308    )
309    .fetch(conn)
310    .err_into()
311}
312
313/// Returns all nodes with the specified hash
314pub(super) fn load_all_by_hash<'a>(
315    conn: &'a mut db::Connection,
316    hash: &'a Hash,
317) -> impl Stream<Item = Result<RootNode, Error>> + 'a {
318    sqlx::query_as(
319        "SELECT
320             snapshot_id,
321             writer_id,
322             versions,
323             hash,
324             signature,
325             state,
326             block_presence
327         FROM snapshot_root_nodes
328         WHERE hash = ?",
329    )
330    .bind(hash)
331    .fetch(conn)
332    .err_into()
333}
334
335/// Loads the "best" `NodeState` of all the root nodes that reference the given missing block. If
336/// the block is not referenced from any root node or if it's not missing, falls back to returning
337/// `Rejected`.
338pub(super) async fn load_node_state_of_missing(
339    conn: &mut db::Connection,
340    block_id: &BlockId,
341) -> Result<NodeState, Error> {
342    use NodeState as S;
343
344    sqlx::query(
345        "WITH RECURSIVE
346             inner_nodes(parent) AS (
347                 SELECT parent
348                     FROM snapshot_leaf_nodes
349                     WHERE block_id = ? AND block_presence = ?
350                 UNION ALL
351                 SELECT i.parent
352                     FROM snapshot_inner_nodes i INNER JOIN inner_nodes c
353                     WHERE i.hash = c.parent
354             )
355         SELECT state
356             FROM snapshot_root_nodes r INNER JOIN inner_nodes c
357             WHERE r.hash = c.parent",
358    )
359    .bind(block_id)
360    .bind(SingleBlockPresence::Missing)
361    .fetch(conn)
362    .map_ok(|row| row.get(0))
363    .err_into()
364    .try_fold(S::Rejected, |old, new| {
365        let new = match (old, new) {
366            (S::Incomplete | S::Complete | S::Approved | S::Rejected, S::Approved)
367            | (S::Approved, S::Incomplete | S::Complete | S::Rejected) => S::Approved,
368            (S::Incomplete | S::Complete | S::Rejected, S::Complete)
369            | (S::Complete, S::Incomplete | S::Rejected) => S::Complete,
370            (S::Incomplete | S::Rejected, S::Incomplete) | (S::Incomplete, S::Rejected) => {
371                S::Incomplete
372            }
373            (S::Rejected, S::Rejected) => S::Rejected,
374        };
375
376        future::ready(Ok(new))
377    })
378    .await
379}
380
381/// Removes the given root node including all its descendants that are not referenced from any
382/// other root nodes.
383pub(super) async fn remove(tx: &mut db::WriteTransaction, node: &RootNode) -> Result<(), Error> {
384    // This uses db triggers to delete the whole snapshot.
385    sqlx::query("DELETE FROM snapshot_root_nodes WHERE snapshot_id = ?")
386        .bind(node.snapshot_id)
387        .execute(tx)
388        .await?;
389
390    Ok(())
391}
392
393/// Removes all root nodes that are older than the given node and are on the same branch.
394pub(super) async fn remove_older(
395    tx: &mut db::WriteTransaction,
396    node: &RootNode,
397) -> Result<(), Error> {
398    // This uses db triggers to delete the whole snapshot.
399    sqlx::query("DELETE FROM snapshot_root_nodes WHERE snapshot_id < ? AND writer_id = ?")
400        .bind(node.snapshot_id)
401        .bind(&node.proof.writer_id)
402        .execute(tx)
403        .await?;
404
405    Ok(())
406}
407
408/// Removes all root nodes that are older than the given node and are on the same branch and are
409/// not complete.
410pub(super) async fn remove_older_incomplete(
411    tx: &mut db::WriteTransaction,
412    node: &RootNode,
413) -> Result<(), Error> {
414    // This uses db triggers to delete the whole snapshot.
415    sqlx::query(
416        "DELETE FROM snapshot_root_nodes
417         WHERE snapshot_id < ? AND writer_id = ? AND state IN (?, ?)
418         RETURNING hash",
419    )
420    .bind(node.snapshot_id)
421    .bind(&node.proof.writer_id)
422    .bind(NodeState::Incomplete)
423    .bind(NodeState::Rejected)
424    .fetch(tx)
425    .try_for_each(|row| {
426        tracing::trace!(hash = ?row.get::<Hash, _>(0), "outdated incomplete snapshot removed");
427        future::ready(Ok(()))
428    })
429    .await?;
430
431    Ok(())
432}
433
434/// Update the summaries of all nodes with the specified hash.
435pub(super) async fn update_summaries(
436    tx: &mut db::WriteTransaction,
437    hash: &Hash,
438    summary: Summary,
439) -> Result<NodeState, Error> {
440    let state = sqlx::query(
441        "UPDATE snapshot_root_nodes
442         SET block_presence = ?,
443             state = CASE state WHEN ? THEN ? ELSE state END
444         WHERE hash = ?
445         RETURNING state
446         ",
447    )
448    .bind(&summary.block_presence)
449    .bind(NodeState::Incomplete)
450    .bind(summary.state)
451    .bind(hash)
452    .fetch_optional(tx)
453    .await?
454    .map(|row| row.get(0))
455    .unwrap_or(NodeState::Incomplete);
456
457    Ok(state)
458}
459
460/// Check whether the `old` snapshot can serve as a fallback for the `new` snapshot.
461/// A snapshot can serve as a fallback if there is at least one locator that points to a missing
462/// block in `new` but present block in `old`.
463pub(super) async fn check_fallback(
464    conn: &mut db::Connection,
465    old: &RootNode,
466    new: &RootNode,
467) -> Result<bool, Error> {
468    // TODO: verify this query is efficient, especially on large repositories
469
470    Ok(sqlx::query(
471        "WITH RECURSIVE
472             inner_nodes_old(hash) AS (
473                 SELECT i.hash
474                     FROM snapshot_inner_nodes AS i
475                     INNER JOIN snapshot_root_nodes AS r ON r.hash = i.parent
476                     WHERE r.snapshot_id = ?
477                 UNION ALL
478                 SELECT c.hash
479                     FROM snapshot_inner_nodes AS c
480                     INNER JOIN inner_nodes_old AS p ON p.hash = c.parent
481             ),
482             inner_nodes_new(hash) AS (
483                 SELECT i.hash
484                     FROM snapshot_inner_nodes AS i
485                     INNER JOIN snapshot_root_nodes AS r ON r.hash = i.parent
486                     WHERE r.snapshot_id = ?
487                 UNION ALL
488                 SELECT c.hash
489                     FROM snapshot_inner_nodes AS c
490                     INNER JOIN inner_nodes_new AS p ON p.hash = c.parent
491             )
492         SELECT locator
493             FROM snapshot_leaf_nodes
494             WHERE block_presence = ? AND parent IN inner_nodes_old
495         INTERSECT
496         SELECT locator
497             FROM snapshot_leaf_nodes
498             WHERE block_presence = ? AND parent IN inner_nodes_new
499         LIMIT 1",
500    )
501    .bind(old.snapshot_id)
502    .bind(new.snapshot_id)
503    .bind(SingleBlockPresence::Present)
504    .bind(SingleBlockPresence::Missing)
505    .fetch_optional(conn)
506    .await?
507    .is_some())
508}
509
510/// Approve the nodes with the specified hash.
511pub(super) fn approve<'a>(
512    tx: &'a mut db::WriteTransaction,
513    hash: &'a Hash,
514) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
515    set_state(tx, hash, NodeState::Approved)
516}
517
518/// Reject the nodes with the specified hash.
519pub(super) fn reject<'a>(
520    tx: &'a mut db::WriteTransaction,
521    hash: &'a Hash,
522) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
523    set_state(tx, hash, NodeState::Rejected)
524}
525
526/// Set the state of the nodes with the specified hash and returns the writer ids of the updated
527/// nodes.
528fn set_state<'a>(
529    tx: &'a mut db::WriteTransaction,
530    hash: &'a Hash,
531    state: NodeState,
532) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
533    sqlx::query("UPDATE snapshot_root_nodes SET state = ? WHERE hash = ? RETURNING writer_id")
534        .bind(state)
535        .bind(hash)
536        .fetch(tx)
537        .map_ok(|row| row.get(0))
538        .err_into()
539}
540
541/// Returns all writer ids.
542pub(super) fn load_writer_ids(
543    conn: &mut db::Connection,
544) -> impl Stream<Item = Result<PublicKey, Error>> + '_ {
545    sqlx::query("SELECT DISTINCT writer_id FROM snapshot_root_nodes")
546        .fetch(conn)
547        .map_ok(|row| row.get(0))
548        .err_into()
549}
550
551/// Returns the writer ids of the nodes with the specified hash.
552pub(super) fn load_writer_ids_by_hash<'a>(
553    conn: &'a mut db::Connection,
554    hash: &'a Hash,
555) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
556    sqlx::query("SELECT DISTINCT writer_id FROM snapshot_root_nodes WHERE hash = ?")
557        .bind(hash)
558        .fetch(conn)
559        .map_ok(|row| row.get(0))
560        .err_into()
561}
562
563pub(super) async fn status(
564    conn: &mut db::Connection,
565    new_proof: &Proof,
566    new_block_presence: &MultiBlockPresence,
567) -> Result<RootNodeStatus, Error> {
568    let mut updated = RootNodeUpdated::Snapshot;
569    let mut block_presence = MultiBlockPresence::None;
570
571    let mut old_nodes = load_all_latest(conn);
572    while let Some(old_node) = old_nodes.try_next().await? {
573        match new_proof
574            .version_vector
575            .partial_cmp(&old_node.proof.version_vector)
576        {
577            Some(Ordering::Less) => {
578                // The incoming node is outdated compared to at least one existing node - discard
579                // it.
580                return Ok(RootNodeStatus::Outdated);
581            }
582            Some(Ordering::Equal) => {
583                if new_proof.hash == old_node.proof.hash {
584                    // The incoming node has the same version vector and the same hash as one of
585                    // the existing nodes which means its effectively the same node (except
586                    // possibly in a different branch). There is no point inserting it but if the
587                    // incoming summary is potentially more up-to-date than the exising one, we
588                    // still want to request the children. Otherwise we discard it.
589                    if old_node
590                        .summary
591                        .block_presence
592                        .is_outdated(new_block_presence)
593                    {
594                        updated = RootNodeUpdated::Blocks;
595                    } else {
596                        return Ok(RootNodeStatus::Outdated);
597                    }
598                } else {
599                    tracing::warn!(
600                        vv = ?old_node.proof.version_vector,
601                        old_writer_id = ?old_node.proof.writer_id,
602                        new_writer_id = ?new_proof.writer_id,
603                        old_hash = ?old_node.proof.hash,
604                        new_hash = ?new_proof.hash,
605                        "Received root node invalid - broken invariant: same vv but different hash"
606                    );
607
608                    return Ok(RootNodeStatus::Outdated);
609                }
610            }
611            Some(Ordering::Greater) => (),
612            None => {
613                if new_proof.writer_id == old_node.proof.writer_id {
614                    tracing::warn!(
615                        old_vv = ?old_node.proof.version_vector,
616                        new_vv = ?new_proof.version_vector,
617                        writer_id = ?new_proof.writer_id,
618                        "Received root node invalid - broken invariant: concurrency within branch is not allowed"
619                    );
620
621                    return Ok(RootNodeStatus::Outdated);
622                }
623            }
624        }
625
626        if old_node.proof.writer_id == new_proof.writer_id {
627            block_presence = old_node.summary.block_presence;
628        }
629    }
630
631    Ok(RootNodeStatus::Updated(updated, block_presence))
632}
633
634pub(super) async fn debug_print(conn: &mut db::Connection, printer: DebugPrinter) {
635    let mut roots = sqlx::query_as::<_, RootNode>(
636        "SELECT
637             snapshot_id,
638             writer_id,
639             versions,
640             hash,
641             signature,
642             state,
643             block_presence
644         FROM snapshot_root_nodes
645         ORDER BY snapshot_id DESC",
646    )
647    .fetch(conn);
648
649    while let Some(root_node) = roots.next().await {
650        match root_node {
651            Ok(root_node) => {
652                printer.debug(&format_args!(
653                    "RootNode: snapshot_id:{:?}, writer_id:{:?}, vv:{:?}, state:{:?}",
654                    root_node.snapshot_id,
655                    root_node.proof.writer_id,
656                    root_node.proof.version_vector,
657                    root_node.summary.state
658                ));
659            }
660            Err(err) => {
661                printer.debug(&format_args!("RootNode: error: {:?}", err));
662            }
663        }
664    }
665}
666
667/// Returns a stream of all root nodes corresponding to the specified writer ordered from the
668/// most recent to the least recent.
669#[cfg(test)]
670pub(super) fn load_all_by_writer<'a>(
671    conn: &'a mut db::Connection,
672    writer_id: &'a PublicKey,
673) -> impl Stream<Item = Result<RootNode, Error>> + 'a {
674    sqlx::query_as(
675        "SELECT
676             snapshot_id,
677             writer_id,
678             versions,
679             hash,
680             signature,
681             state,
682             block_presence
683         FROM snapshot_root_nodes
684         WHERE writer_id = ?
685         ORDER BY snapshot_id DESC",
686    )
687    .bind(writer_id) // needed to satisfy the borrow checker.
688    .fetch(conn)
689    .err_into()
690}
691
692impl FromRow<'_, SqliteRow> for RootNode {
693    fn from_row(row: &SqliteRow) -> Result<Self, sqlx::Error> {
694        Ok(RootNode {
695            snapshot_id: row.try_get(0)?,
696            proof: Proof::new_unchecked(
697                row.try_get(1)?,
698                row.try_get(2)?,
699                row.try_get(3)?,
700                row.try_get(4)?,
701            ),
702            summary: Summary {
703                state: row.try_get(5)?,
704                block_presence: row.try_get(6)?,
705            },
706        })
707    }
708}
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713    use crate::crypto::sign::Keypair;
714    use assert_matches::assert_matches;
715    use tempfile::TempDir;
716
717    #[tokio::test]
718    async fn create_new() {
719        let (_base_dir, pool) = setup().await;
720
721        let writer_id = PublicKey::random();
722        let write_keys = Keypair::random();
723        let hash = rand::random();
724
725        let mut tx = pool.begin_write().await.unwrap();
726
727        let (node0, _) = create(
728            &mut tx,
729            Proof::new(
730                writer_id,
731                VersionVector::first(writer_id),
732                hash,
733                &write_keys,
734            ),
735            Summary::INCOMPLETE,
736            RootNodeFilter::Any,
737        )
738        .await
739        .unwrap();
740        assert_eq!(node0.proof.hash, hash);
741
742        let nodes: Vec<_> = load_all_by_writer(&mut tx, &writer_id)
743            .try_collect()
744            .await
745            .unwrap();
746        assert_eq!(nodes.len(), 1);
747        assert_eq!(nodes[0], node0);
748    }
749
750    #[tokio::test]
751    async fn create_draft() {
752        let (_base_dir, pool) = setup().await;
753
754        let writer_id = PublicKey::random();
755        let write_keys = Keypair::random();
756
757        let mut tx = pool.begin_write().await.unwrap();
758
759        let (node0, kind) = create(
760            &mut tx,
761            Proof::new(
762                writer_id,
763                VersionVector::first(writer_id),
764                rand::random(),
765                &write_keys,
766            ),
767            Summary::INCOMPLETE,
768            RootNodeFilter::Any,
769        )
770        .await
771        .unwrap();
772        assert_eq!(kind, RootNodeKind::Published);
773
774        let (_node1, kind) = create(
775            &mut tx,
776            Proof::new(
777                writer_id,
778                node0.proof.version_vector.clone(),
779                rand::random(),
780                &write_keys,
781            ),
782            Summary::INCOMPLETE,
783            RootNodeFilter::Any,
784        )
785        .await
786        .unwrap();
787        assert_eq!(kind, RootNodeKind::Draft);
788    }
789
790    #[tokio::test]
791    async fn attempt_to_create_outdated_node() {
792        let (_base_dir, pool) = setup().await;
793
794        let writer_id = PublicKey::random();
795        let write_keys = Keypair::random();
796        let hash = rand::random();
797
798        let mut tx = pool.begin_write().await.unwrap();
799
800        let vv0 = VersionVector::first(writer_id);
801        let vv1 = vv0.clone().incremented(writer_id);
802
803        create(
804            &mut tx,
805            Proof::new(writer_id, vv1.clone(), hash, &write_keys),
806            Summary::INCOMPLETE,
807            RootNodeFilter::Any,
808        )
809        .await
810        .unwrap();
811
812        // Same vv
813        assert_matches!(
814            create(
815                &mut tx,
816                Proof::new(writer_id, vv1, hash, &write_keys),
817                Summary::INCOMPLETE,
818                RootNodeFilter::Published, // With `RootNodeFilter::Any` this would be allowed
819            )
820            .await,
821            Err(Error::OutdatedRootNode)
822        );
823
824        // Old vv
825        assert_matches!(
826            create(
827                &mut tx,
828                Proof::new(writer_id, vv0, hash, &write_keys),
829                Summary::INCOMPLETE,
830                RootNodeFilter::Any,
831            )
832            .await,
833            Err(Error::OutdatedRootNode)
834        );
835    }
836
837    // Proptest for the `load_all_latest_preferred` function.
838    mod load_all_latest_preferred {
839        use super::*;
840        use crate::protocol::SnapshotId;
841        use proptest::{arbitrary::any, collection::vec, sample::select, strategy::Strategy};
842        use test_strategy::proptest;
843
844        #[proptest(async = "tokio")]
845        async fn proptest(
846            write_keys: Keypair,
847            #[strategy(root_node_params_strategy())] input: Vec<(
848                SnapshotId,
849                PublicKey,
850                Hash,
851                NodeState,
852            )>,
853        ) {
854            case(write_keys, input).await
855        }
856
857        async fn case(write_keys: Keypair, input: Vec<(SnapshotId, PublicKey, Hash, NodeState)>) {
858            let (_base_dir, pool) = setup().await;
859
860            let mut writer_ids: Vec<_> = input
861                .iter()
862                .map(|(_, writer_id, _, _)| *writer_id)
863                .collect();
864            writer_ids.sort();
865            writer_ids.dedup();
866
867            let mut expected: Vec<_> = writer_ids
868                .into_iter()
869                .filter_map(|this_writer_id| {
870                    input
871                        .iter()
872                        .filter(|(_, that_writer_id, _, _)| *that_writer_id == this_writer_id)
873                        .map(|(snapshot_id, _, _, state)| (*snapshot_id, *state))
874                        .max_by_key(|(snapshot_id, state)| {
875                            (
876                                match state {
877                                    NodeState::Approved => 3,
878                                    NodeState::Complete => 2,
879                                    NodeState::Incomplete => 1,
880                                    NodeState::Rejected => 0,
881                                },
882                                *snapshot_id,
883                            )
884                        })
885                        .map(|(snapshot_id, state)| (this_writer_id, snapshot_id, state))
886                })
887                .collect();
888            expected.sort_by_key(|(writer_id, _, _)| *writer_id);
889
890            let mut vv = VersionVector::default();
891            let mut tx = pool.begin_write().await.unwrap();
892
893            for (expected_snapshot_id, writer_id, hash, state) in input {
894                vv.increment(writer_id);
895
896                let (node, _) = create(
897                    &mut tx,
898                    Proof::new(writer_id, vv.clone(), hash, &write_keys),
899                    Summary {
900                        state,
901                        block_presence: MultiBlockPresence::None,
902                    },
903                    RootNodeFilter::Any,
904                )
905                .await
906                .unwrap();
907
908                assert_eq!(node.snapshot_id, expected_snapshot_id);
909            }
910
911            let mut actual: Vec<_> = load_all_latest_preferred(&mut tx)
912                .map_ok(|node| (node.proof.writer_id, node.snapshot_id, node.summary.state))
913                .try_collect()
914                .await
915                .unwrap();
916            actual.sort_by_key(|(writer_id, _, _)| *writer_id);
917
918            assert_eq!(actual, expected);
919
920            drop(tx);
921            pool.close().await.unwrap();
922        }
923
924        fn root_node_params_strategy(
925        ) -> impl Strategy<Value = Vec<(SnapshotId, PublicKey, Hash, NodeState)>> {
926            vec(any::<PublicKey>(), 1..=3)
927                .prop_flat_map(|writer_ids| {
928                    vec(
929                        (select(writer_ids), any::<Hash>(), any::<NodeState>()),
930                        0..=32,
931                    )
932                })
933                .prop_map(|params| {
934                    params
935                        .into_iter()
936                        .enumerate()
937                        .map(|(index, (writer_id, hash, state))| {
938                            ((index + 1) as u32, writer_id, hash, state)
939                        })
940                        .collect()
941                })
942        }
943    }
944
945    async fn setup() -> (TempDir, db::Pool) {
946        db::create_temp().await.unwrap()
947    }
948}