use super::error::Error;
use crate::{
crypto::{sign::PublicKey, Hash},
db,
debug::DebugPrinter,
protocol::{
BlockId, MultiBlockPresence, NodeState, Proof, RootNode, RootNodeFilter, RootNodeKind,
SingleBlockPresence, Summary,
},
version_vector::VersionVector,
};
use futures_util::{Stream, StreamExt, TryStreamExt};
use sqlx::{sqlite::SqliteRow, FromRow, Row};
use std::{cmp::Ordering, fmt, future};
#[derive(PartialEq, Eq, Debug)]
pub(crate) enum RootNodeStatus {
NewSnapshot,
NewBlocks,
Outdated,
}
impl RootNodeStatus {
pub fn request_children(&self) -> bool {
match self {
Self::NewSnapshot | Self::NewBlocks => true,
Self::Outdated => false,
}
}
pub fn write(&self) -> bool {
match self {
Self::NewSnapshot => true,
Self::NewBlocks | Self::Outdated => false,
}
}
}
impl fmt::Display for RootNodeStatus {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::NewSnapshot => write!(f, "new snapshot"),
Self::NewBlocks => write!(f, "new blocks"),
Self::Outdated => write!(f, "outdated"),
}
}
}
pub(super) async fn create(
tx: &mut db::WriteTransaction,
proof: Proof,
mut summary: Summary,
filter: RootNodeFilter,
) -> Result<(RootNode, RootNodeKind), Error> {
let old_vv: VersionVector = sqlx::query(
"SELECT versions
FROM snapshot_root_nodes
WHERE snapshot_id = (
SELECT MAX(snapshot_id)
FROM snapshot_root_nodes
WHERE writer_id = ?
)",
)
.bind(&proof.writer_id)
.map(|row| row.get(0))
.fetch_optional(&mut *tx)
.await?
.unwrap_or_else(VersionVector::new);
let kind = match (proof.version_vector.partial_cmp(&old_vv), filter) {
(Some(Ordering::Greater), _) => RootNodeKind::Published,
(Some(Ordering::Equal), RootNodeFilter::Any) => RootNodeKind::Draft,
(Some(Ordering::Equal), RootNodeFilter::Published) => return Err(Error::OutdatedRootNode),
(Some(Ordering::Less), _) => return Err(Error::OutdatedRootNode),
(None, _) => return Err(Error::ConcurrentRootNode),
};
if summary.state == NodeState::Incomplete {
let state = sqlx::query(
"SELECT state FROM snapshot_root_nodes WHERE hash = ? AND state <> ? LIMIT 1",
)
.bind(&proof.hash)
.bind(NodeState::Incomplete)
.fetch_optional(&mut *tx)
.await?
.map(|row| row.get(0));
if let Some(state) = state {
summary.state = state;
}
}
let snapshot_id = sqlx::query(
"INSERT INTO snapshot_root_nodes (
writer_id,
versions,
hash,
signature,
state,
block_presence
)
VALUES (?, ?, ?, ?, ?, ?)
RETURNING snapshot_id",
)
.bind(&proof.writer_id)
.bind(&proof.version_vector)
.bind(&proof.hash)
.bind(&proof.signature)
.bind(summary.state)
.bind(&summary.block_presence)
.map(|row| row.get(0))
.fetch_one(tx)
.await?;
let node = RootNode {
snapshot_id,
proof,
summary,
};
Ok((node, kind))
}
pub(super) async fn load_latest_approved(
conn: &mut db::Connection,
branch_id: &PublicKey,
) -> Result<RootNode, Error> {
sqlx::query_as(
"SELECT
snapshot_id,
writer_id,
versions,
hash,
signature,
state,
block_presence
FROM
snapshot_root_nodes
WHERE
snapshot_id = (
SELECT MAX(snapshot_id)
FROM snapshot_root_nodes
WHERE writer_id = ? AND state = ?
)
",
)
.bind(branch_id)
.bind(NodeState::Approved)
.fetch_optional(conn)
.await?
.ok_or(Error::BranchNotFound)
}
pub(super) async fn load_prev_approved(
conn: &mut db::Connection,
node: &RootNode,
) -> Result<Option<RootNode>, Error> {
sqlx::query_as(
"SELECT
snapshot_id,
writer_id,
versions,
hash,
signature,
state,
block_presence
FROM snapshot_root_nodes
WHERE writer_id = ? AND state = ? AND snapshot_id < ?
ORDER BY snapshot_id DESC
LIMIT 1",
)
.bind(&node.proof.writer_id)
.bind(NodeState::Approved)
.bind(node.snapshot_id)
.fetch(conn)
.err_into()
.try_next()
.await
}
pub(super) fn load_all_latest_approved(
conn: &mut db::Connection,
) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
sqlx::query_as(
"SELECT
snapshot_id,
writer_id,
versions,
hash,
signature,
state,
block_presence
FROM
snapshot_root_nodes
WHERE
snapshot_id IN (
SELECT MAX(snapshot_id)
FROM snapshot_root_nodes
WHERE state = ?
GROUP BY writer_id
)",
)
.bind(NodeState::Approved)
.fetch(conn)
.err_into()
}
pub(super) fn load_all_latest_preferred(
conn: &mut db::Connection,
) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
sqlx::query_as(
"SELECT
snapshot_id,
writer_id,
versions,
hash,
signature,
state,
block_presence
FROM (
SELECT
*,
ROW_NUMBER() OVER (
PARTITION BY writer_id
ORDER BY
CASE state
WHEN ? THEN 0
WHEN ? THEN 1
WHEN ? THEN 2
WHEN ? THEN 3
END,
snapshot_id DESC
) AS position
FROM snapshot_root_nodes
)
WHERE position = 1",
)
.bind(NodeState::Approved)
.bind(NodeState::Complete)
.bind(NodeState::Incomplete)
.bind(NodeState::Rejected)
.fetch(conn)
.err_into()
}
pub(super) fn load_all_latest(
conn: &mut db::Connection,
) -> impl Stream<Item = Result<RootNode, Error>> + '_ {
sqlx::query_as(
"SELECT
snapshot_id,
writer_id,
versions,
hash,
signature,
state,
block_presence
FROM
snapshot_root_nodes
WHERE
snapshot_id IN (
SELECT MAX(snapshot_id)
FROM snapshot_root_nodes
GROUP BY writer_id
)",
)
.fetch(conn)
.err_into()
}
pub(super) fn load_all_by_hash<'a>(
conn: &'a mut db::Connection,
hash: &'a Hash,
) -> impl Stream<Item = Result<RootNode, Error>> + 'a {
sqlx::query_as(
"SELECT
snapshot_id,
writer_id,
versions,
hash,
signature,
state,
block_presence
FROM snapshot_root_nodes
WHERE hash = ?",
)
.bind(hash)
.fetch(conn)
.err_into()
}
pub(super) async fn load_node_state_of_missing(
conn: &mut db::Connection,
block_id: &BlockId,
) -> Result<NodeState, Error> {
use NodeState as S;
sqlx::query(
"WITH RECURSIVE
inner_nodes(parent) AS (
SELECT parent
FROM snapshot_leaf_nodes
WHERE block_id = ? AND block_presence = ?
UNION ALL
SELECT i.parent
FROM snapshot_inner_nodes i INNER JOIN inner_nodes c
WHERE i.hash = c.parent
)
SELECT state
FROM snapshot_root_nodes r INNER JOIN inner_nodes c
WHERE r.hash = c.parent",
)
.bind(block_id)
.bind(SingleBlockPresence::Missing)
.fetch(conn)
.map_ok(|row| row.get(0))
.err_into()
.try_fold(S::Rejected, |old, new| {
let new = match (old, new) {
(S::Incomplete | S::Complete | S::Approved | S::Rejected, S::Approved)
| (S::Approved, S::Incomplete | S::Complete | S::Rejected) => S::Approved,
(S::Incomplete | S::Complete | S::Rejected, S::Complete)
| (S::Complete, S::Incomplete | S::Rejected) => S::Complete,
(S::Incomplete | S::Rejected, S::Incomplete) | (S::Incomplete, S::Rejected) => {
S::Incomplete
}
(S::Rejected, S::Rejected) => S::Rejected,
};
future::ready(Ok(new))
})
.await
}
pub(super) async fn remove(tx: &mut db::WriteTransaction, node: &RootNode) -> Result<(), Error> {
sqlx::query("DELETE FROM snapshot_root_nodes WHERE snapshot_id = ?")
.bind(node.snapshot_id)
.execute(tx)
.await?;
Ok(())
}
pub(super) async fn remove_older(
tx: &mut db::WriteTransaction,
node: &RootNode,
) -> Result<(), Error> {
sqlx::query("DELETE FROM snapshot_root_nodes WHERE snapshot_id < ? AND writer_id = ?")
.bind(node.snapshot_id)
.bind(&node.proof.writer_id)
.execute(tx)
.await?;
Ok(())
}
pub(super) async fn remove_older_incomplete(
tx: &mut db::WriteTransaction,
node: &RootNode,
) -> Result<(), Error> {
sqlx::query(
"DELETE FROM snapshot_root_nodes
WHERE snapshot_id < ? AND writer_id = ? AND state IN (?, ?)",
)
.bind(node.snapshot_id)
.bind(&node.proof.writer_id)
.bind(NodeState::Incomplete)
.bind(NodeState::Rejected)
.execute(tx)
.await?;
Ok(())
}
pub(super) async fn update_summaries(
tx: &mut db::WriteTransaction,
hash: &Hash,
summary: Summary,
) -> Result<NodeState, Error> {
let state = sqlx::query(
"UPDATE snapshot_root_nodes
SET block_presence = ?,
state = CASE state WHEN ? THEN ? ELSE state END
WHERE hash = ?
RETURNING state
",
)
.bind(&summary.block_presence)
.bind(NodeState::Incomplete)
.bind(summary.state)
.bind(hash)
.fetch_optional(tx)
.await?
.map(|row| row.get(0))
.unwrap_or(NodeState::Incomplete);
Ok(state)
}
pub(super) async fn check_fallback(
conn: &mut db::Connection,
old: &RootNode,
new: &RootNode,
) -> Result<bool, Error> {
Ok(sqlx::query(
"WITH RECURSIVE
inner_nodes_old(hash) AS (
SELECT i.hash
FROM snapshot_inner_nodes AS i
INNER JOIN snapshot_root_nodes AS r ON r.hash = i.parent
WHERE r.snapshot_id = ?
UNION ALL
SELECT c.hash
FROM snapshot_inner_nodes AS c
INNER JOIN inner_nodes_old AS p ON p.hash = c.parent
),
inner_nodes_new(hash) AS (
SELECT i.hash
FROM snapshot_inner_nodes AS i
INNER JOIN snapshot_root_nodes AS r ON r.hash = i.parent
WHERE r.snapshot_id = ?
UNION ALL
SELECT c.hash
FROM snapshot_inner_nodes AS c
INNER JOIN inner_nodes_new AS p ON p.hash = c.parent
)
SELECT locator
FROM snapshot_leaf_nodes
WHERE block_presence = ? AND parent IN inner_nodes_old
INTERSECT
SELECT locator
FROM snapshot_leaf_nodes
WHERE block_presence = ? AND parent IN inner_nodes_new
LIMIT 1",
)
.bind(old.snapshot_id)
.bind(new.snapshot_id)
.bind(SingleBlockPresence::Present)
.bind(SingleBlockPresence::Missing)
.fetch_optional(conn)
.await?
.is_some())
}
pub(super) fn approve<'a>(
tx: &'a mut db::WriteTransaction,
hash: &'a Hash,
) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
set_state(tx, hash, NodeState::Approved)
}
pub(super) fn reject<'a>(
tx: &'a mut db::WriteTransaction,
hash: &'a Hash,
) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
set_state(tx, hash, NodeState::Rejected)
}
fn set_state<'a>(
tx: &'a mut db::WriteTransaction,
hash: &'a Hash,
state: NodeState,
) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
sqlx::query("UPDATE snapshot_root_nodes SET state = ? WHERE hash = ? RETURNING writer_id")
.bind(state)
.bind(hash)
.fetch(tx)
.map_ok(|row| row.get(0))
.err_into()
}
pub(super) fn load_writer_ids(
conn: &mut db::Connection,
) -> impl Stream<Item = Result<PublicKey, Error>> + '_ {
sqlx::query("SELECT DISTINCT writer_id FROM snapshot_root_nodes")
.fetch(conn)
.map_ok(|row| row.get(0))
.err_into()
}
pub(super) fn load_writer_ids_by_hash<'a>(
conn: &'a mut db::Connection,
hash: &'a Hash,
) -> impl Stream<Item = Result<PublicKey, Error>> + 'a {
sqlx::query("SELECT DISTINCT writer_id FROM snapshot_root_nodes WHERE hash = ?")
.bind(hash)
.fetch(conn)
.map_ok(|row| row.get(0))
.err_into()
}
pub(super) async fn status(
conn: &mut db::Connection,
new_proof: &Proof,
new_block_presence: &MultiBlockPresence,
) -> Result<RootNodeStatus, Error> {
let mut status = RootNodeStatus::NewSnapshot;
let mut old_nodes = load_all_latest(conn);
while let Some(old_node) = old_nodes.try_next().await? {
match new_proof
.version_vector
.partial_cmp(&old_node.proof.version_vector)
{
Some(Ordering::Less) => {
status = RootNodeStatus::Outdated;
}
Some(Ordering::Equal) => {
if new_proof.hash == old_node.proof.hash {
if old_node
.summary
.block_presence
.is_outdated(new_block_presence)
{
status = RootNodeStatus::NewBlocks;
} else {
status = RootNodeStatus::Outdated;
}
} else {
tracing::warn!(
vv = ?old_node.proof.version_vector,
old_writer_id = ?old_node.proof.writer_id,
new_writer_id = ?new_proof.writer_id,
old_hash = ?old_node.proof.hash,
new_hash = ?new_proof.hash,
"Received root node invalid - broken invariant: same vv but different hash"
);
status = RootNodeStatus::Outdated;
}
}
Some(Ordering::Greater) => (),
None => {
if new_proof.writer_id == old_node.proof.writer_id {
tracing::warn!(
old_vv = ?old_node.proof.version_vector,
new_vv = ?new_proof.version_vector,
writer_id = ?new_proof.writer_id,
"Received root node invalid - broken invariant: concurrency within branch is not allowed"
);
status = RootNodeStatus::Outdated;
}
}
}
if matches!(status, RootNodeStatus::Outdated) {
break;
}
}
Ok(status)
}
pub(super) async fn debug_print(conn: &mut db::Connection, printer: DebugPrinter) {
let mut roots = sqlx::query_as::<_, RootNode>(
"SELECT
snapshot_id,
writer_id,
versions,
hash,
signature,
state,
block_presence
FROM snapshot_root_nodes
ORDER BY snapshot_id DESC",
)
.fetch(conn);
while let Some(root_node) = roots.next().await {
match root_node {
Ok(root_node) => {
printer.debug(&format_args!(
"RootNode: snapshot_id:{:?}, writer_id:{:?}, vv:{:?}, state:{:?}",
root_node.snapshot_id,
root_node.proof.writer_id,
root_node.proof.version_vector,
root_node.summary.state
));
}
Err(err) => {
printer.debug(&format_args!("RootNode: error: {:?}", err));
}
}
}
}
#[cfg(test)]
pub(super) fn load_all_by_writer<'a>(
conn: &'a mut db::Connection,
writer_id: &'a PublicKey,
) -> impl Stream<Item = Result<RootNode, Error>> + 'a {
sqlx::query_as(
"SELECT
snapshot_id,
writer_id,
versions,
hash,
signature,
state,
block_presence
FROM snapshot_root_nodes
WHERE writer_id = ?
ORDER BY snapshot_id DESC",
)
.bind(writer_id) .fetch(conn)
.err_into()
}
impl FromRow<'_, SqliteRow> for RootNode {
fn from_row(row: &SqliteRow) -> Result<Self, sqlx::Error> {
Ok(RootNode {
snapshot_id: row.try_get(0)?,
proof: Proof::new_unchecked(
row.try_get(1)?,
row.try_get(2)?,
row.try_get(3)?,
row.try_get(4)?,
),
summary: Summary {
state: row.try_get(5)?,
block_presence: row.try_get(6)?,
},
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::sign::Keypair;
use assert_matches::assert_matches;
use tempfile::TempDir;
#[tokio::test]
async fn create_new() {
let (_base_dir, pool) = setup().await;
let writer_id = PublicKey::random();
let write_keys = Keypair::random();
let hash = rand::random();
let mut tx = pool.begin_write().await.unwrap();
let (node0, _) = create(
&mut tx,
Proof::new(
writer_id,
VersionVector::first(writer_id),
hash,
&write_keys,
),
Summary::INCOMPLETE,
RootNodeFilter::Any,
)
.await
.unwrap();
assert_eq!(node0.proof.hash, hash);
let nodes: Vec<_> = load_all_by_writer(&mut tx, &writer_id)
.try_collect()
.await
.unwrap();
assert_eq!(nodes.len(), 1);
assert_eq!(nodes[0], node0);
}
#[tokio::test]
async fn create_draft() {
let (_base_dir, pool) = setup().await;
let writer_id = PublicKey::random();
let write_keys = Keypair::random();
let mut tx = pool.begin_write().await.unwrap();
let (node0, kind) = create(
&mut tx,
Proof::new(
writer_id,
VersionVector::first(writer_id),
rand::random(),
&write_keys,
),
Summary::INCOMPLETE,
RootNodeFilter::Any,
)
.await
.unwrap();
assert_eq!(kind, RootNodeKind::Published);
let (_node1, kind) = create(
&mut tx,
Proof::new(
writer_id,
node0.proof.version_vector.clone(),
rand::random(),
&write_keys,
),
Summary::INCOMPLETE,
RootNodeFilter::Any,
)
.await
.unwrap();
assert_eq!(kind, RootNodeKind::Draft);
}
#[tokio::test]
async fn attempt_to_create_outdated_node() {
let (_base_dir, pool) = setup().await;
let writer_id = PublicKey::random();
let write_keys = Keypair::random();
let hash = rand::random();
let mut tx = pool.begin_write().await.unwrap();
let vv0 = VersionVector::first(writer_id);
let vv1 = vv0.clone().incremented(writer_id);
create(
&mut tx,
Proof::new(writer_id, vv1.clone(), hash, &write_keys),
Summary::INCOMPLETE,
RootNodeFilter::Any,
)
.await
.unwrap();
assert_matches!(
create(
&mut tx,
Proof::new(writer_id, vv1, hash, &write_keys),
Summary::INCOMPLETE,
RootNodeFilter::Published, )
.await,
Err(Error::OutdatedRootNode)
);
assert_matches!(
create(
&mut tx,
Proof::new(writer_id, vv0, hash, &write_keys),
Summary::INCOMPLETE,
RootNodeFilter::Any,
)
.await,
Err(Error::OutdatedRootNode)
);
}
mod load_all_latest_preferred {
use super::*;
use crate::protocol::SnapshotId;
use proptest::{arbitrary::any, collection::vec, sample::select, strategy::Strategy};
use test_strategy::proptest;
#[proptest]
fn proptest(
write_keys: Keypair,
#[strategy(root_node_params_strategy())] input: Vec<(
SnapshotId,
PublicKey,
Hash,
NodeState,
)>,
) {
crate::test_utils::run(case(write_keys, input))
}
async fn case(write_keys: Keypair, input: Vec<(SnapshotId, PublicKey, Hash, NodeState)>) {
let (_base_dir, pool) = setup().await;
let mut writer_ids: Vec<_> = input
.iter()
.map(|(_, writer_id, _, _)| *writer_id)
.collect();
writer_ids.sort();
writer_ids.dedup();
let mut expected: Vec<_> = writer_ids
.into_iter()
.filter_map(|this_writer_id| {
input
.iter()
.filter(|(_, that_writer_id, _, _)| *that_writer_id == this_writer_id)
.map(|(snapshot_id, _, _, state)| (*snapshot_id, *state))
.max_by_key(|(snapshot_id, state)| {
(
match state {
NodeState::Approved => 3,
NodeState::Complete => 2,
NodeState::Incomplete => 1,
NodeState::Rejected => 0,
},
*snapshot_id,
)
})
.map(|(snapshot_id, state)| (this_writer_id, snapshot_id, state))
})
.collect();
expected.sort_by_key(|(writer_id, _, _)| *writer_id);
let mut vv = VersionVector::default();
let mut tx = pool.begin_write().await.unwrap();
for (expected_snapshot_id, writer_id, hash, state) in input {
vv.increment(writer_id);
let (node, _) = create(
&mut tx,
Proof::new(writer_id, vv.clone(), hash, &write_keys),
Summary {
state,
block_presence: MultiBlockPresence::None,
},
RootNodeFilter::Any,
)
.await
.unwrap();
assert_eq!(node.snapshot_id, expected_snapshot_id);
}
let mut actual: Vec<_> = load_all_latest_preferred(&mut tx)
.map_ok(|node| (node.proof.writer_id, node.snapshot_id, node.summary.state))
.try_collect()
.await
.unwrap();
actual.sort_by_key(|(writer_id, _, _)| *writer_id);
assert_eq!(actual, expected);
drop(tx);
pool.close().await.unwrap();
}
fn root_node_params_strategy(
) -> impl Strategy<Value = Vec<(SnapshotId, PublicKey, Hash, NodeState)>> {
vec(any::<PublicKey>(), 1..=3)
.prop_flat_map(|writer_ids| {
vec(
(select(writer_ids), any::<Hash>(), any::<NodeState>()),
0..=32,
)
})
.prop_map(|params| {
params
.into_iter()
.enumerate()
.map(|(index, (writer_id, hash, state))| {
((index + 1) as u32, writer_id, hash, state)
})
.collect()
})
}
}
async fn setup() -> (TempDir, db::Pool) {
db::create_temp().await.unwrap()
}
}