ouisync/network/
event.rs

1use crate::collections::HashMap;
2use futures_util::{stream, Stream, StreamExt};
3use num_enum::{IntoPrimitive, TryFromPrimitive};
4use ouisync_macros::api;
5use serde::{Deserialize, Serialize};
6use std::{
7    pin::Pin,
8    task::{Context, Poll},
9};
10use tokio::{select, sync::watch};
11
12use super::{
13    connection::{ConnectionData, ConnectionKey},
14    protocol::{Version, VERSION},
15};
16
17pub(super) struct ProtocolVersions {
18    pub our: Version,
19    pub highest_seen: Version,
20}
21
22impl ProtocolVersions {
23    pub fn new() -> Self {
24        Self {
25            our: VERSION,
26            highest_seen: Version::ZERO,
27        }
28    }
29}
30
31/// Network notification event.
32#[derive(
33    Clone, Copy, Eq, PartialEq, Debug, Serialize, Deserialize, TryFromPrimitive, IntoPrimitive,
34)]
35#[repr(u8)]
36#[serde(into = "u8", try_from = "u8")]
37#[api]
38pub enum NetworkEvent {
39    /// A peer has appeared with higher protocol version than us. Probably means we are using
40    /// outdated library. This event can be used to notify the user that they should update the app.
41    ProtocolVersionMismatch = 0,
42    /// The set of known peers has changed (e.g., a new peer has been discovered)
43    PeerSetChange = 1,
44}
45
46pub struct NetworkEventReceiver {
47    protocol_versions: watch::Receiver<ProtocolVersions>,
48    connections: watch::Receiver<HashMap<ConnectionKey, ConnectionData>>,
49    highest_seen: Version,
50}
51
52impl NetworkEventReceiver {
53    pub(super) fn new(
54        protocol_versions: watch::Receiver<ProtocolVersions>,
55        connections: watch::Receiver<HashMap<ConnectionKey, ConnectionData>>,
56    ) -> Self {
57        Self {
58            protocol_versions,
59            connections,
60            highest_seen: Version::ZERO,
61        }
62    }
63
64    pub async fn recv(&mut self) -> Option<NetworkEvent> {
65        select! {
66            Ok(versions) = self.protocol_versions.wait_for(|versions| {
67                    versions.highest_seen > self.highest_seen && versions.highest_seen > versions.our
68            }) => {
69                self.highest_seen = versions.highest_seen;
70                Some(NetworkEvent::ProtocolVersionMismatch)
71            }
72            Ok(_) = self.connections.changed() => Some(NetworkEvent::PeerSetChange),
73            else => None,
74        }
75    }
76}
77
78/// Wrapper around `NetworkEventReceiver` that implements `Stream`.
79pub struct NetworkEventStream {
80    inner: Pin<Box<dyn Stream<Item = NetworkEvent> + Send + 'static>>,
81}
82
83impl NetworkEventStream {
84    pub fn new(rx: NetworkEventReceiver) -> Self {
85        Self {
86            inner: Box::pin(stream::unfold(rx, |mut rx| async move {
87                Some((rx.recv().await?, rx))
88            })),
89        }
90    }
91}
92
93impl Stream for NetworkEventStream {
94    type Item = NetworkEvent;
95
96    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
97        self.get_mut().inner.poll_next_unpin(cx)
98    }
99}