ouisync/store/
root_node.rs

1use super::error::Error;
2use crate::{
3    crypto::{Hash, sign::PublicKey},
4    db,
5    debug::DebugPrinter,
6    protocol::{
7        BlockId, MultiBlockPresence, NodeState, Proof, RootNode, RootNodeFilter, RootNodeKind,
8        SingleBlockPresence, Summary,
9    },
10    version_vector::VersionVector,
11};
12use futures_util::{Stream, StreamExt, TryStreamExt};
13use sqlx::{FromRow, Row, sqlite::SqliteRow};
14use std::{cmp::Ordering, fmt, future};
15
16/// Status of receiving a root node
17#[derive(PartialEq, Eq, Debug)]
18pub(crate) enum RootNodeStatus {
19    /// The incoming node is more up to date than the nodos we already have. Contains info about
20    /// which part of the node is more up to date and the block presence of the latest node we have
21    /// from the same branch as the incoming node.
22    Updated(RootNodeUpdated, MultiBlockPresence),
23    /// The node is outdated - discard it.
24    Outdated,
25}
26
27#[derive(PartialEq, Eq, Debug)]
28pub(crate) enum RootNodeUpdated {
29    /// The node represents a new snapshot - write it into the store and requests its children.
30    Snapshot,
31    /// The node represents a snapshot we already have but its block presence indicated the peer potentially has some
32    /// blocks we don't have. Don't write it into the store but do request its children.
33    Blocks,
34}
35
36impl RootNodeStatus {
37    pub fn request_children(&self) -> Option<MultiBlockPresence> {
38        match self {
39            Self::Updated(_, block_presence) => Some(*block_presence),
40            Self::Outdated => None,
41        }
42    }
43
44    pub fn write(&self) -> bool {
45        match self {
46            Self::Updated(RootNodeUpdated::Snapshot, _) => true,
47            Self::Updated(RootNodeUpdated::Blocks, _) | Self::Outdated => false,
48        }
49    }
50}
51
52impl fmt::Display for RootNodeStatus {
53    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54        match self {
55            Self::Updated(RootNodeUpdated::Snapshot, _) => write!(f, "new snapshot"),
56            Self::Updated(RootNodeUpdated::Blocks, _) => write!(f, "new blocks"),
57            Self::Outdated => write!(f, "outdated"),
58        }
59    }
60}
61
62/// Creates a root node with the specified proof and summary.
63///
64/// A root node can be either "published" or "draft". A published node is one whose version vector
65/// is strictly greater than the version vector of any previous node in the same branch.
66/// A draft node is one whose version vector is equal to the version vector of the previous node.
67///
68/// The `filter` parameter determines what kind of node can be created.
69///
70/// Attempt to create a node whose version vector is less than or concurrent to the previous one is
71/// not allowed.
72///
73/// Returns also the kind of node (published or draft) that was actually created.
74pub(super) async fn create(
75    tx: &mut db::WriteTransaction,
76    proof: Proof,
77    mut summary: Summary,
78    filter: RootNodeFilter,
79) -> Result<(RootNode, RootNodeKind), Error> {
80    // Check that the root node to be created is newer than the latest existing root node in
81    // the same branch.
82    let old_vv: VersionVector = sqlx::query(
83        "SELECT versions
84         FROM snapshot_root_nodes
85         WHERE snapshot_id = (
86             SELECT MAX(snapshot_id)
87             FROM snapshot_root_nodes
88             WHERE writer_id = ?
89         )",
90    )
91    .bind(&proof.writer_id)
92    .map(|row| row.get(0))
93    .fetch_optional(&mut *tx)
94    .await?
95    .unwrap_or_else(VersionVector::new);
96
97    let kind = match (proof.version_vector.partial_cmp(&old_vv), filter) {
98        (Some(Ordering::Greater), _) => RootNodeKind::Published,
99        (Some(Ordering::Equal), RootNodeFilter::Any) => RootNodeKind::Draft,
100        (Some(Ordering::Equal), RootNodeFilter::Published) => return Err(Error::OutdatedRootNode),
101        (Some(Ordering::Less), _) => return Err(Error::OutdatedRootNode),
102        (None, _) => return Err(Error::ConcurrentRootNode),
103    };
104
105    // Inherit non-incomplete state from existing nodes with the same hash.
106    // (All nodes with the same hash have the same state so it's enough to fetch only the first one)
107    if summary.state == NodeState::Incomplete {
108        let state = sqlx::query(
109            "SELECT state FROM snapshot_root_nodes WHERE hash = ? AND state <> ? LIMIT 1",
110        )
111        .bind(&proof.hash)
112        .bind(NodeState::Incomplete)
113        .fetch_optional(&mut *tx)
114        .await?
115        .map(|row| row.get(0));
116
117        if let Some(state) = state {
118            summary.state = state;
119        }
120    }
121
122    let snapshot_id = sqlx::query(
123        "INSERT INTO snapshot_root_nodes (
124             writer_id,
125             versions,
126             hash,
127             signature,
128             state,
129             block_presence
130         )
131         VALUES (?, ?, ?, ?, ?, ?)
132         RETURNING snapshot_id",
133    )
134    .bind(&proof.writer_id)
135    .bind(&proof.version_vector)
136    .bind(&proof.hash)
137    .bind(&proof.signature)
138    .bind(summary.state)
139    .bind(&summary.block_presence)
140    .map(|row| row.get(0))
141    .fetch_one(tx)
142    .await?;
143
144    let node = RootNode {
145        snapshot_id,
146        proof,
147        summary,
148    };
149
150    Ok((node, kind))
151}
152
153/// Returns the latest approved root node of the specified branch.
154pub(super) async fn load_latest_approved(
155    conn: &mut db::Connection,
156    branch_id: &PublicKey,
157) -> Result<RootNode, Error> {
158    sqlx::query_as(
159        "SELECT
160             snapshot_id,
161             writer_id,
162             versions,
163             hash,
164             signature,
165             state,
166             block_presence
167         FROM
168             snapshot_root_nodes
169         WHERE
170             snapshot_id = (
171                 SELECT MAX(snapshot_id)
172                 FROM snapshot_root_nodes
173                 WHERE writer_id = ? AND state = ?
174             )
175        ",
176    )
177    .bind(branch_id)
178    .bind(NodeState::Approved)
179    .fetch_optional(conn)
180    .await?
181    .ok_or(Error::BranchNotFound)
182}
183
184/// Load the previous approved root node of the same writer.
185pub(super) async fn load_prev_approved(
186    conn: &mut db::Connection,
187    node: &RootNode,
188) -> Result<Option<RootNode>, Error> {
189    sqlx::query_as(
190        "SELECT
191            snapshot_id,
192            writer_id,
193            versions,
194            hash,
195            signature,
196            state,
197            block_presence
198         FROM snapshot_root_nodes
199         WHERE writer_id = ? AND state = ? AND snapshot_id < ?
200         ORDER BY snapshot_id DESC
201         LIMIT 1",
202    )
203    .bind(&node.proof.writer_id)
204    .bind(NodeState::Approved)
205    .bind(node.snapshot_id)
206    .fetch(conn)
207    .err_into()
208    .try_next()
209    .await
210}
211
212/// Return the latest approved root nodes of all known writers.
213pub(super) fn load_all_latest_approved(
214    conn: &mut db::Connection,
215) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
216    sqlx::query_as(
217        "SELECT
218             snapshot_id,
219             writer_id,
220             versions,
221             hash,
222             signature,
223             state,
224             block_presence
225         FROM
226             snapshot_root_nodes
227         WHERE
228             snapshot_id IN (
229                 SELECT MAX(snapshot_id)
230                 FROM snapshot_root_nodes
231                 WHERE state = ?
232                 GROUP BY writer_id
233             )",
234    )
235    .bind(NodeState::Approved)
236    .fetch(conn)
237    .err_into()
238}
239
240/// Return the latest root nodes of all known writers according to the following order of
241/// preferrence: approved, complete, incomplete, rejected. That is, returns the latest approved
242/// node if it exists, otherwise the latest complete, etc...
243pub(super) fn load_all_latest_preferred(
244    conn: &mut db::Connection,
245) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
246    // Partition all root nodes by their writer_id. Then sort each partition according to the
247    // preferrence as described in the above doc comment. Then take the first row from each
248    // partition.
249
250    // TODO: Is this the best way to do this (simple, efficient, etc...)?
251
252    sqlx::query_as(
253        "SELECT
254             snapshot_id,
255             writer_id,
256             versions,
257             hash,
258             signature,
259             state,
260             block_presence
261         FROM (
262             SELECT
263                 *,
264                 ROW_NUMBER() OVER (
265                     PARTITION BY writer_id
266                     ORDER BY
267                         CASE state
268                             WHEN ? THEN 0
269                             WHEN ? THEN 1
270                             WHEN ? THEN 2
271                             WHEN ? THEN 3
272                         END,
273                         snapshot_id DESC
274                 ) AS position
275             FROM snapshot_root_nodes
276         )
277         WHERE position = 1",
278    )
279    .bind(NodeState::Approved)
280    .bind(NodeState::Complete)
281    .bind(NodeState::Incomplete)
282    .bind(NodeState::Rejected)
283    .fetch(conn)
284    .err_into()
285}
286
287/// Return the latest root nodes of all known writers in any state.
288pub(super) fn load_all_latest(
289    conn: &mut db::Connection,
290) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
291    sqlx::query_as(
292        "SELECT
293             snapshot_id,
294             writer_id,
295             versions,
296             hash,
297             signature,
298             state,
299             block_presence
300         FROM
301             snapshot_root_nodes
302         WHERE
303             snapshot_id IN (
304                 SELECT MAX(snapshot_id)
305                 FROM snapshot_root_nodes
306                 GROUP BY writer_id
307             )",
308    )
309    .fetch(conn)
310    .err_into()
311}
312
313/// Returns all nodes with the specified hash
314pub(super) fn load_all_by_hash<'a>(
315    conn: &'a mut db::Connection,
316    hash: &'a Hash,
317) -> impl Stream<Item = Result<RootNode, Error>> + 'a {
318    sqlx::query_as(
319        "SELECT
320             snapshot_id,
321             writer_id,
322             versions,
323             hash,
324             signature,
325             state,
326             block_presence
327         FROM snapshot_root_nodes
328         WHERE hash = ?",
329    )
330    .bind(hash)
331    .fetch(conn)
332    .err_into()
333}
334
335/// Loads the "best" `NodeState` of all the root nodes that reference the given missing block. If
336/// the block is not referenced from any root node or if it's not missing, falls back to returning
337/// `Rejected`.
338pub(super) async fn load_node_state_of_missing(
339    conn: &mut db::Connection,
340    block_id: &BlockId,
341) -> Result<NodeState, Error> {
342    use NodeState as S;
343
344    sqlx::query(
345        "WITH RECURSIVE
346             inner_nodes(parent) AS (
347                 SELECT parent
348                     FROM snapshot_leaf_nodes
349                     WHERE block_id = ? AND block_presence = ?
350                 UNION ALL
351                 SELECT i.parent
352                     FROM snapshot_inner_nodes i INNER JOIN inner_nodes c
353                     WHERE i.hash = c.parent
354             )
355         SELECT state
356             FROM snapshot_root_nodes r INNER JOIN inner_nodes c
357             WHERE r.hash = c.parent",
358    )
359    .bind(block_id)
360    .bind(SingleBlockPresence::Missing)
361    .fetch(conn)
362    .map_ok(|row| row.get(0))
363    .err_into()
364    .try_fold(S::Rejected, |old, new| {
365        let new = match (old, new) {
366            (S::Incomplete | S::Complete | S::Approved | S::Rejected, S::Approved)
367            | (S::Approved, S::Incomplete | S::Complete | S::Rejected) => S::Approved,
368            (S::Incomplete | S::Complete | S::Rejected, S::Complete)
369            | (S::Complete, S::Incomplete | S::Rejected) => S::Complete,
370            (S::Incomplete | S::Rejected, S::Incomplete) | (S::Incomplete, S::Rejected) => {
371                S::Incomplete
372            }
373            (S::Rejected, S::Rejected) => S::Rejected,
374        };
375
376        future::ready(Ok(new))
377    })
378    .await
379}
380
381/// Removes the given root node including all its descendants that are not referenced from any
382/// other root nodes.
383pub(super) async fn remove(tx: &mut db::WriteTransaction, node: &RootNode) -> Result<(), Error> {
384    // This uses db triggers to delete the whole snapshot.
385    sqlx::query("DELETE FROM snapshot_root_nodes WHERE snapshot_id = ?")
386        .bind(node.snapshot_id)
387        .execute(tx)
388        .await?;
389
390    Ok(())
391}
392
393/// Removes all root nodes that are older than the given node and are on the same branch.
394pub(super) async fn remove_older(
395    tx: &mut db::WriteTransaction,
396    node: &RootNode,
397) -> Result<(), Error> {
398    // This uses db triggers to delete the whole snapshot.
399    sqlx::query("DELETE FROM snapshot_root_nodes WHERE snapshot_id < ? AND writer_id = ?")
400        .bind(node.snapshot_id)
401        .bind(&node.proof.writer_id)
402        .execute(tx)
403        .await?;
404
405    Ok(())
406}
407
408/// Removes all root nodes that are older than the given node and are on the same branch and are
409/// not complete.
410///
411/// Returns whether any node was removed.
412pub(super) async fn remove_older_incomplete(
413    tx: &mut db::WriteTransaction,
414    node: &RootNode,
415) -> Result<bool, Error> {
416    let mut changed = false;
417
418    // This uses db triggers to delete the whole snapshot.
419    sqlx::query(
420        "DELETE FROM snapshot_root_nodes
421         WHERE snapshot_id < ? AND writer_id = ? AND state IN (?, ?)
422         RETURNING hash",
423    )
424    .bind(node.snapshot_id)
425    .bind(&node.proof.writer_id)
426    .bind(NodeState::Incomplete)
427    .bind(NodeState::Rejected)
428    .fetch(tx)
429    .try_for_each(|row| {
430        tracing::trace!(hash = ?row.get::<Hash, _>(0), "outdated incomplete snapshot removed");
431        changed = true;
432        future::ready(Ok(()))
433    })
434    .await?;
435
436    Ok(changed)
437}
438
439/// Update the summaries of all nodes with the specified hash.
440pub(super) async fn update_summaries(
441    tx: &mut db::WriteTransaction,
442    hash: &Hash,
443    summary: Summary,
444) -> Result<NodeState, Error> {
445    let state = sqlx::query(
446        "UPDATE snapshot_root_nodes
447         SET block_presence = ?,
448             state = CASE state WHEN ? THEN ? ELSE state END
449         WHERE hash = ?
450         RETURNING state
451         ",
452    )
453    .bind(&summary.block_presence)
454    .bind(NodeState::Incomplete)
455    .bind(summary.state)
456    .bind(hash)
457    .fetch_optional(tx)
458    .await?
459    .map(|row| row.get(0))
460    .unwrap_or(NodeState::Incomplete);
461
462    Ok(state)
463}
464
465/// Check whether the `old` snapshot can serve as a fallback for the `new` snapshot.
466/// A snapshot can serve as a fallback if there is at least one locator that points to a missing
467/// block in `new` but present block in `old`.
468pub(super) async fn check_fallback(
469    conn: &mut db::Connection,
470    old: &RootNode,
471    new: &RootNode,
472) -> Result<bool, Error> {
473    // TODO: verify this query is efficient, especially on large repositories
474
475    Ok(sqlx::query(
476        "WITH RECURSIVE
477             inner_nodes_old(hash) AS (
478                 SELECT i.hash
479                     FROM snapshot_inner_nodes AS i
480                     INNER JOIN snapshot_root_nodes AS r ON r.hash = i.parent
481                     WHERE r.snapshot_id = ?
482                 UNION ALL
483                 SELECT c.hash
484                     FROM snapshot_inner_nodes AS c
485                     INNER JOIN inner_nodes_old AS p ON p.hash = c.parent
486             ),
487             inner_nodes_new(hash) AS (
488                 SELECT i.hash
489                     FROM snapshot_inner_nodes AS i
490                     INNER JOIN snapshot_root_nodes AS r ON r.hash = i.parent
491                     WHERE r.snapshot_id = ?
492                 UNION ALL
493                 SELECT c.hash
494                     FROM snapshot_inner_nodes AS c
495                     INNER JOIN inner_nodes_new AS p ON p.hash = c.parent
496             )
497         SELECT locator
498             FROM snapshot_leaf_nodes
499             WHERE block_presence = ? AND parent IN inner_nodes_old
500         INTERSECT
501         SELECT locator
502             FROM snapshot_leaf_nodes
503             WHERE block_presence = ? AND parent IN inner_nodes_new
504         LIMIT 1",
505    )
506    .bind(old.snapshot_id)
507    .bind(new.snapshot_id)
508    .bind(SingleBlockPresence::Present)
509    .bind(SingleBlockPresence::Missing)
510    .fetch_optional(conn)
511    .await?
512    .is_some())
513}
514
515/// Approve the nodes with the specified hash.
516pub(super) fn approve<'a>(
517    tx: &'a mut db::WriteTransaction,
518    hash: &'a Hash,
519) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
520    set_state(tx, hash, NodeState::Approved)
521}
522
523/// Reject the nodes with the specified hash.
524pub(super) fn reject<'a>(
525    tx: &'a mut db::WriteTransaction,
526    hash: &'a Hash,
527) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
528    set_state(tx, hash, NodeState::Rejected)
529}
530
531/// Set the state of the nodes with the specified hash and returns the writer ids of the updated
532/// nodes.
533fn set_state<'a>(
534    tx: &'a mut db::WriteTransaction,
535    hash: &'a Hash,
536    state: NodeState,
537) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
538    sqlx::query("UPDATE snapshot_root_nodes SET state = ? WHERE hash = ? RETURNING writer_id")
539        .bind(state)
540        .bind(hash)
541        .fetch(tx)
542        .map_ok(|row| row.get(0))
543        .err_into()
544}
545
546/// Returns all writer ids.
547pub(super) fn load_writer_ids(
548    conn: &mut db::Connection,
549) -> impl Stream<Item = Result<PublicKey, Error>> + '_ {
550    sqlx::query("SELECT DISTINCT writer_id FROM snapshot_root_nodes")
551        .fetch(conn)
552        .map_ok(|row| row.get(0))
553        .err_into()
554}
555
556/// Returns the writer ids of the nodes with the specified hash.
557pub(super) fn load_writer_ids_by_hash<'a>(
558    conn: &'a mut db::Connection,
559    hash: &'a Hash,
560) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
561    sqlx::query("SELECT DISTINCT writer_id FROM snapshot_root_nodes WHERE hash = ?")
562        .bind(hash)
563        .fetch(conn)
564        .map_ok(|row| row.get(0))
565        .err_into()
566}
567
568pub(super) async fn status(
569    conn: &mut db::Connection,
570    new_proof: &Proof,
571    new_block_presence: &MultiBlockPresence,
572) -> Result<RootNodeStatus, Error> {
573    let mut updated = RootNodeUpdated::Snapshot;
574    let mut block_presence = MultiBlockPresence::None;
575
576    let mut old_nodes = load_all_latest(conn);
577    while let Some(old_node) = old_nodes.try_next().await? {
578        match new_proof
579            .version_vector
580            .partial_cmp(&old_node.proof.version_vector)
581        {
582            Some(Ordering::Less) => {
583                // The incoming node is outdated compared to at least one existing node - discard
584                // it.
585                return Ok(RootNodeStatus::Outdated);
586            }
587            Some(Ordering::Equal) => {
588                if new_proof.hash == old_node.proof.hash {
589                    // The incoming node has the same version vector and the same hash as one of
590                    // the existing nodes which means its effectively the same node (except
591                    // possibly in a different branch). There is no point inserting it but if the
592                    // incoming summary is potentially more up-to-date than the exising one, we
593                    // still want to request the children. Otherwise we discard it.
594                    if old_node
595                        .summary
596                        .block_presence
597                        .is_outdated(new_block_presence)
598                    {
599                        updated = RootNodeUpdated::Blocks;
600                    } else {
601                        return Ok(RootNodeStatus::Outdated);
602                    }
603                } else {
604                    tracing::warn!(
605                        vv = ?old_node.proof.version_vector,
606                        old_writer_id = ?old_node.proof.writer_id,
607                        new_writer_id = ?new_proof.writer_id,
608                        old_hash = ?old_node.proof.hash,
609                        new_hash = ?new_proof.hash,
610                        "Received root node invalid - broken invariant: same vv but different hash"
611                    );
612
613                    return Ok(RootNodeStatus::Outdated);
614                }
615            }
616            Some(Ordering::Greater) => (),
617            None => {
618                if new_proof.writer_id == old_node.proof.writer_id {
619                    tracing::warn!(
620                        old_vv = ?old_node.proof.version_vector,
621                        new_vv = ?new_proof.version_vector,
622                        writer_id = ?new_proof.writer_id,
623                        "Received root node invalid - broken invariant: concurrency within branch is not allowed"
624                    );
625
626                    return Ok(RootNodeStatus::Outdated);
627                }
628            }
629        }
630
631        if old_node.proof.writer_id == new_proof.writer_id {
632            block_presence = old_node.summary.block_presence;
633        }
634    }
635
636    Ok(RootNodeStatus::Updated(updated, block_presence))
637}
638
639pub(super) async fn debug_print(conn: &mut db::Connection, printer: DebugPrinter) {
640    let mut roots = sqlx::query_as::<_, RootNode>(
641        "SELECT
642             snapshot_id,
643             writer_id,
644             versions,
645             hash,
646             signature,
647             state,
648             block_presence
649         FROM snapshot_root_nodes
650         ORDER BY snapshot_id DESC",
651    )
652    .fetch(conn);
653
654    while let Some(root_node) = roots.next().await {
655        match root_node {
656            Ok(root_node) => {
657                printer.debug(&format_args!(
658                    "RootNode: snapshot_id:{:?}, writer_id:{:?}, vv:{:?}, state:{:?}",
659                    root_node.snapshot_id,
660                    root_node.proof.writer_id,
661                    root_node.proof.version_vector,
662                    root_node.summary.state
663                ));
664            }
665            Err(err) => {
666                printer.debug(&format_args!("RootNode: error: {err:?}"));
667            }
668        }
669    }
670}
671
672/// Returns a stream of all root nodes corresponding to the specified writer ordered from the
673/// most recent to the least recent.
674#[cfg(test)]
675pub(super) fn load_all_by_writer<'a>(
676    conn: &'a mut db::Connection,
677    writer_id: &'a PublicKey,
678) -> impl Stream<Item = Result<RootNode, Error>> + 'a {
679    sqlx::query_as(
680        "SELECT
681             snapshot_id,
682             writer_id,
683             versions,
684             hash,
685             signature,
686             state,
687             block_presence
688         FROM snapshot_root_nodes
689         WHERE writer_id = ?
690         ORDER BY snapshot_id DESC",
691    )
692    .bind(writer_id) // needed to satisfy the borrow checker.
693    .fetch(conn)
694    .err_into()
695}
696
697impl FromRow<'_, SqliteRow> for RootNode {
698    fn from_row(row: &SqliteRow) -> Result<Self, sqlx::Error> {
699        Ok(RootNode {
700            snapshot_id: row.try_get(0)?,
701            proof: Proof::new_unchecked(
702                row.try_get(1)?,
703                row.try_get(2)?,
704                row.try_get(3)?,
705                row.try_get(4)?,
706            ),
707            summary: Summary {
708                state: row.try_get(5)?,
709                block_presence: row.try_get(6)?,
710            },
711        })
712    }
713}
714
715#[cfg(test)]
716mod tests {
717    use super::*;
718    use crate::crypto::sign::Keypair;
719    use assert_matches::assert_matches;
720    use tempfile::TempDir;
721
722    #[tokio::test]
723    async fn create_new() {
724        let (_base_dir, pool) = setup().await;
725
726        let writer_id = PublicKey::random();
727        let write_keys = Keypair::random();
728        let hash = rand::random();
729
730        let mut tx = pool.begin_write().await.unwrap();
731
732        let (node0, _) = create(
733            &mut tx,
734            Proof::new(
735                writer_id,
736                VersionVector::first(writer_id),
737                hash,
738                &write_keys,
739            ),
740            Summary::INCOMPLETE,
741            RootNodeFilter::Any,
742        )
743        .await
744        .unwrap();
745        assert_eq!(node0.proof.hash, hash);
746
747        let nodes: Vec<_> = load_all_by_writer(&mut tx, &writer_id)
748            .try_collect()
749            .await
750            .unwrap();
751        assert_eq!(nodes.len(), 1);
752        assert_eq!(nodes[0], node0);
753    }
754
755    #[tokio::test]
756    async fn create_draft() {
757        let (_base_dir, pool) = setup().await;
758
759        let writer_id = PublicKey::random();
760        let write_keys = Keypair::random();
761
762        let mut tx = pool.begin_write().await.unwrap();
763
764        let (node0, kind) = create(
765            &mut tx,
766            Proof::new(
767                writer_id,
768                VersionVector::first(writer_id),
769                rand::random(),
770                &write_keys,
771            ),
772            Summary::INCOMPLETE,
773            RootNodeFilter::Any,
774        )
775        .await
776        .unwrap();
777        assert_eq!(kind, RootNodeKind::Published);
778
779        let (_node1, kind) = create(
780            &mut tx,
781            Proof::new(
782                writer_id,
783                node0.proof.version_vector.clone(),
784                rand::random(),
785                &write_keys,
786            ),
787            Summary::INCOMPLETE,
788            RootNodeFilter::Any,
789        )
790        .await
791        .unwrap();
792        assert_eq!(kind, RootNodeKind::Draft);
793    }
794
795    #[tokio::test]
796    async fn attempt_to_create_outdated_node() {
797        let (_base_dir, pool) = setup().await;
798
799        let writer_id = PublicKey::random();
800        let write_keys = Keypair::random();
801        let hash = rand::random();
802
803        let mut tx = pool.begin_write().await.unwrap();
804
805        let vv0 = VersionVector::first(writer_id);
806        let vv1 = vv0.clone().incremented(writer_id);
807
808        create(
809            &mut tx,
810            Proof::new(writer_id, vv1.clone(), hash, &write_keys),
811            Summary::INCOMPLETE,
812            RootNodeFilter::Any,
813        )
814        .await
815        .unwrap();
816
817        // Same vv
818        assert_matches!(
819            create(
820                &mut tx,
821                Proof::new(writer_id, vv1, hash, &write_keys),
822                Summary::INCOMPLETE,
823                RootNodeFilter::Published, // With `RootNodeFilter::Any` this would be allowed
824            )
825            .await,
826            Err(Error::OutdatedRootNode)
827        );
828
829        // Old vv
830        assert_matches!(
831            create(
832                &mut tx,
833                Proof::new(writer_id, vv0, hash, &write_keys),
834                Summary::INCOMPLETE,
835                RootNodeFilter::Any,
836            )
837            .await,
838            Err(Error::OutdatedRootNode)
839        );
840    }
841
842    // Proptest for the `load_all_latest_preferred` function.
843    mod load_all_latest_preferred {
844        use super::*;
845        use crate::protocol::SnapshotId;
846        use proptest::{arbitrary::any, collection::vec, sample::select, strategy::Strategy};
847        use test_strategy::proptest;
848
849        #[proptest(async = "tokio")]
850        async fn proptest(
851            write_keys: Keypair,
852            #[strategy(root_node_params_strategy())] input: Vec<(
853                SnapshotId,
854                PublicKey,
855                Hash,
856                NodeState,
857            )>,
858        ) {
859            case(write_keys, input).await
860        }
861
862        async fn case(write_keys: Keypair, input: Vec<(SnapshotId, PublicKey, Hash, NodeState)>) {
863            let (_base_dir, pool) = setup().await;
864
865            let mut writer_ids: Vec<_> = input
866                .iter()
867                .map(|(_, writer_id, _, _)| *writer_id)
868                .collect();
869            writer_ids.sort();
870            writer_ids.dedup();
871
872            let mut expected: Vec<_> = writer_ids
873                .into_iter()
874                .filter_map(|this_writer_id| {
875                    input
876                        .iter()
877                        .filter(|(_, that_writer_id, _, _)| *that_writer_id == this_writer_id)
878                        .map(|(snapshot_id, _, _, state)| (*snapshot_id, *state))
879                        .max_by_key(|(snapshot_id, state)| {
880                            (
881                                match state {
882                                    NodeState::Approved => 3,
883                                    NodeState::Complete => 2,
884                                    NodeState::Incomplete => 1,
885                                    NodeState::Rejected => 0,
886                                },
887                                *snapshot_id,
888                            )
889                        })
890                        .map(|(snapshot_id, state)| (this_writer_id, snapshot_id, state))
891                })
892                .collect();
893            expected.sort_by_key(|(writer_id, _, _)| *writer_id);
894
895            let mut vv = VersionVector::default();
896            let mut tx = pool.begin_write().await.unwrap();
897
898            for (expected_snapshot_id, writer_id, hash, state) in input {
899                vv.increment(writer_id);
900
901                let (node, _) = create(
902                    &mut tx,
903                    Proof::new(writer_id, vv.clone(), hash, &write_keys),
904                    Summary {
905                        state,
906                        block_presence: MultiBlockPresence::None,
907                    },
908                    RootNodeFilter::Any,
909                )
910                .await
911                .unwrap();
912
913                assert_eq!(node.snapshot_id, expected_snapshot_id);
914            }
915
916            let mut actual: Vec<_> = load_all_latest_preferred(&mut tx)
917                .map_ok(|node| (node.proof.writer_id, node.snapshot_id, node.summary.state))
918                .try_collect()
919                .await
920                .unwrap();
921            actual.sort_by_key(|(writer_id, _, _)| *writer_id);
922
923            assert_eq!(actual, expected);
924
925            drop(tx);
926            pool.close().await.unwrap();
927        }
928
929        fn root_node_params_strategy()
930        -> impl Strategy<Value = Vec<(SnapshotId, PublicKey, Hash, NodeState)>> {
931            vec(any::<PublicKey>(), 1..=3)
932                .prop_flat_map(|writer_ids| {
933                    vec(
934                        (select(writer_ids), any::<Hash>(), any::<NodeState>()),
935                        0..=32,
936                    )
937                })
938                .prop_map(|params| {
939                    params
940                        .into_iter()
941                        .enumerate()
942                        .map(|(index, (writer_id, hash, state))| {
943                            ((index + 1) as u32, writer_id, hash, state)
944                        })
945                        .collect()
946                })
947        }
948    }
949
950    async fn setup() -> (TempDir, db::Pool) {
951        db::create_temp().await.unwrap()
952    }
953}