1use super::error::Error;
2use crate::{
3 crypto::{sign::PublicKey, Hash},
4 db,
5 debug::DebugPrinter,
6 protocol::{
7 BlockId, MultiBlockPresence, NodeState, Proof, RootNode, RootNodeFilter, RootNodeKind,
8 SingleBlockPresence, Summary,
9 },
10 version_vector::VersionVector,
11};
12use futures_util::{Stream, StreamExt, TryStreamExt};
13use sqlx::{sqlite::SqliteRow, FromRow, Row};
14use std::{cmp::Ordering, fmt, future};
15
16#[derive(PartialEq, Eq, Debug)]
18pub(crate) enum RootNodeStatus {
19 Updated(RootNodeUpdated, MultiBlockPresence),
23 Outdated,
25}
26
27#[derive(PartialEq, Eq, Debug)]
28pub(crate) enum RootNodeUpdated {
29 Snapshot,
31 Blocks,
34}
35
36impl RootNodeStatus {
37 pub fn request_children(&self) -> Option<MultiBlockPresence> {
38 match self {
39 Self::Updated(_, block_presence) => Some(*block_presence),
40 Self::Outdated => None,
41 }
42 }
43
44 pub fn write(&self) -> bool {
45 match self {
46 Self::Updated(RootNodeUpdated::Snapshot, _) => true,
47 Self::Updated(RootNodeUpdated::Blocks, _) | Self::Outdated => false,
48 }
49 }
50}
51
52impl fmt::Display for RootNodeStatus {
53 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54 match self {
55 Self::Updated(RootNodeUpdated::Snapshot, _) => write!(f, "new snapshot"),
56 Self::Updated(RootNodeUpdated::Blocks, _) => write!(f, "new blocks"),
57 Self::Outdated => write!(f, "outdated"),
58 }
59 }
60}
61
62pub(super) async fn create(
75 tx: &mut db::WriteTransaction,
76 proof: Proof,
77 mut summary: Summary,
78 filter: RootNodeFilter,
79) -> Result<(RootNode, RootNodeKind), Error> {
80 let old_vv: VersionVector = sqlx::query(
83 "SELECT versions
84 FROM snapshot_root_nodes
85 WHERE snapshot_id = (
86 SELECT MAX(snapshot_id)
87 FROM snapshot_root_nodes
88 WHERE writer_id = ?
89 )",
90 )
91 .bind(&proof.writer_id)
92 .map(|row| row.get(0))
93 .fetch_optional(&mut *tx)
94 .await?
95 .unwrap_or_else(VersionVector::new);
96
97 let kind = match (proof.version_vector.partial_cmp(&old_vv), filter) {
98 (Some(Ordering::Greater), _) => RootNodeKind::Published,
99 (Some(Ordering::Equal), RootNodeFilter::Any) => RootNodeKind::Draft,
100 (Some(Ordering::Equal), RootNodeFilter::Published) => return Err(Error::OutdatedRootNode),
101 (Some(Ordering::Less), _) => return Err(Error::OutdatedRootNode),
102 (None, _) => return Err(Error::ConcurrentRootNode),
103 };
104
105 if summary.state == NodeState::Incomplete {
108 let state = sqlx::query(
109 "SELECT state FROM snapshot_root_nodes WHERE hash = ? AND state <> ? LIMIT 1",
110 )
111 .bind(&proof.hash)
112 .bind(NodeState::Incomplete)
113 .fetch_optional(&mut *tx)
114 .await?
115 .map(|row| row.get(0));
116
117 if let Some(state) = state {
118 summary.state = state;
119 }
120 }
121
122 let snapshot_id = sqlx::query(
123 "INSERT INTO snapshot_root_nodes (
124 writer_id,
125 versions,
126 hash,
127 signature,
128 state,
129 block_presence
130 )
131 VALUES (?, ?, ?, ?, ?, ?)
132 RETURNING snapshot_id",
133 )
134 .bind(&proof.writer_id)
135 .bind(&proof.version_vector)
136 .bind(&proof.hash)
137 .bind(&proof.signature)
138 .bind(summary.state)
139 .bind(&summary.block_presence)
140 .map(|row| row.get(0))
141 .fetch_one(tx)
142 .await?;
143
144 let node = RootNode {
145 snapshot_id,
146 proof,
147 summary,
148 };
149
150 Ok((node, kind))
151}
152
153pub(super) async fn load_latest_approved(
155 conn: &mut db::Connection,
156 branch_id: &PublicKey,
157) -> Result<RootNode, Error> {
158 sqlx::query_as(
159 "SELECT
160 snapshot_id,
161 writer_id,
162 versions,
163 hash,
164 signature,
165 state,
166 block_presence
167 FROM
168 snapshot_root_nodes
169 WHERE
170 snapshot_id = (
171 SELECT MAX(snapshot_id)
172 FROM snapshot_root_nodes
173 WHERE writer_id = ? AND state = ?
174 )
175 ",
176 )
177 .bind(branch_id)
178 .bind(NodeState::Approved)
179 .fetch_optional(conn)
180 .await?
181 .ok_or(Error::BranchNotFound)
182}
183
184pub(super) async fn load_prev_approved(
186 conn: &mut db::Connection,
187 node: &RootNode,
188) -> Result<Option<RootNode>, Error> {
189 sqlx::query_as(
190 "SELECT
191 snapshot_id,
192 writer_id,
193 versions,
194 hash,
195 signature,
196 state,
197 block_presence
198 FROM snapshot_root_nodes
199 WHERE writer_id = ? AND state = ? AND snapshot_id < ?
200 ORDER BY snapshot_id DESC
201 LIMIT 1",
202 )
203 .bind(&node.proof.writer_id)
204 .bind(NodeState::Approved)
205 .bind(node.snapshot_id)
206 .fetch(conn)
207 .err_into()
208 .try_next()
209 .await
210}
211
212pub(super) fn load_all_latest_approved(
214 conn: &mut db::Connection,
215) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
216 sqlx::query_as(
217 "SELECT
218 snapshot_id,
219 writer_id,
220 versions,
221 hash,
222 signature,
223 state,
224 block_presence
225 FROM
226 snapshot_root_nodes
227 WHERE
228 snapshot_id IN (
229 SELECT MAX(snapshot_id)
230 FROM snapshot_root_nodes
231 WHERE state = ?
232 GROUP BY writer_id
233 )",
234 )
235 .bind(NodeState::Approved)
236 .fetch(conn)
237 .err_into()
238}
239
240pub(super) fn load_all_latest_preferred(
244 conn: &mut db::Connection,
245) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
246 sqlx::query_as(
253 "SELECT
254 snapshot_id,
255 writer_id,
256 versions,
257 hash,
258 signature,
259 state,
260 block_presence
261 FROM (
262 SELECT
263 *,
264 ROW_NUMBER() OVER (
265 PARTITION BY writer_id
266 ORDER BY
267 CASE state
268 WHEN ? THEN 0
269 WHEN ? THEN 1
270 WHEN ? THEN 2
271 WHEN ? THEN 3
272 END,
273 snapshot_id DESC
274 ) AS position
275 FROM snapshot_root_nodes
276 )
277 WHERE position = 1",
278 )
279 .bind(NodeState::Approved)
280 .bind(NodeState::Complete)
281 .bind(NodeState::Incomplete)
282 .bind(NodeState::Rejected)
283 .fetch(conn)
284 .err_into()
285}
286
287pub(super) fn load_all_latest(
289 conn: &mut db::Connection,
290) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
291 sqlx::query_as(
292 "SELECT
293 snapshot_id,
294 writer_id,
295 versions,
296 hash,
297 signature,
298 state,
299 block_presence
300 FROM
301 snapshot_root_nodes
302 WHERE
303 snapshot_id IN (
304 SELECT MAX(snapshot_id)
305 FROM snapshot_root_nodes
306 GROUP BY writer_id
307 )",
308 )
309 .fetch(conn)
310 .err_into()
311}
312
313pub(super) fn load_all_by_hash<'a>(
315 conn: &'a mut db::Connection,
316 hash: &'a Hash,
317) -> impl Stream<Item = Result<RootNode, Error>> + 'a {
318 sqlx::query_as(
319 "SELECT
320 snapshot_id,
321 writer_id,
322 versions,
323 hash,
324 signature,
325 state,
326 block_presence
327 FROM snapshot_root_nodes
328 WHERE hash = ?",
329 )
330 .bind(hash)
331 .fetch(conn)
332 .err_into()
333}
334
335pub(super) async fn load_node_state_of_missing(
339 conn: &mut db::Connection,
340 block_id: &BlockId,
341) -> Result<NodeState, Error> {
342 use NodeState as S;
343
344 sqlx::query(
345 "WITH RECURSIVE
346 inner_nodes(parent) AS (
347 SELECT parent
348 FROM snapshot_leaf_nodes
349 WHERE block_id = ? AND block_presence = ?
350 UNION ALL
351 SELECT i.parent
352 FROM snapshot_inner_nodes i INNER JOIN inner_nodes c
353 WHERE i.hash = c.parent
354 )
355 SELECT state
356 FROM snapshot_root_nodes r INNER JOIN inner_nodes c
357 WHERE r.hash = c.parent",
358 )
359 .bind(block_id)
360 .bind(SingleBlockPresence::Missing)
361 .fetch(conn)
362 .map_ok(|row| row.get(0))
363 .err_into()
364 .try_fold(S::Rejected, |old, new| {
365 let new = match (old, new) {
366 (S::Incomplete | S::Complete | S::Approved | S::Rejected, S::Approved)
367 | (S::Approved, S::Incomplete | S::Complete | S::Rejected) => S::Approved,
368 (S::Incomplete | S::Complete | S::Rejected, S::Complete)
369 | (S::Complete, S::Incomplete | S::Rejected) => S::Complete,
370 (S::Incomplete | S::Rejected, S::Incomplete) | (S::Incomplete, S::Rejected) => {
371 S::Incomplete
372 }
373 (S::Rejected, S::Rejected) => S::Rejected,
374 };
375
376 future::ready(Ok(new))
377 })
378 .await
379}
380
381pub(super) async fn remove(tx: &mut db::WriteTransaction, node: &RootNode) -> Result<(), Error> {
384 sqlx::query("DELETE FROM snapshot_root_nodes WHERE snapshot_id = ?")
386 .bind(node.snapshot_id)
387 .execute(tx)
388 .await?;
389
390 Ok(())
391}
392
393pub(super) async fn remove_older(
395 tx: &mut db::WriteTransaction,
396 node: &RootNode,
397) -> Result<(), Error> {
398 sqlx::query("DELETE FROM snapshot_root_nodes WHERE snapshot_id < ? AND writer_id = ?")
400 .bind(node.snapshot_id)
401 .bind(&node.proof.writer_id)
402 .execute(tx)
403 .await?;
404
405 Ok(())
406}
407
408pub(super) async fn remove_older_incomplete(
411 tx: &mut db::WriteTransaction,
412 node: &RootNode,
413) -> Result<(), Error> {
414 sqlx::query(
416 "DELETE FROM snapshot_root_nodes
417 WHERE snapshot_id < ? AND writer_id = ? AND state IN (?, ?)
418 RETURNING hash",
419 )
420 .bind(node.snapshot_id)
421 .bind(&node.proof.writer_id)
422 .bind(NodeState::Incomplete)
423 .bind(NodeState::Rejected)
424 .fetch(tx)
425 .try_for_each(|row| {
426 tracing::trace!(hash = ?row.get::<Hash, _>(0), "outdated incomplete snapshot removed");
427 future::ready(Ok(()))
428 })
429 .await?;
430
431 Ok(())
432}
433
434pub(super) async fn update_summaries(
436 tx: &mut db::WriteTransaction,
437 hash: &Hash,
438 summary: Summary,
439) -> Result<NodeState, Error> {
440 let state = sqlx::query(
441 "UPDATE snapshot_root_nodes
442 SET block_presence = ?,
443 state = CASE state WHEN ? THEN ? ELSE state END
444 WHERE hash = ?
445 RETURNING state
446 ",
447 )
448 .bind(&summary.block_presence)
449 .bind(NodeState::Incomplete)
450 .bind(summary.state)
451 .bind(hash)
452 .fetch_optional(tx)
453 .await?
454 .map(|row| row.get(0))
455 .unwrap_or(NodeState::Incomplete);
456
457 Ok(state)
458}
459
460pub(super) async fn check_fallback(
464 conn: &mut db::Connection,
465 old: &RootNode,
466 new: &RootNode,
467) -> Result<bool, Error> {
468 Ok(sqlx::query(
471 "WITH RECURSIVE
472 inner_nodes_old(hash) AS (
473 SELECT i.hash
474 FROM snapshot_inner_nodes AS i
475 INNER JOIN snapshot_root_nodes AS r ON r.hash = i.parent
476 WHERE r.snapshot_id = ?
477 UNION ALL
478 SELECT c.hash
479 FROM snapshot_inner_nodes AS c
480 INNER JOIN inner_nodes_old AS p ON p.hash = c.parent
481 ),
482 inner_nodes_new(hash) AS (
483 SELECT i.hash
484 FROM snapshot_inner_nodes AS i
485 INNER JOIN snapshot_root_nodes AS r ON r.hash = i.parent
486 WHERE r.snapshot_id = ?
487 UNION ALL
488 SELECT c.hash
489 FROM snapshot_inner_nodes AS c
490 INNER JOIN inner_nodes_new AS p ON p.hash = c.parent
491 )
492 SELECT locator
493 FROM snapshot_leaf_nodes
494 WHERE block_presence = ? AND parent IN inner_nodes_old
495 INTERSECT
496 SELECT locator
497 FROM snapshot_leaf_nodes
498 WHERE block_presence = ? AND parent IN inner_nodes_new
499 LIMIT 1",
500 )
501 .bind(old.snapshot_id)
502 .bind(new.snapshot_id)
503 .bind(SingleBlockPresence::Present)
504 .bind(SingleBlockPresence::Missing)
505 .fetch_optional(conn)
506 .await?
507 .is_some())
508}
509
510pub(super) fn approve<'a>(
512 tx: &'a mut db::WriteTransaction,
513 hash: &'a Hash,
514) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
515 set_state(tx, hash, NodeState::Approved)
516}
517
518pub(super) fn reject<'a>(
520 tx: &'a mut db::WriteTransaction,
521 hash: &'a Hash,
522) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
523 set_state(tx, hash, NodeState::Rejected)
524}
525
526fn set_state<'a>(
529 tx: &'a mut db::WriteTransaction,
530 hash: &'a Hash,
531 state: NodeState,
532) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
533 sqlx::query("UPDATE snapshot_root_nodes SET state = ? WHERE hash = ? RETURNING writer_id")
534 .bind(state)
535 .bind(hash)
536 .fetch(tx)
537 .map_ok(|row| row.get(0))
538 .err_into()
539}
540
541pub(super) fn load_writer_ids(
543 conn: &mut db::Connection,
544) -> impl Stream<Item = Result<PublicKey, Error>> + '_ {
545 sqlx::query("SELECT DISTINCT writer_id FROM snapshot_root_nodes")
546 .fetch(conn)
547 .map_ok(|row| row.get(0))
548 .err_into()
549}
550
551pub(super) fn load_writer_ids_by_hash<'a>(
553 conn: &'a mut db::Connection,
554 hash: &'a Hash,
555) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
556 sqlx::query("SELECT DISTINCT writer_id FROM snapshot_root_nodes WHERE hash = ?")
557 .bind(hash)
558 .fetch(conn)
559 .map_ok(|row| row.get(0))
560 .err_into()
561}
562
563pub(super) async fn status(
564 conn: &mut db::Connection,
565 new_proof: &Proof,
566 new_block_presence: &MultiBlockPresence,
567) -> Result<RootNodeStatus, Error> {
568 let mut updated = RootNodeUpdated::Snapshot;
569 let mut block_presence = MultiBlockPresence::None;
570
571 let mut old_nodes = load_all_latest(conn);
572 while let Some(old_node) = old_nodes.try_next().await? {
573 match new_proof
574 .version_vector
575 .partial_cmp(&old_node.proof.version_vector)
576 {
577 Some(Ordering::Less) => {
578 return Ok(RootNodeStatus::Outdated);
581 }
582 Some(Ordering::Equal) => {
583 if new_proof.hash == old_node.proof.hash {
584 if old_node
590 .summary
591 .block_presence
592 .is_outdated(new_block_presence)
593 {
594 updated = RootNodeUpdated::Blocks;
595 } else {
596 return Ok(RootNodeStatus::Outdated);
597 }
598 } else {
599 tracing::warn!(
600 vv = ?old_node.proof.version_vector,
601 old_writer_id = ?old_node.proof.writer_id,
602 new_writer_id = ?new_proof.writer_id,
603 old_hash = ?old_node.proof.hash,
604 new_hash = ?new_proof.hash,
605 "Received root node invalid - broken invariant: same vv but different hash"
606 );
607
608 return Ok(RootNodeStatus::Outdated);
609 }
610 }
611 Some(Ordering::Greater) => (),
612 None => {
613 if new_proof.writer_id == old_node.proof.writer_id {
614 tracing::warn!(
615 old_vv = ?old_node.proof.version_vector,
616 new_vv = ?new_proof.version_vector,
617 writer_id = ?new_proof.writer_id,
618 "Received root node invalid - broken invariant: concurrency within branch is not allowed"
619 );
620
621 return Ok(RootNodeStatus::Outdated);
622 }
623 }
624 }
625
626 if old_node.proof.writer_id == new_proof.writer_id {
627 block_presence = old_node.summary.block_presence;
628 }
629 }
630
631 Ok(RootNodeStatus::Updated(updated, block_presence))
632}
633
634pub(super) async fn debug_print(conn: &mut db::Connection, printer: DebugPrinter) {
635 let mut roots = sqlx::query_as::<_, RootNode>(
636 "SELECT
637 snapshot_id,
638 writer_id,
639 versions,
640 hash,
641 signature,
642 state,
643 block_presence
644 FROM snapshot_root_nodes
645 ORDER BY snapshot_id DESC",
646 )
647 .fetch(conn);
648
649 while let Some(root_node) = roots.next().await {
650 match root_node {
651 Ok(root_node) => {
652 printer.debug(&format_args!(
653 "RootNode: snapshot_id:{:?}, writer_id:{:?}, vv:{:?}, state:{:?}",
654 root_node.snapshot_id,
655 root_node.proof.writer_id,
656 root_node.proof.version_vector,
657 root_node.summary.state
658 ));
659 }
660 Err(err) => {
661 printer.debug(&format_args!("RootNode: error: {:?}", err));
662 }
663 }
664 }
665}
666
667#[cfg(test)]
670pub(super) fn load_all_by_writer<'a>(
671 conn: &'a mut db::Connection,
672 writer_id: &'a PublicKey,
673) -> impl Stream<Item = Result<RootNode, Error>> + 'a {
674 sqlx::query_as(
675 "SELECT
676 snapshot_id,
677 writer_id,
678 versions,
679 hash,
680 signature,
681 state,
682 block_presence
683 FROM snapshot_root_nodes
684 WHERE writer_id = ?
685 ORDER BY snapshot_id DESC",
686 )
687 .bind(writer_id) .fetch(conn)
689 .err_into()
690}
691
692impl FromRow<'_, SqliteRow> for RootNode {
693 fn from_row(row: &SqliteRow) -> Result<Self, sqlx::Error> {
694 Ok(RootNode {
695 snapshot_id: row.try_get(0)?,
696 proof: Proof::new_unchecked(
697 row.try_get(1)?,
698 row.try_get(2)?,
699 row.try_get(3)?,
700 row.try_get(4)?,
701 ),
702 summary: Summary {
703 state: row.try_get(5)?,
704 block_presence: row.try_get(6)?,
705 },
706 })
707 }
708}
709
710#[cfg(test)]
711mod tests {
712 use super::*;
713 use crate::crypto::sign::Keypair;
714 use assert_matches::assert_matches;
715 use tempfile::TempDir;
716
717 #[tokio::test]
718 async fn create_new() {
719 let (_base_dir, pool) = setup().await;
720
721 let writer_id = PublicKey::random();
722 let write_keys = Keypair::random();
723 let hash = rand::random();
724
725 let mut tx = pool.begin_write().await.unwrap();
726
727 let (node0, _) = create(
728 &mut tx,
729 Proof::new(
730 writer_id,
731 VersionVector::first(writer_id),
732 hash,
733 &write_keys,
734 ),
735 Summary::INCOMPLETE,
736 RootNodeFilter::Any,
737 )
738 .await
739 .unwrap();
740 assert_eq!(node0.proof.hash, hash);
741
742 let nodes: Vec<_> = load_all_by_writer(&mut tx, &writer_id)
743 .try_collect()
744 .await
745 .unwrap();
746 assert_eq!(nodes.len(), 1);
747 assert_eq!(nodes[0], node0);
748 }
749
750 #[tokio::test]
751 async fn create_draft() {
752 let (_base_dir, pool) = setup().await;
753
754 let writer_id = PublicKey::random();
755 let write_keys = Keypair::random();
756
757 let mut tx = pool.begin_write().await.unwrap();
758
759 let (node0, kind) = create(
760 &mut tx,
761 Proof::new(
762 writer_id,
763 VersionVector::first(writer_id),
764 rand::random(),
765 &write_keys,
766 ),
767 Summary::INCOMPLETE,
768 RootNodeFilter::Any,
769 )
770 .await
771 .unwrap();
772 assert_eq!(kind, RootNodeKind::Published);
773
774 let (_node1, kind) = create(
775 &mut tx,
776 Proof::new(
777 writer_id,
778 node0.proof.version_vector.clone(),
779 rand::random(),
780 &write_keys,
781 ),
782 Summary::INCOMPLETE,
783 RootNodeFilter::Any,
784 )
785 .await
786 .unwrap();
787 assert_eq!(kind, RootNodeKind::Draft);
788 }
789
790 #[tokio::test]
791 async fn attempt_to_create_outdated_node() {
792 let (_base_dir, pool) = setup().await;
793
794 let writer_id = PublicKey::random();
795 let write_keys = Keypair::random();
796 let hash = rand::random();
797
798 let mut tx = pool.begin_write().await.unwrap();
799
800 let vv0 = VersionVector::first(writer_id);
801 let vv1 = vv0.clone().incremented(writer_id);
802
803 create(
804 &mut tx,
805 Proof::new(writer_id, vv1.clone(), hash, &write_keys),
806 Summary::INCOMPLETE,
807 RootNodeFilter::Any,
808 )
809 .await
810 .unwrap();
811
812 assert_matches!(
814 create(
815 &mut tx,
816 Proof::new(writer_id, vv1, hash, &write_keys),
817 Summary::INCOMPLETE,
818 RootNodeFilter::Published, )
820 .await,
821 Err(Error::OutdatedRootNode)
822 );
823
824 assert_matches!(
826 create(
827 &mut tx,
828 Proof::new(writer_id, vv0, hash, &write_keys),
829 Summary::INCOMPLETE,
830 RootNodeFilter::Any,
831 )
832 .await,
833 Err(Error::OutdatedRootNode)
834 );
835 }
836
837 mod load_all_latest_preferred {
839 use super::*;
840 use crate::protocol::SnapshotId;
841 use proptest::{arbitrary::any, collection::vec, sample::select, strategy::Strategy};
842 use test_strategy::proptest;
843
844 #[proptest(async = "tokio")]
845 async fn proptest(
846 write_keys: Keypair,
847 #[strategy(root_node_params_strategy())] input: Vec<(
848 SnapshotId,
849 PublicKey,
850 Hash,
851 NodeState,
852 )>,
853 ) {
854 case(write_keys, input).await
855 }
856
857 async fn case(write_keys: Keypair, input: Vec<(SnapshotId, PublicKey, Hash, NodeState)>) {
858 let (_base_dir, pool) = setup().await;
859
860 let mut writer_ids: Vec<_> = input
861 .iter()
862 .map(|(_, writer_id, _, _)| *writer_id)
863 .collect();
864 writer_ids.sort();
865 writer_ids.dedup();
866
867 let mut expected: Vec<_> = writer_ids
868 .into_iter()
869 .filter_map(|this_writer_id| {
870 input
871 .iter()
872 .filter(|(_, that_writer_id, _, _)| *that_writer_id == this_writer_id)
873 .map(|(snapshot_id, _, _, state)| (*snapshot_id, *state))
874 .max_by_key(|(snapshot_id, state)| {
875 (
876 match state {
877 NodeState::Approved => 3,
878 NodeState::Complete => 2,
879 NodeState::Incomplete => 1,
880 NodeState::Rejected => 0,
881 },
882 *snapshot_id,
883 )
884 })
885 .map(|(snapshot_id, state)| (this_writer_id, snapshot_id, state))
886 })
887 .collect();
888 expected.sort_by_key(|(writer_id, _, _)| *writer_id);
889
890 let mut vv = VersionVector::default();
891 let mut tx = pool.begin_write().await.unwrap();
892
893 for (expected_snapshot_id, writer_id, hash, state) in input {
894 vv.increment(writer_id);
895
896 let (node, _) = create(
897 &mut tx,
898 Proof::new(writer_id, vv.clone(), hash, &write_keys),
899 Summary {
900 state,
901 block_presence: MultiBlockPresence::None,
902 },
903 RootNodeFilter::Any,
904 )
905 .await
906 .unwrap();
907
908 assert_eq!(node.snapshot_id, expected_snapshot_id);
909 }
910
911 let mut actual: Vec<_> = load_all_latest_preferred(&mut tx)
912 .map_ok(|node| (node.proof.writer_id, node.snapshot_id, node.summary.state))
913 .try_collect()
914 .await
915 .unwrap();
916 actual.sort_by_key(|(writer_id, _, _)| *writer_id);
917
918 assert_eq!(actual, expected);
919
920 drop(tx);
921 pool.close().await.unwrap();
922 }
923
924 fn root_node_params_strategy(
925 ) -> impl Strategy<Value = Vec<(SnapshotId, PublicKey, Hash, NodeState)>> {
926 vec(any::<PublicKey>(), 1..=3)
927 .prop_flat_map(|writer_ids| {
928 vec(
929 (select(writer_ids), any::<Hash>(), any::<NodeState>()),
930 0..=32,
931 )
932 })
933 .prop_map(|params| {
934 params
935 .into_iter()
936 .enumerate()
937 .map(|(index, (writer_id, hash, state))| {
938 ((index + 1) as u32, writer_id, hash, state)
939 })
940 .collect()
941 })
942 }
943 }
944
945 async fn setup() -> (TempDir, db::Pool) {
946 db::create_temp().await.unwrap()
947 }
948}