ouisync/network/
connection.rs

1use super::{
2    peer_addr::PeerAddr,
3    peer_info::PeerInfo,
4    peer_source::{ConnectionDirection, PeerSource},
5    peer_state::PeerState,
6    runtime_id::PublicRuntimeId,
7    stats::{ByteCounters, StatsTracker},
8};
9use crate::{
10    collections::{hash_map::Entry, HashMap},
11    sync::{AwaitDrop, DropAwaitable, WatchSenderExt},
12};
13use std::{
14    fmt,
15    sync::{
16        atomic::{AtomicU64, Ordering},
17        Arc,
18    },
19    time::SystemTime,
20};
21use tokio::sync::watch;
22
23/// Container for known connections.
24pub(super) struct ConnectionSet {
25    connections: watch::Sender<HashMap<ConnectionKey, ConnectionData>>,
26}
27
28impl ConnectionSet {
29    pub fn new() -> Self {
30        Self {
31            connections: watch::Sender::new(HashMap::default()),
32        }
33    }
34
35    /// Attempt to reserve an connection to the given peer. If the connection hasn't been reserved
36    /// yet, it returns a `ConnectionPermit` which keeps the connection reserved as long as it
37    /// lives. Otherwise it returns `None`. To release a connection the permit needs to be dropped.
38    /// Also returns a notification object that can be used to wait until the permit gets released.
39    pub fn reserve(&self, addr: PeerAddr, source: PeerSource) -> ReserveResult {
40        let key = ConnectionKey {
41            addr,
42            dir: source.direction(),
43        };
44
45        self.connections
46            .send_if_modified_return(|connections| match connections.entry(key) {
47                Entry::Vacant(entry) => {
48                    let id = ConnectionId::next();
49
50                    entry.insert(ConnectionData {
51                        id,
52                        state: PeerState::Known,
53                        source,
54                        stats_tracker: StatsTracker::default(),
55                        on_release: DropAwaitable::new(),
56                    });
57
58                    (
59                        true,
60                        ReserveResult::Permit(ConnectionPermit {
61                            connections: self.connections.clone(),
62                            key,
63                            id,
64                        }),
65                    )
66                }
67                Entry::Occupied(entry) => {
68                    let peer_permit = entry.get();
69
70                    (
71                        false,
72                        ReserveResult::Occupied(
73                            peer_permit.on_release.subscribe(),
74                            peer_permit.source,
75                            peer_permit.id,
76                        ),
77                    )
78                }
79            })
80    }
81
82    pub fn peer_info_collector(&self) -> PeerInfoCollector {
83        PeerInfoCollector(self.connections.clone())
84    }
85
86    pub fn get_peer_info(&self, addr: PeerAddr) -> Option<PeerInfo> {
87        let connections = self.connections.borrow();
88
89        connections
90            .get(&ConnectionKey {
91                addr,
92                dir: ConnectionDirection::Incoming,
93            })
94            .or_else(|| {
95                connections.get(&ConnectionKey {
96                    addr,
97                    dir: ConnectionDirection::Outgoing,
98                })
99            })
100            .map(|data| data.peer_info(addr))
101    }
102
103    pub fn subscribe(&self) -> watch::Receiver<HashMap<ConnectionKey, ConnectionData>> {
104        self.connections.subscribe()
105    }
106}
107
108/// Unique identifier of a connection. Connections are mostly already identified by the peer address
109/// and direction (incoming / outgoing), but this type allows to distinguish even connections with
110/// the same address/direction but that were established in two separate occasions.
111#[derive(Clone, Copy, Eq, PartialEq, Debug)]
112#[repr(transparent)]
113pub(super) struct ConnectionId(u64);
114
115impl ConnectionId {
116    pub fn next() -> Self {
117        static NEXT: AtomicU64 = AtomicU64::new(0);
118        Self(NEXT.fetch_add(1, Ordering::Relaxed))
119    }
120}
121
122impl fmt::Display for ConnectionId {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        write!(f, "{}", self.0)
125    }
126}
127
128pub(super) enum ReserveResult {
129    Permit(ConnectionPermit),
130    // Use the receiver to get notified when the existing permit is destroyed.
131    Occupied(AwaitDrop, PeerSource, ConnectionId),
132}
133
134#[derive(Clone)]
135pub struct PeerInfoCollector(watch::Sender<HashMap<ConnectionKey, ConnectionData>>);
136
137impl PeerInfoCollector {
138    pub fn collect(&self) -> Vec<PeerInfo> {
139        self.0
140            .borrow()
141            .iter()
142            .map(|(key, data)| data.peer_info(key.addr))
143            .collect()
144    }
145}
146
147/// Connection permit that prevents another connection to the same peer (socket address) to be
148/// established as long as it remains in scope.
149pub(super) struct ConnectionPermit {
150    connections: watch::Sender<HashMap<ConnectionKey, ConnectionData>>,
151    key: ConnectionKey,
152    id: ConnectionId,
153}
154
155impl ConnectionPermit {
156    pub fn mark_as_connecting(&self) {
157        self.set_state(PeerState::Connecting);
158    }
159
160    pub fn mark_as_handshaking(&self) {
161        self.set_state(PeerState::Handshaking);
162    }
163
164    pub fn mark_as_active(&self, runtime_id: PublicRuntimeId) {
165        self.set_state(PeerState::Active {
166            id: runtime_id,
167            since: SystemTime::now(),
168        });
169    }
170
171    fn set_state(&self, new_state: PeerState) {
172        self.connections.send_if_modified(|connections| {
173            // unwrap is ok because if `self` exists then the entry should exists as well.
174            let peer = connections.get_mut(&self.key).unwrap();
175
176            if peer.state != new_state {
177                peer.state = new_state;
178                true
179            } else {
180                false
181            }
182        });
183    }
184
185    /// Returns a `AwaitDrop` that gets notified when this permit gets released.
186    pub fn released(&self) -> AwaitDrop {
187        // We can't use unwrap here because this method is used in `ConnectionPermitHalf` which can
188        // outlive the entry if the other half gets dropped.
189        self.with(|data| data.on_release.subscribe())
190            .unwrap_or_else(|| DropAwaitable::new().subscribe())
191    }
192
193    pub fn addr(&self) -> PeerAddr {
194        self.key.addr
195    }
196
197    pub fn id(&self) -> ConnectionId {
198        self.id
199    }
200
201    pub fn source(&self) -> PeerSource {
202        // unwrap is ok because if `self` exists then the entry should exists as well.
203        self.with(|data| data.source).unwrap()
204    }
205
206    pub fn byte_counters(&self) -> Arc<ByteCounters> {
207        self.with(|data| data.stats_tracker.bytes.clone())
208            .unwrap_or_default()
209    }
210
211    fn with<F, R>(&self, f: F) -> Option<R>
212    where
213        F: FnOnce(&ConnectionData) -> R,
214    {
215        self.connections.borrow().get(&self.key).map(f)
216    }
217}
218
219impl fmt::Debug for ConnectionPermit {
220    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
221        f.debug_struct("ConnectionPermit")
222            .field("key", &self.key)
223            .field("id", &self.id)
224            .finish_non_exhaustive()
225    }
226}
227
228impl Drop for ConnectionPermit {
229    fn drop(&mut self) {
230        self.connections.send_if_modified(|connections| {
231            let Entry::Occupied(entry) = connections.entry(self.key) else {
232                return false;
233            };
234
235            if entry.get().id != self.id {
236                return false;
237            }
238
239            entry.remove();
240            true
241        });
242    }
243}
244
245#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
246pub(super) struct ConnectionKey {
247    addr: PeerAddr,
248    dir: ConnectionDirection,
249}
250
251pub(super) struct ConnectionData {
252    id: ConnectionId,
253    state: PeerState,
254    source: PeerSource,
255    stats_tracker: StatsTracker,
256    on_release: DropAwaitable,
257}
258
259impl ConnectionData {
260    fn peer_info(&self, addr: PeerAddr) -> PeerInfo {
261        let stats = self.stats_tracker.read();
262
263        PeerInfo {
264            addr,
265            source: self.source,
266            state: self.state,
267            stats,
268        }
269    }
270}