ouisync/network/
connection.rs

1use super::{
2    peer_addr::PeerAddr,
3    peer_info::PeerInfo,
4    peer_source::{ConnectionDirection, PeerSource},
5    peer_state::PeerState,
6    runtime_id::PublicRuntimeId,
7    stats::{ByteCounters, StatsTracker},
8};
9use crate::sync::{AwaitDrop, DropAwaitable, WatchSenderExt};
10use std::{
11    collections::{HashMap, hash_map::Entry},
12    fmt,
13    sync::{
14        Arc,
15        atomic::{AtomicU64, Ordering},
16    },
17    time::SystemTime,
18};
19use tokio::sync::watch;
20
21/// Container for known connections.
22pub(super) struct ConnectionSet {
23    connections: watch::Sender<HashMap<ConnectionKey, ConnectionData>>,
24}
25
26impl ConnectionSet {
27    pub fn new() -> Self {
28        Self {
29            connections: watch::Sender::new(HashMap::default()),
30        }
31    }
32
33    /// Attempt to reserve an connection to the given peer. If the connection hasn't been reserved
34    /// yet, it returns a `ConnectionPermit` which keeps the connection reserved as long as it
35    /// lives. Otherwise it returns `None`. To release a connection the permit needs to be dropped.
36    /// Also returns a notification object that can be used to wait until the permit gets released.
37    pub fn reserve(&self, addr: PeerAddr, source: PeerSource) -> ReserveResult {
38        let key = ConnectionKey {
39            addr,
40            dir: source.direction(),
41        };
42
43        self.connections
44            .send_if_modified_return(|connections| match connections.entry(key) {
45                Entry::Vacant(entry) => {
46                    let id = ConnectionId::next();
47
48                    entry.insert(ConnectionData {
49                        id,
50                        state: PeerState::Known,
51                        source,
52                        stats_tracker: StatsTracker::default(),
53                        on_release: DropAwaitable::new(),
54                    });
55
56                    (
57                        true,
58                        ReserveResult::Permit(ConnectionPermit {
59                            connections: self.connections.clone(),
60                            key,
61                            id,
62                        }),
63                    )
64                }
65                Entry::Occupied(entry) => {
66                    let peer_permit = entry.get();
67
68                    (
69                        false,
70                        ReserveResult::Occupied(
71                            peer_permit.on_release.subscribe(),
72                            peer_permit.source,
73                            peer_permit.id,
74                        ),
75                    )
76                }
77            })
78    }
79
80    pub fn peer_info_collector(&self) -> PeerInfoCollector {
81        PeerInfoCollector(self.connections.clone())
82    }
83
84    pub fn get_peer_info(&self, addr: PeerAddr) -> Option<PeerInfo> {
85        let connections = self.connections.borrow();
86
87        connections
88            .get(&ConnectionKey {
89                addr,
90                dir: ConnectionDirection::Incoming,
91            })
92            .or_else(|| {
93                connections.get(&ConnectionKey {
94                    addr,
95                    dir: ConnectionDirection::Outgoing,
96                })
97            })
98            .map(|data| data.peer_info(addr))
99    }
100
101    pub fn subscribe(&self) -> watch::Receiver<HashMap<ConnectionKey, ConnectionData>> {
102        self.connections.subscribe()
103    }
104}
105
106/// Unique identifier of a connection. Connections are mostly already identified by the peer address
107/// and direction (incoming / outgoing), but this type allows to distinguish even connections with
108/// the same address/direction but that were established in two separate occasions.
109#[derive(Clone, Copy, Eq, PartialEq, Debug)]
110#[repr(transparent)]
111pub(super) struct ConnectionId(u64);
112
113impl ConnectionId {
114    pub fn next() -> Self {
115        static NEXT: AtomicU64 = AtomicU64::new(0);
116        Self(NEXT.fetch_add(1, Ordering::Relaxed))
117    }
118}
119
120impl fmt::Display for ConnectionId {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        write!(f, "{}", self.0)
123    }
124}
125
126pub(super) enum ReserveResult {
127    Permit(ConnectionPermit),
128    // Use the receiver to get notified when the existing permit is destroyed.
129    Occupied(AwaitDrop, PeerSource, ConnectionId),
130}
131
132#[derive(Clone)]
133pub struct PeerInfoCollector(watch::Sender<HashMap<ConnectionKey, ConnectionData>>);
134
135impl PeerInfoCollector {
136    pub fn collect(&self) -> Vec<PeerInfo> {
137        self.0
138            .borrow()
139            .iter()
140            .map(|(key, data)| data.peer_info(key.addr))
141            .collect()
142    }
143}
144
145/// Connection permit that prevents another connection to the same peer (socket address) to be
146/// established as long as it remains in scope.
147pub(super) struct ConnectionPermit {
148    connections: watch::Sender<HashMap<ConnectionKey, ConnectionData>>,
149    key: ConnectionKey,
150    id: ConnectionId,
151}
152
153impl ConnectionPermit {
154    pub fn mark_as_connecting(&self) {
155        self.set_state(PeerState::Connecting);
156    }
157
158    pub fn mark_as_handshaking(&self) {
159        self.set_state(PeerState::Handshaking);
160    }
161
162    pub fn mark_as_active(&self, runtime_id: PublicRuntimeId) {
163        self.set_state(PeerState::Active {
164            id: runtime_id,
165            since: SystemTime::now(),
166        });
167    }
168
169    fn set_state(&self, new_state: PeerState) {
170        self.connections.send_if_modified(|connections| {
171            // unwrap is ok because if `self` exists then the entry should exists as well.
172            let peer = connections.get_mut(&self.key).unwrap();
173
174            if peer.state != new_state {
175                peer.state = new_state;
176                true
177            } else {
178                false
179            }
180        });
181    }
182
183    /// Returns a `AwaitDrop` that gets notified when this permit gets released.
184    pub fn released(&self) -> AwaitDrop {
185        // We can't use unwrap here because this method is used in `ConnectionPermitHalf` which can
186        // outlive the entry if the other half gets dropped.
187        self.with(|data| data.on_release.subscribe())
188            .unwrap_or_else(|| DropAwaitable::new().subscribe())
189    }
190
191    pub fn addr(&self) -> PeerAddr {
192        self.key.addr
193    }
194
195    pub fn id(&self) -> ConnectionId {
196        self.id
197    }
198
199    pub fn source(&self) -> PeerSource {
200        // unwrap is ok because if `self` exists then the entry should exists as well.
201        self.with(|data| data.source).unwrap()
202    }
203
204    pub fn byte_counters(&self) -> Arc<ByteCounters> {
205        self.with(|data| data.stats_tracker.bytes.clone())
206            .unwrap_or_default()
207    }
208
209    fn with<F, R>(&self, f: F) -> Option<R>
210    where
211        F: FnOnce(&ConnectionData) -> R,
212    {
213        self.connections.borrow().get(&self.key).map(f)
214    }
215}
216
217impl fmt::Debug for ConnectionPermit {
218    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
219        f.debug_struct("ConnectionPermit")
220            .field("key", &self.key)
221            .field("id", &self.id)
222            .finish_non_exhaustive()
223    }
224}
225
226impl Drop for ConnectionPermit {
227    fn drop(&mut self) {
228        self.connections.send_if_modified(|connections| {
229            let Entry::Occupied(entry) = connections.entry(self.key) else {
230                return false;
231            };
232
233            if entry.get().id != self.id {
234                return false;
235            }
236
237            entry.remove();
238            true
239        });
240    }
241}
242
243#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
244pub(super) struct ConnectionKey {
245    addr: PeerAddr,
246    dir: ConnectionDirection,
247}
248
249pub(super) struct ConnectionData {
250    id: ConnectionId,
251    state: PeerState,
252    source: PeerSource,
253    stats_tracker: StatsTracker,
254    on_release: DropAwaitable,
255}
256
257impl ConnectionData {
258    fn peer_info(&self, addr: PeerAddr) -> PeerInfo {
259        let stats = self.stats_tracker.read();
260
261        PeerInfo {
262            addr,
263            source: self.source,
264            state: self.state,
265            stats,
266        }
267    }
268}