1use super::error::Error;
2use crate::{
3 crypto::{Hash, sign::PublicKey},
4 db,
5 debug::DebugPrinter,
6 protocol::{
7 BlockId, MultiBlockPresence, NodeState, Proof, RootNode, RootNodeFilter, RootNodeKind,
8 SingleBlockPresence, Summary,
9 },
10 version_vector::VersionVector,
11};
12use futures_util::{Stream, StreamExt, TryStreamExt};
13use sqlx::{FromRow, Row, sqlite::SqliteRow};
14use std::{cmp::Ordering, fmt, future};
15
16#[derive(PartialEq, Eq, Debug)]
18pub(crate) enum RootNodeStatus {
19 Updated(RootNodeUpdated, MultiBlockPresence),
23 Outdated,
25}
26
27#[derive(PartialEq, Eq, Debug)]
28pub(crate) enum RootNodeUpdated {
29 Snapshot,
31 Blocks,
34}
35
36impl RootNodeStatus {
37 pub fn request_children(&self) -> Option<MultiBlockPresence> {
38 match self {
39 Self::Updated(_, block_presence) => Some(*block_presence),
40 Self::Outdated => None,
41 }
42 }
43
44 pub fn write(&self) -> bool {
45 match self {
46 Self::Updated(RootNodeUpdated::Snapshot, _) => true,
47 Self::Updated(RootNodeUpdated::Blocks, _) | Self::Outdated => false,
48 }
49 }
50}
51
52impl fmt::Display for RootNodeStatus {
53 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54 match self {
55 Self::Updated(RootNodeUpdated::Snapshot, _) => write!(f, "new snapshot"),
56 Self::Updated(RootNodeUpdated::Blocks, _) => write!(f, "new blocks"),
57 Self::Outdated => write!(f, "outdated"),
58 }
59 }
60}
61
62pub(super) async fn create(
75 tx: &mut db::WriteTransaction,
76 proof: Proof,
77 mut summary: Summary,
78 filter: RootNodeFilter,
79) -> Result<(RootNode, RootNodeKind), Error> {
80 let old_vv: VersionVector = sqlx::query(
83 "SELECT versions
84 FROM snapshot_root_nodes
85 WHERE snapshot_id = (
86 SELECT MAX(snapshot_id)
87 FROM snapshot_root_nodes
88 WHERE writer_id = ?
89 )",
90 )
91 .bind(&proof.writer_id)
92 .map(|row| row.get(0))
93 .fetch_optional(&mut *tx)
94 .await?
95 .unwrap_or_else(VersionVector::new);
96
97 let kind = match (proof.version_vector.partial_cmp(&old_vv), filter) {
98 (Some(Ordering::Greater), _) => RootNodeKind::Published,
99 (Some(Ordering::Equal), RootNodeFilter::Any) => RootNodeKind::Draft,
100 (Some(Ordering::Equal), RootNodeFilter::Published) => return Err(Error::OutdatedRootNode),
101 (Some(Ordering::Less), _) => return Err(Error::OutdatedRootNode),
102 (None, _) => return Err(Error::ConcurrentRootNode),
103 };
104
105 if summary.state == NodeState::Incomplete {
108 let state = sqlx::query(
109 "SELECT state FROM snapshot_root_nodes WHERE hash = ? AND state <> ? LIMIT 1",
110 )
111 .bind(&proof.hash)
112 .bind(NodeState::Incomplete)
113 .fetch_optional(&mut *tx)
114 .await?
115 .map(|row| row.get(0));
116
117 if let Some(state) = state {
118 summary.state = state;
119 }
120 }
121
122 let snapshot_id = sqlx::query(
123 "INSERT INTO snapshot_root_nodes (
124 writer_id,
125 versions,
126 hash,
127 signature,
128 state,
129 block_presence
130 )
131 VALUES (?, ?, ?, ?, ?, ?)
132 RETURNING snapshot_id",
133 )
134 .bind(&proof.writer_id)
135 .bind(&proof.version_vector)
136 .bind(&proof.hash)
137 .bind(&proof.signature)
138 .bind(summary.state)
139 .bind(&summary.block_presence)
140 .map(|row| row.get(0))
141 .fetch_one(tx)
142 .await?;
143
144 let node = RootNode {
145 snapshot_id,
146 proof,
147 summary,
148 };
149
150 Ok((node, kind))
151}
152
153pub(super) async fn load_latest_approved(
155 conn: &mut db::Connection,
156 branch_id: &PublicKey,
157) -> Result<RootNode, Error> {
158 sqlx::query_as(
159 "SELECT
160 snapshot_id,
161 writer_id,
162 versions,
163 hash,
164 signature,
165 state,
166 block_presence
167 FROM
168 snapshot_root_nodes
169 WHERE
170 snapshot_id = (
171 SELECT MAX(snapshot_id)
172 FROM snapshot_root_nodes
173 WHERE writer_id = ? AND state = ?
174 )
175 ",
176 )
177 .bind(branch_id)
178 .bind(NodeState::Approved)
179 .fetch_optional(conn)
180 .await?
181 .ok_or(Error::BranchNotFound)
182}
183
184pub(super) async fn load_prev_approved(
186 conn: &mut db::Connection,
187 node: &RootNode,
188) -> Result<Option<RootNode>, Error> {
189 sqlx::query_as(
190 "SELECT
191 snapshot_id,
192 writer_id,
193 versions,
194 hash,
195 signature,
196 state,
197 block_presence
198 FROM snapshot_root_nodes
199 WHERE writer_id = ? AND state = ? AND snapshot_id < ?
200 ORDER BY snapshot_id DESC
201 LIMIT 1",
202 )
203 .bind(&node.proof.writer_id)
204 .bind(NodeState::Approved)
205 .bind(node.snapshot_id)
206 .fetch(conn)
207 .err_into()
208 .try_next()
209 .await
210}
211
212pub(super) fn load_all_latest_approved(
214 conn: &mut db::Connection,
215) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
216 sqlx::query_as(
217 "SELECT
218 snapshot_id,
219 writer_id,
220 versions,
221 hash,
222 signature,
223 state,
224 block_presence
225 FROM
226 snapshot_root_nodes
227 WHERE
228 snapshot_id IN (
229 SELECT MAX(snapshot_id)
230 FROM snapshot_root_nodes
231 WHERE state = ?
232 GROUP BY writer_id
233 )",
234 )
235 .bind(NodeState::Approved)
236 .fetch(conn)
237 .err_into()
238}
239
240pub(super) fn load_all_latest_preferred(
244 conn: &mut db::Connection,
245) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
246 sqlx::query_as(
253 "SELECT
254 snapshot_id,
255 writer_id,
256 versions,
257 hash,
258 signature,
259 state,
260 block_presence
261 FROM (
262 SELECT
263 *,
264 ROW_NUMBER() OVER (
265 PARTITION BY writer_id
266 ORDER BY
267 CASE state
268 WHEN ? THEN 0
269 WHEN ? THEN 1
270 WHEN ? THEN 2
271 WHEN ? THEN 3
272 END,
273 snapshot_id DESC
274 ) AS position
275 FROM snapshot_root_nodes
276 )
277 WHERE position = 1",
278 )
279 .bind(NodeState::Approved)
280 .bind(NodeState::Complete)
281 .bind(NodeState::Incomplete)
282 .bind(NodeState::Rejected)
283 .fetch(conn)
284 .err_into()
285}
286
287pub(super) fn load_all_latest(
289 conn: &mut db::Connection,
290) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
291 sqlx::query_as(
292 "SELECT
293 snapshot_id,
294 writer_id,
295 versions,
296 hash,
297 signature,
298 state,
299 block_presence
300 FROM
301 snapshot_root_nodes
302 WHERE
303 snapshot_id IN (
304 SELECT MAX(snapshot_id)
305 FROM snapshot_root_nodes
306 GROUP BY writer_id
307 )",
308 )
309 .fetch(conn)
310 .err_into()
311}
312
313pub(super) fn load_all_by_hash<'a>(
315 conn: &'a mut db::Connection,
316 hash: &'a Hash,
317) -> impl Stream<Item = Result<RootNode, Error>> + 'a {
318 sqlx::query_as(
319 "SELECT
320 snapshot_id,
321 writer_id,
322 versions,
323 hash,
324 signature,
325 state,
326 block_presence
327 FROM snapshot_root_nodes
328 WHERE hash = ?",
329 )
330 .bind(hash)
331 .fetch(conn)
332 .err_into()
333}
334
335pub(super) async fn load_node_state_of_missing(
339 conn: &mut db::Connection,
340 block_id: &BlockId,
341) -> Result<NodeState, Error> {
342 use NodeState as S;
343
344 sqlx::query(
345 "WITH RECURSIVE
346 inner_nodes(parent) AS (
347 SELECT parent
348 FROM snapshot_leaf_nodes
349 WHERE block_id = ? AND block_presence = ?
350 UNION ALL
351 SELECT i.parent
352 FROM snapshot_inner_nodes i INNER JOIN inner_nodes c
353 WHERE i.hash = c.parent
354 )
355 SELECT state
356 FROM snapshot_root_nodes r INNER JOIN inner_nodes c
357 WHERE r.hash = c.parent",
358 )
359 .bind(block_id)
360 .bind(SingleBlockPresence::Missing)
361 .fetch(conn)
362 .map_ok(|row| row.get(0))
363 .err_into()
364 .try_fold(S::Rejected, |old, new| {
365 let new = match (old, new) {
366 (S::Incomplete | S::Complete | S::Approved | S::Rejected, S::Approved)
367 | (S::Approved, S::Incomplete | S::Complete | S::Rejected) => S::Approved,
368 (S::Incomplete | S::Complete | S::Rejected, S::Complete)
369 | (S::Complete, S::Incomplete | S::Rejected) => S::Complete,
370 (S::Incomplete | S::Rejected, S::Incomplete) | (S::Incomplete, S::Rejected) => {
371 S::Incomplete
372 }
373 (S::Rejected, S::Rejected) => S::Rejected,
374 };
375
376 future::ready(Ok(new))
377 })
378 .await
379}
380
381pub(super) async fn remove(tx: &mut db::WriteTransaction, node: &RootNode) -> Result<(), Error> {
384 sqlx::query("DELETE FROM snapshot_root_nodes WHERE snapshot_id = ?")
386 .bind(node.snapshot_id)
387 .execute(tx)
388 .await?;
389
390 Ok(())
391}
392
393pub(super) async fn remove_older(
395 tx: &mut db::WriteTransaction,
396 node: &RootNode,
397) -> Result<(), Error> {
398 sqlx::query("DELETE FROM snapshot_root_nodes WHERE snapshot_id < ? AND writer_id = ?")
400 .bind(node.snapshot_id)
401 .bind(&node.proof.writer_id)
402 .execute(tx)
403 .await?;
404
405 Ok(())
406}
407
408pub(super) async fn remove_older_incomplete(
413 tx: &mut db::WriteTransaction,
414 node: &RootNode,
415) -> Result<bool, Error> {
416 let mut changed = false;
417
418 sqlx::query(
420 "DELETE FROM snapshot_root_nodes
421 WHERE snapshot_id < ? AND writer_id = ? AND state IN (?, ?)
422 RETURNING hash",
423 )
424 .bind(node.snapshot_id)
425 .bind(&node.proof.writer_id)
426 .bind(NodeState::Incomplete)
427 .bind(NodeState::Rejected)
428 .fetch(tx)
429 .try_for_each(|row| {
430 tracing::trace!(hash = ?row.get::<Hash, _>(0), "outdated incomplete snapshot removed");
431 changed = true;
432 future::ready(Ok(()))
433 })
434 .await?;
435
436 Ok(changed)
437}
438
439pub(super) async fn update_summaries(
441 tx: &mut db::WriteTransaction,
442 hash: &Hash,
443 summary: Summary,
444) -> Result<NodeState, Error> {
445 let state = sqlx::query(
446 "UPDATE snapshot_root_nodes
447 SET block_presence = ?,
448 state = CASE state WHEN ? THEN ? ELSE state END
449 WHERE hash = ?
450 RETURNING state
451 ",
452 )
453 .bind(&summary.block_presence)
454 .bind(NodeState::Incomplete)
455 .bind(summary.state)
456 .bind(hash)
457 .fetch_optional(tx)
458 .await?
459 .map(|row| row.get(0))
460 .unwrap_or(NodeState::Incomplete);
461
462 Ok(state)
463}
464
465pub(super) async fn check_fallback(
469 conn: &mut db::Connection,
470 old: &RootNode,
471 new: &RootNode,
472) -> Result<bool, Error> {
473 Ok(sqlx::query(
476 "WITH RECURSIVE
477 inner_nodes_old(hash) AS (
478 SELECT i.hash
479 FROM snapshot_inner_nodes AS i
480 INNER JOIN snapshot_root_nodes AS r ON r.hash = i.parent
481 WHERE r.snapshot_id = ?
482 UNION ALL
483 SELECT c.hash
484 FROM snapshot_inner_nodes AS c
485 INNER JOIN inner_nodes_old AS p ON p.hash = c.parent
486 ),
487 inner_nodes_new(hash) AS (
488 SELECT i.hash
489 FROM snapshot_inner_nodes AS i
490 INNER JOIN snapshot_root_nodes AS r ON r.hash = i.parent
491 WHERE r.snapshot_id = ?
492 UNION ALL
493 SELECT c.hash
494 FROM snapshot_inner_nodes AS c
495 INNER JOIN inner_nodes_new AS p ON p.hash = c.parent
496 )
497 SELECT locator
498 FROM snapshot_leaf_nodes
499 WHERE block_presence = ? AND parent IN inner_nodes_old
500 INTERSECT
501 SELECT locator
502 FROM snapshot_leaf_nodes
503 WHERE block_presence = ? AND parent IN inner_nodes_new
504 LIMIT 1",
505 )
506 .bind(old.snapshot_id)
507 .bind(new.snapshot_id)
508 .bind(SingleBlockPresence::Present)
509 .bind(SingleBlockPresence::Missing)
510 .fetch_optional(conn)
511 .await?
512 .is_some())
513}
514
515pub(super) fn approve<'a>(
517 tx: &'a mut db::WriteTransaction,
518 hash: &'a Hash,
519) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
520 set_state(tx, hash, NodeState::Approved)
521}
522
523pub(super) fn reject<'a>(
525 tx: &'a mut db::WriteTransaction,
526 hash: &'a Hash,
527) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
528 set_state(tx, hash, NodeState::Rejected)
529}
530
531fn set_state<'a>(
534 tx: &'a mut db::WriteTransaction,
535 hash: &'a Hash,
536 state: NodeState,
537) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
538 sqlx::query("UPDATE snapshot_root_nodes SET state = ? WHERE hash = ? RETURNING writer_id")
539 .bind(state)
540 .bind(hash)
541 .fetch(tx)
542 .map_ok(|row| row.get(0))
543 .err_into()
544}
545
546pub(super) fn load_writer_ids(
548 conn: &mut db::Connection,
549) -> impl Stream<Item = Result<PublicKey, Error>> + '_ {
550 sqlx::query("SELECT DISTINCT writer_id FROM snapshot_root_nodes")
551 .fetch(conn)
552 .map_ok(|row| row.get(0))
553 .err_into()
554}
555
556pub(super) fn load_writer_ids_by_hash<'a>(
558 conn: &'a mut db::Connection,
559 hash: &'a Hash,
560) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
561 sqlx::query("SELECT DISTINCT writer_id FROM snapshot_root_nodes WHERE hash = ?")
562 .bind(hash)
563 .fetch(conn)
564 .map_ok(|row| row.get(0))
565 .err_into()
566}
567
568pub(super) async fn status(
569 conn: &mut db::Connection,
570 new_proof: &Proof,
571 new_block_presence: &MultiBlockPresence,
572) -> Result<RootNodeStatus, Error> {
573 let mut updated = RootNodeUpdated::Snapshot;
574 let mut block_presence = MultiBlockPresence::None;
575
576 let mut old_nodes = load_all_latest(conn);
577 while let Some(old_node) = old_nodes.try_next().await? {
578 match new_proof
579 .version_vector
580 .partial_cmp(&old_node.proof.version_vector)
581 {
582 Some(Ordering::Less) => {
583 return Ok(RootNodeStatus::Outdated);
586 }
587 Some(Ordering::Equal) => {
588 if new_proof.hash == old_node.proof.hash {
589 if old_node
595 .summary
596 .block_presence
597 .is_outdated(new_block_presence)
598 {
599 updated = RootNodeUpdated::Blocks;
600 } else {
601 return Ok(RootNodeStatus::Outdated);
602 }
603 } else {
604 tracing::warn!(
605 vv = ?old_node.proof.version_vector,
606 old_writer_id = ?old_node.proof.writer_id,
607 new_writer_id = ?new_proof.writer_id,
608 old_hash = ?old_node.proof.hash,
609 new_hash = ?new_proof.hash,
610 "Received root node invalid - broken invariant: same vv but different hash"
611 );
612
613 return Ok(RootNodeStatus::Outdated);
614 }
615 }
616 Some(Ordering::Greater) => (),
617 None => {
618 if new_proof.writer_id == old_node.proof.writer_id {
619 tracing::warn!(
620 old_vv = ?old_node.proof.version_vector,
621 new_vv = ?new_proof.version_vector,
622 writer_id = ?new_proof.writer_id,
623 "Received root node invalid - broken invariant: concurrency within branch is not allowed"
624 );
625
626 return Ok(RootNodeStatus::Outdated);
627 }
628 }
629 }
630
631 if old_node.proof.writer_id == new_proof.writer_id {
632 block_presence = old_node.summary.block_presence;
633 }
634 }
635
636 Ok(RootNodeStatus::Updated(updated, block_presence))
637}
638
639pub(super) async fn debug_print(conn: &mut db::Connection, printer: DebugPrinter) {
640 let mut roots = sqlx::query_as::<_, RootNode>(
641 "SELECT
642 snapshot_id,
643 writer_id,
644 versions,
645 hash,
646 signature,
647 state,
648 block_presence
649 FROM snapshot_root_nodes
650 ORDER BY snapshot_id DESC",
651 )
652 .fetch(conn);
653
654 while let Some(root_node) = roots.next().await {
655 match root_node {
656 Ok(root_node) => {
657 printer.debug(&format_args!(
658 "RootNode: snapshot_id:{:?}, writer_id:{:?}, vv:{:?}, state:{:?}",
659 root_node.snapshot_id,
660 root_node.proof.writer_id,
661 root_node.proof.version_vector,
662 root_node.summary.state
663 ));
664 }
665 Err(err) => {
666 printer.debug(&format_args!("RootNode: error: {err:?}"));
667 }
668 }
669 }
670}
671
672#[cfg(test)]
675pub(super) fn load_all_by_writer<'a>(
676 conn: &'a mut db::Connection,
677 writer_id: &'a PublicKey,
678) -> impl Stream<Item = Result<RootNode, Error>> + 'a {
679 sqlx::query_as(
680 "SELECT
681 snapshot_id,
682 writer_id,
683 versions,
684 hash,
685 signature,
686 state,
687 block_presence
688 FROM snapshot_root_nodes
689 WHERE writer_id = ?
690 ORDER BY snapshot_id DESC",
691 )
692 .bind(writer_id) .fetch(conn)
694 .err_into()
695}
696
697impl FromRow<'_, SqliteRow> for RootNode {
698 fn from_row(row: &SqliteRow) -> Result<Self, sqlx::Error> {
699 Ok(RootNode {
700 snapshot_id: row.try_get(0)?,
701 proof: Proof::new_unchecked(
702 row.try_get(1)?,
703 row.try_get(2)?,
704 row.try_get(3)?,
705 row.try_get(4)?,
706 ),
707 summary: Summary {
708 state: row.try_get(5)?,
709 block_presence: row.try_get(6)?,
710 },
711 })
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718 use crate::crypto::sign::Keypair;
719 use assert_matches::assert_matches;
720 use tempfile::TempDir;
721
722 #[tokio::test]
723 async fn create_new() {
724 let (_base_dir, pool) = setup().await;
725
726 let writer_id = PublicKey::random();
727 let write_keys = Keypair::random();
728 let hash = rand::random();
729
730 let mut tx = pool.begin_write().await.unwrap();
731
732 let (node0, _) = create(
733 &mut tx,
734 Proof::new(
735 writer_id,
736 VersionVector::first(writer_id),
737 hash,
738 &write_keys,
739 ),
740 Summary::INCOMPLETE,
741 RootNodeFilter::Any,
742 )
743 .await
744 .unwrap();
745 assert_eq!(node0.proof.hash, hash);
746
747 let nodes: Vec<_> = load_all_by_writer(&mut tx, &writer_id)
748 .try_collect()
749 .await
750 .unwrap();
751 assert_eq!(nodes.len(), 1);
752 assert_eq!(nodes[0], node0);
753 }
754
755 #[tokio::test]
756 async fn create_draft() {
757 let (_base_dir, pool) = setup().await;
758
759 let writer_id = PublicKey::random();
760 let write_keys = Keypair::random();
761
762 let mut tx = pool.begin_write().await.unwrap();
763
764 let (node0, kind) = create(
765 &mut tx,
766 Proof::new(
767 writer_id,
768 VersionVector::first(writer_id),
769 rand::random(),
770 &write_keys,
771 ),
772 Summary::INCOMPLETE,
773 RootNodeFilter::Any,
774 )
775 .await
776 .unwrap();
777 assert_eq!(kind, RootNodeKind::Published);
778
779 let (_node1, kind) = create(
780 &mut tx,
781 Proof::new(
782 writer_id,
783 node0.proof.version_vector.clone(),
784 rand::random(),
785 &write_keys,
786 ),
787 Summary::INCOMPLETE,
788 RootNodeFilter::Any,
789 )
790 .await
791 .unwrap();
792 assert_eq!(kind, RootNodeKind::Draft);
793 }
794
795 #[tokio::test]
796 async fn attempt_to_create_outdated_node() {
797 let (_base_dir, pool) = setup().await;
798
799 let writer_id = PublicKey::random();
800 let write_keys = Keypair::random();
801 let hash = rand::random();
802
803 let mut tx = pool.begin_write().await.unwrap();
804
805 let vv0 = VersionVector::first(writer_id);
806 let vv1 = vv0.clone().incremented(writer_id);
807
808 create(
809 &mut tx,
810 Proof::new(writer_id, vv1.clone(), hash, &write_keys),
811 Summary::INCOMPLETE,
812 RootNodeFilter::Any,
813 )
814 .await
815 .unwrap();
816
817 assert_matches!(
819 create(
820 &mut tx,
821 Proof::new(writer_id, vv1, hash, &write_keys),
822 Summary::INCOMPLETE,
823 RootNodeFilter::Published, )
825 .await,
826 Err(Error::OutdatedRootNode)
827 );
828
829 assert_matches!(
831 create(
832 &mut tx,
833 Proof::new(writer_id, vv0, hash, &write_keys),
834 Summary::INCOMPLETE,
835 RootNodeFilter::Any,
836 )
837 .await,
838 Err(Error::OutdatedRootNode)
839 );
840 }
841
842 mod load_all_latest_preferred {
844 use super::*;
845 use crate::protocol::SnapshotId;
846 use proptest::{arbitrary::any, collection::vec, sample::select, strategy::Strategy};
847 use test_strategy::proptest;
848
849 #[proptest(async = "tokio")]
850 async fn proptest(
851 write_keys: Keypair,
852 #[strategy(root_node_params_strategy())] input: Vec<(
853 SnapshotId,
854 PublicKey,
855 Hash,
856 NodeState,
857 )>,
858 ) {
859 case(write_keys, input).await
860 }
861
862 async fn case(write_keys: Keypair, input: Vec<(SnapshotId, PublicKey, Hash, NodeState)>) {
863 let (_base_dir, pool) = setup().await;
864
865 let mut writer_ids: Vec<_> = input
866 .iter()
867 .map(|(_, writer_id, _, _)| *writer_id)
868 .collect();
869 writer_ids.sort();
870 writer_ids.dedup();
871
872 let mut expected: Vec<_> = writer_ids
873 .into_iter()
874 .filter_map(|this_writer_id| {
875 input
876 .iter()
877 .filter(|(_, that_writer_id, _, _)| *that_writer_id == this_writer_id)
878 .map(|(snapshot_id, _, _, state)| (*snapshot_id, *state))
879 .max_by_key(|(snapshot_id, state)| {
880 (
881 match state {
882 NodeState::Approved => 3,
883 NodeState::Complete => 2,
884 NodeState::Incomplete => 1,
885 NodeState::Rejected => 0,
886 },
887 *snapshot_id,
888 )
889 })
890 .map(|(snapshot_id, state)| (this_writer_id, snapshot_id, state))
891 })
892 .collect();
893 expected.sort_by_key(|(writer_id, _, _)| *writer_id);
894
895 let mut vv = VersionVector::default();
896 let mut tx = pool.begin_write().await.unwrap();
897
898 for (expected_snapshot_id, writer_id, hash, state) in input {
899 vv.increment(writer_id);
900
901 let (node, _) = create(
902 &mut tx,
903 Proof::new(writer_id, vv.clone(), hash, &write_keys),
904 Summary {
905 state,
906 block_presence: MultiBlockPresence::None,
907 },
908 RootNodeFilter::Any,
909 )
910 .await
911 .unwrap();
912
913 assert_eq!(node.snapshot_id, expected_snapshot_id);
914 }
915
916 let mut actual: Vec<_> = load_all_latest_preferred(&mut tx)
917 .map_ok(|node| (node.proof.writer_id, node.snapshot_id, node.summary.state))
918 .try_collect()
919 .await
920 .unwrap();
921 actual.sort_by_key(|(writer_id, _, _)| *writer_id);
922
923 assert_eq!(actual, expected);
924
925 drop(tx);
926 pool.close().await.unwrap();
927 }
928
929 fn root_node_params_strategy()
930 -> impl Strategy<Value = Vec<(SnapshotId, PublicKey, Hash, NodeState)>> {
931 vec(any::<PublicKey>(), 1..=3)
932 .prop_flat_map(|writer_ids| {
933 vec(
934 (select(writer_ids), any::<Hash>(), any::<NodeState>()),
935 0..=32,
936 )
937 })
938 .prop_map(|params| {
939 params
940 .into_iter()
941 .enumerate()
942 .map(|(index, (writer_id, hash, state))| {
943 ((index + 1) as u32, writer_id, hash, state)
944 })
945 .collect()
946 })
947 }
948 }
949
950 async fn setup() -> (TempDir, db::Pool) {
951 db::create_temp().await.unwrap()
952 }
953}