ouisync/network/
mod.rs

1mod choke;
2mod client;
3mod connection;
4mod connection_monitor;
5mod constants;
6mod crypto;
7mod debug_payload;
8mod dht_discovery;
9mod event;
10mod gateway;
11mod ip;
12mod local_discovery;
13mod message;
14mod message_broker;
15mod message_dispatcher;
16mod peer_addr;
17mod peer_exchange;
18mod peer_info;
19mod peer_source;
20mod peer_state;
21mod protocol;
22mod request_tracker;
23mod runtime_id;
24mod seen_peers;
25mod server;
26mod stats;
27mod stun;
28mod stun_server_list;
29#[cfg(test)]
30mod tests;
31mod upnp;
32
33pub use self::{
34    connection::PeerInfoCollector,
35    dht_discovery::{DhtContactsStoreTrait, DHT_ROUTERS},
36    event::{NetworkEvent, NetworkEventReceiver, NetworkEventStream},
37    peer_addr::PeerAddr,
38    peer_info::PeerInfo,
39    peer_source::PeerSource,
40    peer_state::PeerState,
41    runtime_id::{PublicRuntimeId, SecretRuntimeId},
42    stats::Stats,
43};
44use choke::Choker;
45use constants::REQUEST_TIMEOUT;
46use event::ProtocolVersions;
47pub use net::stun::NatBehavior;
48use request_tracker::RequestTracker;
49
50use self::{
51    connection::{ConnectionPermit, ConnectionSet, ReserveResult},
52    connection_monitor::ConnectionMonitor,
53    dht_discovery::DhtDiscovery,
54    gateway::{Connectivity, Gateway, StackAddresses},
55    local_discovery::LocalDiscovery,
56    message_broker::MessageBroker,
57    peer_addr::PeerPort,
58    peer_exchange::{PexDiscovery, PexRepository},
59    protocol::{Version, MAGIC, VERSION},
60    seen_peers::{SeenPeer, SeenPeers},
61    stats::{ByteCounters, StatsTracker},
62    stun::StunClients,
63};
64use crate::{
65    collections::HashSet,
66    network::connection::ConnectionDirection,
67    protocol::RepositoryId,
68    repository::{RepositoryHandle, Vault},
69};
70use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
71use btdht::{self, InfoHash, INFO_HASH_LEN};
72use deadlock::BlockingMutex;
73use futures_util::future;
74use net::unified::{Connection, ConnectionError};
75use scoped_task::ScopedAbortHandle;
76use slab::Slab;
77use state_monitor::StateMonitor;
78use std::{
79    io, mem,
80    net::{SocketAddr, SocketAddrV4, SocketAddrV6},
81    sync::{Arc, Weak},
82};
83use thiserror::Error;
84use tokio::{
85    io::{AsyncReadExt, AsyncWriteExt},
86    sync::{mpsc, watch},
87    task::{AbortHandle, JoinSet},
88    time::Duration,
89};
90use tracing::{Instrument, Span};
91
92pub struct Network {
93    inner: Arc<Inner>,
94    // We keep tasks here instead of in Inner because we want them to be
95    // destroyed when Network is Dropped.
96    _tasks: Arc<BlockingMutex<JoinSet<()>>>,
97}
98
99impl Network {
100    pub fn new(
101        monitor: StateMonitor,
102        dht_contacts: Option<Arc<dyn DhtContactsStoreTrait>>,
103        this_runtime_id: Option<SecretRuntimeId>,
104    ) -> Self {
105        let (incoming_tx, incoming_rx) = mpsc::channel(1);
106        let gateway = Gateway::new(incoming_tx);
107
108        // Note that we're now only using quic for the transport discovered over the dht.
109        // This is because the dht doesn't let us specify whether the remote peer SocketAddr is
110        // TCP, UDP or anything else.
111        // TODO: There are ways to address this: e.g. we could try both, or we could include
112        // the protocol information in the info-hash generation. There are pros and cons to
113        // these approaches.
114        let dht_discovery = DhtDiscovery::new(None, None, dht_contacts, monitor.make_child("DHT"));
115        // TODO: do we need unbounded channel here?
116        let (dht_discovery_tx, dht_discovery_rx) = mpsc::unbounded_channel();
117
118        let port_forwarder = upnp::PortForwarder::new(monitor.make_child("UPnP"));
119
120        let (pex_discovery_tx, pex_discovery_rx) = mpsc::channel(1);
121        let pex_discovery = PexDiscovery::new(pex_discovery_tx);
122
123        let user_provided_peers = SeenPeers::new();
124
125        let this_runtime_id = this_runtime_id.unwrap_or_else(SecretRuntimeId::random);
126        let this_runtime_id_public = this_runtime_id.public();
127
128        let connections_monitor = monitor.make_child("Connections");
129        let peers_monitor = monitor.make_child("Peers");
130
131        let tasks = Arc::new(BlockingMutex::new(JoinSet::new()));
132
133        let inner = Arc::new(Inner {
134            main_monitor: monitor,
135            connections_monitor,
136            peers_monitor,
137            span: Span::current(),
138            gateway,
139            this_runtime_id,
140            registry: BlockingMutex::new(Registry {
141                peers: Some(Slab::new()),
142                repos: Slab::new(),
143            }),
144            port_forwarder,
145            port_forwarder_state: BlockingMutex::new(ComponentState::disabled(
146                DisableReason::Explicit,
147            )),
148            local_discovery_state: BlockingMutex::new(ComponentState::disabled(
149                DisableReason::Explicit,
150            )),
151            dht_discovery,
152            dht_discovery_tx,
153            pex_discovery,
154            stun_clients: StunClients::new(),
155            connections: ConnectionSet::new(),
156            user_provided_peers,
157            tasks: Arc::downgrade(&tasks),
158            protocol_versions: watch::Sender::new(ProtocolVersions::new()),
159            our_addresses: BlockingMutex::new(HashSet::default()),
160            stats_tracker: StatsTracker::default(),
161        });
162
163        inner.spawn(inner.clone().handle_incoming_connections(incoming_rx));
164        inner.spawn(inner.clone().run_dht(dht_discovery_rx));
165        inner.spawn(inner.clone().run_peer_exchange(pex_discovery_rx));
166
167        tracing::debug!(this_runtime_id = ?this_runtime_id_public.as_public_key(), "Network created");
168
169        Self {
170            inner,
171            _tasks: tasks,
172        }
173    }
174
175    /// Binds the network to the specified addresses.
176    /// Rebinds if already bound. Unbinds and disables the network if `addrs` is empty.
177    ///
178    /// NOTE: currently at most one address per protocol (QUIC/TCP) and family (IPv4/IPv6) is used
179    /// and the rest are ignored, but this might change in the future.
180    pub async fn bind(&self, addrs: &[PeerAddr]) {
181        self.inner.bind(addrs).await
182    }
183
184    pub fn listener_local_addrs(&self) -> Vec<PeerAddr> {
185        self.inner.gateway.listener_local_addrs()
186    }
187
188    pub fn set_port_forwarding_enabled(&self, enabled: bool) {
189        let mut state = self.inner.port_forwarder_state.lock().unwrap();
190
191        if enabled {
192            if state.is_enabled() {
193                return;
194            }
195
196            state.enable(PortMappings::new(
197                &self.inner.port_forwarder,
198                &self.inner.gateway,
199            ));
200        } else {
201            state.disable(DisableReason::Explicit);
202        }
203    }
204
205    pub fn is_port_forwarding_enabled(&self) -> bool {
206        self.inner.port_forwarder_state.lock().unwrap().is_enabled()
207    }
208
209    pub fn set_local_discovery_enabled(&self, enabled: bool) {
210        let mut state = self.inner.local_discovery_state.lock().unwrap();
211
212        if enabled {
213            if state.is_enabled() {
214                return;
215            }
216
217            if let Some(handle) = self.inner.spawn_local_discovery() {
218                state.enable(handle.into());
219            } else {
220                state.disable(DisableReason::Implicit);
221            }
222        } else {
223            state.disable(DisableReason::Explicit);
224        }
225    }
226
227    pub fn is_local_discovery_enabled(&self) -> bool {
228        self.inner
229            .local_discovery_state
230            .lock()
231            .unwrap()
232            .is_enabled()
233    }
234
235    /// Sets whether sending contacts to other peer over peer exchange is enabled.
236    ///
237    /// Note: PEX sending for a given repo is enabled only if it's enabled globally using this
238    /// function and also for the repo using [Registration::set_pex_enabled].
239    pub fn set_pex_send_enabled(&self, enabled: bool) {
240        self.inner.pex_discovery.set_send_enabled(enabled)
241    }
242
243    pub fn is_pex_send_enabled(&self) -> bool {
244        self.inner.pex_discovery.is_send_enabled()
245    }
246
247    /// Sets whether receiving contacts over peer exchange is enabled.
248    ///
249    /// Note: PEX receiving for a given repo is enabled only if it's enabled globally using this
250    /// function and also for the repo using [Registration::set_pex_enabled].
251    pub fn set_pex_recv_enabled(&self, enabled: bool) {
252        self.inner.pex_discovery.set_recv_enabled(enabled)
253    }
254
255    pub fn is_pex_recv_enabled(&self) -> bool {
256        self.inner.pex_discovery.is_recv_enabled()
257    }
258    /// Find out external address using the STUN protocol.
259    /// Currently QUIC only.
260    pub async fn external_addr_v4(&self) -> Option<SocketAddrV4> {
261        self.inner.stun_clients.external_addr_v4().await
262    }
263
264    /// Find out external address using the STUN protocol.
265    /// Currently QUIC only.
266    pub async fn external_addr_v6(&self) -> Option<SocketAddrV6> {
267        self.inner.stun_clients.external_addr_v6().await
268    }
269
270    /// Determine the behaviour of the NAT we are behind. Returns `None` on unknown.
271    /// Currently IPv4 only.
272    pub async fn nat_behavior(&self) -> Option<NatBehavior> {
273        self.inner.stun_clients.nat_behavior().await
274    }
275
276    /// Get the network traffic stats.
277    pub fn stats(&self) -> Stats {
278        self.inner.stats_tracker.read()
279    }
280
281    pub fn add_user_provided_peer(&self, peer: &PeerAddr) {
282        self.inner.clone().establish_user_provided_connection(peer);
283    }
284
285    pub fn remove_user_provided_peer(&self, peer: &PeerAddr) {
286        self.inner.user_provided_peers.remove(peer)
287    }
288
289    pub fn this_runtime_id(&self) -> PublicRuntimeId {
290        self.inner.this_runtime_id.public()
291    }
292
293    pub fn peer_info_collector(&self) -> PeerInfoCollector {
294        self.inner.connections.peer_info_collector()
295    }
296
297    pub fn peer_info(&self, addr: PeerAddr) -> Option<PeerInfo> {
298        self.inner.connections.get_peer_info(addr)
299    }
300
301    pub fn current_protocol_version(&self) -> u64 {
302        self.inner.protocol_versions.borrow().our.into()
303    }
304
305    pub fn highest_seen_protocol_version(&self) -> u64 {
306        self.inner.protocol_versions.borrow().highest_seen.into()
307    }
308
309    /// Subscribe to network events.
310    pub fn subscribe(&self) -> NetworkEventReceiver {
311        NetworkEventReceiver::new(
312            self.inner.protocol_versions.subscribe(),
313            self.inner.connections.subscribe(),
314        )
315    }
316
317    /// Register a local repository into the network. This links the repository with all matching
318    /// repositories of currently connected remote replicas as well as any replicas connected in
319    /// the future. The repository is automatically deregistered when the returned handle is
320    /// dropped.
321    ///
322    /// Note: A repository should have at most one registration - creating more than one has
323    /// undesired effects. This is currently not enforced and so it's a responsibility of the
324    /// caller.
325    pub fn register(&self, handle: RepositoryHandle) -> Registration {
326        *handle.vault.monitor.info_hash.get() =
327            Some(repository_info_hash(handle.vault.repository_id()));
328
329        let pex = self.inner.pex_discovery.new_repository();
330
331        let request_tracker = RequestTracker::new(handle.vault.monitor.traffic.clone());
332        request_tracker.set_timeout(REQUEST_TIMEOUT);
333
334        // TODO: Should this be global instead of per repo?
335        let choker = Choker::new();
336
337        let stats_tracker = StatsTracker::default();
338
339        let mut registry = self.inner.registry.lock().unwrap();
340
341        registry.create_link(
342            handle.vault.clone(),
343            &pex,
344            &request_tracker,
345            &choker,
346            stats_tracker.bytes.clone(),
347        );
348
349        let key = registry.repos.insert(RegistrationHolder {
350            vault: handle.vault,
351            dht: None,
352            pex,
353            request_tracker,
354            choker,
355            stats_tracker,
356        });
357
358        Registration {
359            inner: self.inner.clone(),
360            key,
361        }
362    }
363
364    /// Gracefully disconnect from peers. Failing to call this function on app termination will
365    /// cause the peers to not learn that we disconnected just now. They will still find out later
366    /// once the keep-alive mechanism kicks in, but in the mean time we will not be able to
367    /// reconnect (by starting the app again) because the remote peer will keep dropping new
368    /// connections from us.
369    pub async fn shutdown(&self) {
370        // TODO: Would be a nice-to-have to also wait for all the spawned tasks here (e.g. dicovery
371        // mechanisms).
372        let Some(peers) = self.inner.registry.lock().unwrap().peers.take() else {
373            tracing::warn!("Network already shut down");
374            return;
375        };
376
377        shutdown_peers(peers).await;
378    }
379
380    /// Change the sync protocol request timeout. Useful mostly for testing and benchmarking as the
381    /// default value should be sufficient for most use cases.
382    pub fn set_request_timeout(&self, timeout: Duration) {
383        for (_, holder) in &self.inner.registry.lock().unwrap().repos {
384            holder.request_tracker.set_timeout(timeout);
385        }
386    }
387}
388
389pub struct Registration {
390    inner: Arc<Inner>,
391    key: usize,
392}
393
394impl Registration {
395    pub fn set_dht_enabled(&self, enabled: bool) {
396        let mut registry = self.inner.registry.lock().unwrap();
397        let holder = &mut registry.repos[self.key];
398
399        if enabled {
400            holder.dht = Some(
401                self.inner
402                    .start_dht_lookup(repository_info_hash(holder.vault.repository_id())),
403            );
404        } else {
405            holder.dht = None;
406        }
407    }
408
409    /// This function provides the information to the user whether DHT is enabled for this
410    /// repository, not necessarily whether the DHT tasks are currently running. The subtle
411    /// difference is in that this function should return true even in case e.g. the whole network
412    /// is disabled.
413    pub fn is_dht_enabled(&self) -> bool {
414        self.inner.registry.lock().unwrap().repos[self.key]
415            .dht
416            .is_some()
417    }
418
419    /// Enables/disables peer exchange for this repo.
420    ///
421    /// Note: sending/receiving over PEX for this repo is enabled only if it's enabled using this
422    /// function and also globally using [Network::set_pex_send_enabled] and/or
423    /// [Network::set_pex_recv_enabled].
424    pub fn set_pex_enabled(&self, enabled: bool) {
425        let registry = self.inner.registry.lock().unwrap();
426        registry.repos[self.key].pex.set_enabled(enabled);
427    }
428
429    pub fn is_pex_enabled(&self) -> bool {
430        self.inner.registry.lock().unwrap().repos[self.key]
431            .pex
432            .is_enabled()
433    }
434
435    /// Fetch per-repository network statistics.
436    pub fn stats(&self) -> Stats {
437        self.inner.registry.lock().unwrap().repos[self.key]
438            .stats_tracker
439            .read()
440    }
441}
442
443impl Drop for Registration {
444    fn drop(&mut self) {
445        let mut registry = self
446            .inner
447            .registry
448            .lock()
449            .unwrap_or_else(|error| error.into_inner());
450
451        if let Some(holder) = registry.repos.try_remove(self.key) {
452            for (_, peer) in registry.peers.as_mut().into_iter().flatten() {
453                peer.destroy_link(holder.vault.repository_id());
454            }
455        }
456    }
457}
458
459struct RegistrationHolder {
460    vault: Vault,
461    dht: Option<dht_discovery::LookupRequest>,
462    pex: PexRepository,
463    request_tracker: RequestTracker,
464    choker: Choker,
465    stats_tracker: StatsTracker,
466}
467
468struct Inner {
469    main_monitor: StateMonitor,
470    connections_monitor: StateMonitor,
471    peers_monitor: StateMonitor,
472    span: Span,
473    gateway: Gateway,
474    this_runtime_id: SecretRuntimeId,
475    registry: BlockingMutex<Registry>,
476    port_forwarder: upnp::PortForwarder,
477    port_forwarder_state: BlockingMutex<ComponentState<PortMappings>>,
478    local_discovery_state: BlockingMutex<ComponentState<ScopedAbortHandle>>,
479    dht_discovery: DhtDiscovery,
480    dht_discovery_tx: mpsc::UnboundedSender<SeenPeer>,
481    pex_discovery: PexDiscovery,
482    stun_clients: StunClients,
483    connections: ConnectionSet,
484    protocol_versions: watch::Sender<ProtocolVersions>,
485    user_provided_peers: SeenPeers,
486    // Note that unwrapping the upgraded weak pointer should be fine because if the underlying Arc
487    // was Dropped, we would not be asking for the upgrade in the first place.
488    tasks: Weak<BlockingMutex<JoinSet<()>>>,
489    // Used to prevent repeatedly connecting to self.
490    our_addresses: BlockingMutex<HashSet<PeerAddr>>,
491    stats_tracker: StatsTracker,
492}
493
494struct Registry {
495    // This is None once the network calls shutdown.
496    peers: Option<Slab<MessageBroker>>,
497    repos: Slab<RegistrationHolder>,
498}
499
500impl Registry {
501    fn create_link(
502        &mut self,
503        repo: Vault,
504        pex: &PexRepository,
505        request_tracker: &RequestTracker,
506        choker: &Choker,
507        byte_counters: Arc<ByteCounters>,
508    ) {
509        if let Some(peers) = &mut self.peers {
510            for (_, peer) in peers {
511                peer.create_link(
512                    repo.clone(),
513                    pex,
514                    request_tracker.clone(),
515                    choker.clone(),
516                    byte_counters.clone(),
517                )
518            }
519        }
520    }
521}
522
523impl Inner {
524    fn is_shutdown(&self) -> bool {
525        self.registry.lock().unwrap().peers.is_none()
526    }
527
528    async fn bind(self: &Arc<Self>, bind: &[PeerAddr]) {
529        let bind = StackAddresses::from(bind);
530
531        // TODO: Would be preferable to only rebind those stacks that actually need rebinding.
532        if !self.gateway.addresses().any_stack_needs_rebind(&bind) {
533            return;
534        }
535
536        // Gateway
537        let side_channel_makers = self.span.in_scope(|| self.gateway.bind(&bind));
538
539        let conn = self.gateway.connectivity();
540
541        let (side_channel_maker_v4, side_channel_maker_v6) = match conn {
542            Connectivity::Full => side_channel_makers,
543            Connectivity::LocalOnly | Connectivity::Disabled => (None, None),
544        };
545
546        // STUN
547        self.stun_clients.rebind(
548            side_channel_maker_v4.as_ref().map(|m| m.make()),
549            side_channel_maker_v6.as_ref().map(|m| m.make()),
550        );
551
552        // DHT
553        self.dht_discovery
554            .rebind(side_channel_maker_v4, side_channel_maker_v6);
555
556        // Port forwarding
557        match conn {
558            Connectivity::Full => {
559                let mut state = self.port_forwarder_state.lock().unwrap();
560                if !state.is_disabled(DisableReason::Explicit) {
561                    state.enable(PortMappings::new(&self.port_forwarder, &self.gateway));
562                }
563            }
564            Connectivity::LocalOnly | Connectivity::Disabled => {
565                self.port_forwarder_state
566                    .lock()
567                    .unwrap()
568                    .disable_if_enabled(DisableReason::Implicit);
569            }
570        }
571
572        // Local discovery
573        //
574        // Note: no need to check the Connectivity because local discovery depends only on whether
575        // Gateway is bound.
576        {
577            let mut state = self.local_discovery_state.lock().unwrap();
578            if !state.is_disabled(DisableReason::Explicit) {
579                if let Some(handle) = self.spawn_local_discovery() {
580                    state.enable(handle.into());
581                } else {
582                    state.disable(DisableReason::Implicit);
583                }
584            }
585        }
586
587        // - If we are disabling connectivity, disconnect from all existing peers.
588        // - If we are going from `Full` -> `LocalOnly`, also disconnect from all with the
589        //   assumption that the local ones will be subsequently re-established. Ideally we would
590        //   disconnect only the non-local ones to avoid the reconnect overhead, but the
591        //   implementation is simpler this way and the trade-off doesn't seem to be too bad.
592        // - If we are going to `Full`, keep all existing connections.
593        if matches!(conn, Connectivity::LocalOnly | Connectivity::Disabled) {
594            self.disconnect_all().await;
595        }
596    }
597
598    // Disconnect from all currently connected peers, regardless of their source.
599    async fn disconnect_all(&self) {
600        let Some(peers) = self.registry.lock().unwrap().peers.replace(Slab::default()) else {
601            return;
602        };
603
604        shutdown_peers(peers).await;
605    }
606
607    fn spawn_local_discovery(self: &Arc<Self>) -> Option<AbortHandle> {
608        let addrs = self.gateway.listener_local_addrs();
609        let tcp_port = addrs
610            .iter()
611            .find(|addr| matches!(addr, PeerAddr::Tcp(SocketAddr::V4(_))))
612            .map(|addr| PeerPort::Tcp(addr.port()));
613        let quic_port = addrs
614            .iter()
615            .find(|addr| matches!(addr, PeerAddr::Quic(SocketAddr::V4(_))))
616            .map(|addr| PeerPort::Quic(addr.port()));
617
618        // Arbitrary order of preference.
619        // TODO: Should we support all available?
620        let port = tcp_port.or(quic_port);
621
622        if let Some(port) = port {
623            Some(
624                self.spawn(
625                    self.clone()
626                        .run_local_discovery(port)
627                        .instrument(self.span.clone()),
628                ),
629            )
630        } else {
631            tracing::error!("Not enabling local discovery because there is no IPv4 listener");
632            None
633        }
634    }
635
636    async fn run_local_discovery(self: Arc<Self>, listener_port: PeerPort) {
637        let mut discovery = LocalDiscovery::new(
638            listener_port,
639            self.main_monitor.make_child("LocalDiscovery"),
640        );
641
642        loop {
643            let peer = discovery.recv().await;
644
645            if self.is_shutdown() {
646                break;
647            }
648
649            self.spawn(
650                self.clone()
651                    .handle_peer_found(peer, PeerSource::LocalDiscovery),
652            );
653        }
654    }
655
656    fn start_dht_lookup(&self, info_hash: InfoHash) -> dht_discovery::LookupRequest {
657        self.dht_discovery
658            .start_lookup(info_hash, self.dht_discovery_tx.clone())
659    }
660
661    async fn run_dht(self: Arc<Self>, mut discovery_rx: mpsc::UnboundedReceiver<SeenPeer>) {
662        while let Some(seen_peer) = discovery_rx.recv().await {
663            if self.is_shutdown() {
664                break;
665            }
666
667            self.spawn(self.clone().handle_peer_found(seen_peer, PeerSource::Dht));
668        }
669    }
670
671    async fn run_peer_exchange(self: Arc<Self>, mut discovery_rx: mpsc::Receiver<SeenPeer>) {
672        while let Some(peer) = discovery_rx.recv().await {
673            if self.is_shutdown() {
674                break;
675            }
676
677            self.spawn(
678                self.clone()
679                    .handle_peer_found(peer, PeerSource::PeerExchange),
680            );
681        }
682    }
683
684    fn establish_user_provided_connection(self: Arc<Self>, peer: &PeerAddr) {
685        let peer = match self.user_provided_peers.insert(*peer) {
686            Some(peer) => peer,
687            // Already in `user_provided_peers`.
688            None => return,
689        };
690
691        self.spawn(
692            self.clone()
693                .handle_peer_found(peer, PeerSource::UserProvided),
694        );
695    }
696
697    async fn handle_incoming_connections(
698        self: Arc<Self>,
699        mut rx: mpsc::Receiver<(Connection, PeerAddr)>,
700    ) {
701        while let Some((connection, addr)) = rx.recv().await {
702            match self.connections.reserve(addr, PeerSource::Listener) {
703                ReserveResult::Permit(permit) => {
704                    if self.is_shutdown() {
705                        break;
706                    }
707
708                    let this = self.clone();
709
710                    let monitor = self.span.in_scope(|| {
711                        ConnectionMonitor::new(
712                            &self.connections_monitor,
713                            &permit.addr(),
714                            permit.source(),
715                        )
716                    });
717                    monitor.mark_as_connecting(permit.id());
718
719                    self.spawn(async move {
720                        this.handle_connection(connection, permit, &monitor).await;
721                    });
722                }
723                ReserveResult::Occupied(_, _their_source, permit_id) => {
724                    tracing::debug!(?addr, ?permit_id, "dropping accepted duplicate connection");
725                }
726            }
727        }
728    }
729
730    async fn handle_peer_found(self: Arc<Self>, peer: SeenPeer, source: PeerSource) {
731        let create_backoff = || {
732            ExponentialBackoffBuilder::new()
733                .with_initial_interval(Duration::from_millis(100))
734                .with_max_interval(Duration::from_secs(8))
735                .with_max_elapsed_time(None)
736                .build()
737        };
738
739        let mut backoff = create_backoff();
740
741        let mut next_sleep = None;
742
743        loop {
744            let monitor = self.span.in_scope(|| {
745                ConnectionMonitor::new(&self.connections_monitor, peer.initial_addr(), source)
746            });
747
748            // TODO: We should also check whether the user still wants to accept connections from
749            // the given `source` (the preference may have changed in the mean time).
750
751            if self.is_shutdown() {
752                return;
753            }
754
755            let addr = match peer.addr_if_seen() {
756                Some(addr) => *addr,
757                None => return,
758            };
759
760            if self.our_addresses.lock().unwrap().contains(&addr) {
761                // Don't connect to self.
762                return;
763            }
764
765            let permit = match self.connections.reserve(addr, source) {
766                ReserveResult::Permit(permit) => permit,
767                ReserveResult::Occupied(on_release, their_source, connection_id) => {
768                    if source == their_source {
769                        // This is a duplicate from the same source, ignore it.
770                        return;
771                    }
772
773                    // This is a duplicate from a different source, if the other source releases
774                    // it, then we may want to try to keep hold of it.
775                    monitor.mark_as_awaiting_permit();
776                    tracing::debug!(
777                        parent: monitor.span(),
778                        %connection_id,
779                        "Duplicate from different source - awaiting permit"
780                    );
781
782                    on_release.await;
783
784                    next_sleep = None;
785                    backoff = create_backoff();
786
787                    continue;
788                }
789            };
790
791            if let Some(sleep) = next_sleep {
792                tracing::debug!(parent: monitor.span(), "Next connection attempt in {:?}", sleep);
793                tokio::time::sleep(sleep).await;
794            }
795
796            next_sleep = backoff.next_backoff();
797
798            permit.mark_as_connecting();
799            monitor.mark_as_connecting(permit.id());
800            tracing::trace!(parent: monitor.span(), "Connecting");
801
802            let socket = match self
803                .gateway
804                .connect_with_retries(&peer, source)
805                .instrument(monitor.span().clone())
806                .await
807            {
808                Some(socket) => socket,
809                None => break,
810            };
811
812            if !self.handle_connection(socket, permit, &monitor).await {
813                break;
814            }
815        }
816    }
817
818    /// Return true iff the peer is suitable for reconnection.
819    async fn handle_connection(
820        &self,
821        connection: Connection,
822        permit: ConnectionPermit,
823        monitor: &ConnectionMonitor,
824    ) -> bool {
825        tracing::trace!(parent: monitor.span(), "Handshaking");
826
827        permit.mark_as_handshaking();
828        monitor.mark_as_handshaking();
829
830        let handshake_result = perform_handshake(
831            &connection,
832            VERSION,
833            &self.this_runtime_id,
834            ConnectionDirection::from_source(permit.source()),
835        )
836        .await;
837
838        if let Err(error) = &handshake_result {
839            tracing::debug!(parent: monitor.span(), ?error, "Handshake failed");
840        }
841
842        let that_runtime_id = match handshake_result {
843            Ok(writer_id) => writer_id,
844            Err(HandshakeError::ProtocolVersionMismatch(their_version)) => {
845                self.on_protocol_mismatch(their_version);
846                return false;
847            }
848            Err(
849                HandshakeError::Timeout
850                | HandshakeError::BadMagic
851                | HandshakeError::Io(_)
852                | HandshakeError::Connection(_),
853            ) => return false,
854        };
855
856        // prevent self-connections.
857        if that_runtime_id == self.this_runtime_id.public() {
858            tracing::debug!(parent: monitor.span(), "Connection from self, discarding");
859            self.our_addresses.lock().unwrap().insert(permit.addr());
860            return false;
861        }
862
863        permit.mark_as_active(that_runtime_id);
864        monitor.mark_as_active(that_runtime_id);
865
866        let closed = connection.closed();
867
868        let key = {
869            let mut registry = self.registry.lock().unwrap();
870            let registry = &mut *registry;
871
872            let Some(peers) = &mut registry.peers else {
873                // Network has been shut down.
874                return false;
875            };
876
877            let pex_peer = self.pex_discovery.new_peer();
878            pex_peer.handle_connection(permit.addr(), permit.source(), permit.released());
879
880            let mut peer = monitor.span().in_scope(|| {
881                MessageBroker::new(
882                    self.this_runtime_id.public(),
883                    that_runtime_id,
884                    connection,
885                    pex_peer,
886                    self.peers_monitor
887                        .make_child(format!("{:?}", that_runtime_id.as_public_key())),
888                    self.stats_tracker.bytes.clone(),
889                    permit.byte_counters(),
890                )
891            });
892
893            // TODO: for DHT connection we should only link the repository for which we did the
894            // lookup but make sure we correctly handle edge cases, for example, when we have
895            // more than one repository shared with the peer.
896            for (_, holder) in &registry.repos {
897                peer.create_link(
898                    holder.vault.clone(),
899                    &holder.pex,
900                    holder.request_tracker.clone(),
901                    holder.choker.clone(),
902                    holder.stats_tracker.bytes.clone(),
903                );
904            }
905
906            peers.insert(peer)
907        };
908
909        // Wait until the connection gets closed, then remove the `MessageBroker` instance. Using a
910        // RAII to also remove it in case this function gets cancelled.
911        let _guard = PeerGuard {
912            registry: &self.registry,
913            key,
914        };
915
916        closed.await;
917
918        true
919    }
920
921    fn on_protocol_mismatch(&self, their_version: Version) {
922        self.protocol_versions.send_if_modified(|versions| {
923            if versions.highest_seen < their_version {
924                versions.highest_seen = their_version;
925                true
926            } else {
927                false
928            }
929        });
930    }
931
932    fn spawn<Fut>(&self, f: Fut) -> AbortHandle
933    where
934        Fut: Future<Output = ()> + Send + 'static,
935    {
936        self.tasks
937            .upgrade()
938            // TODO: this `unwrap` is sketchy. Maybe we should simply not spawn if `tasks` can't be
939            // upgraded?
940            .unwrap()
941            .lock()
942            .unwrap()
943            .spawn(f)
944    }
945}
946
947//------------------------------------------------------------------------------
948
949// Exchange runtime ids with the peer. Returns their (verified) runtime id.
950async fn perform_handshake(
951    connection: &Connection,
952    this_version: Version,
953    this_runtime_id: &SecretRuntimeId,
954    dir: ConnectionDirection,
955) -> Result<PublicRuntimeId, HandshakeError> {
956    let result = tokio::time::timeout(std::time::Duration::from_secs(5), async move {
957        let (mut writer, mut reader) = match dir {
958            ConnectionDirection::Incoming => connection.incoming().await?,
959            ConnectionDirection::Outgoing => connection.outgoing().await?,
960        };
961
962        writer.write_all(MAGIC).await?;
963
964        // Backward fix for Ouisync *App* versions v0.8.3 (and possibly prior)
965        {
966            // Those versions of Ouisync App had a race condition where if some peer with a higher
967            // protocol version managed to connect and handshake before the app subscribed to
968            // "higher protocol version" notifications then they would never get the notification.
969            // On Pixel 9a, and when connecting to a PC, the handshake happened ~100ms prior to
970            // subscribe. Using 700ms to account for slower devices.
971            tokio::time::sleep(std::time::Duration::from_millis(700)).await;
972        }
973
974        this_version.write_into(&mut writer).await?;
975
976        let mut that_magic = [0; MAGIC.len()];
977        reader.read_exact(&mut that_magic).await?;
978
979        if MAGIC != &that_magic {
980            return Err(HandshakeError::BadMagic);
981        }
982
983        let that_version = Version::read_from(&mut reader).await?;
984        if that_version > this_version {
985            return Err(HandshakeError::ProtocolVersionMismatch(that_version));
986        }
987
988        let that_runtime_id =
989            runtime_id::exchange(this_runtime_id, &mut writer, &mut reader).await?;
990
991        writer.shutdown().await?;
992
993        Ok(that_runtime_id)
994    })
995    .await;
996
997    match result {
998        Ok(subresult) => subresult,
999        Err(_) => Err(HandshakeError::Timeout),
1000    }
1001}
1002
1003#[derive(Debug, Error)]
1004enum HandshakeError {
1005    #[error("protocol version mismatch")]
1006    ProtocolVersionMismatch(Version),
1007    #[error("bad magic")]
1008    BadMagic,
1009    #[error("timeout")]
1010    Timeout,
1011    #[error("IO error")]
1012    Io(#[from] io::Error),
1013    #[error("connection error")]
1014    Connection(#[from] ConnectionError),
1015}
1016
1017// RAII guard which when dropped removes the peer from the registry.
1018struct PeerGuard<'a> {
1019    registry: &'a BlockingMutex<Registry>,
1020    key: usize,
1021}
1022
1023impl Drop for PeerGuard<'_> {
1024    fn drop(&mut self) {
1025        if let Some(peers) = &mut self
1026            .registry
1027            .lock()
1028            .unwrap_or_else(|error| error.into_inner())
1029            .peers
1030        {
1031            peers.try_remove(self.key);
1032        }
1033    }
1034}
1035
1036struct PortMappings {
1037    _mappings: Vec<upnp::Mapping>,
1038}
1039
1040impl PortMappings {
1041    fn new(forwarder: &upnp::PortForwarder, gateway: &Gateway) -> Self {
1042        let mappings = gateway
1043            .listener_local_addrs()
1044            .into_iter()
1045            .filter_map(|addr| {
1046                match addr {
1047                    PeerAddr::Quic(SocketAddr::V4(addr)) => {
1048                        Some(forwarder.add_mapping(
1049                            addr.port(), // internal
1050                            addr.port(), // external
1051                            ip::Protocol::Udp,
1052                        ))
1053                    }
1054                    PeerAddr::Tcp(SocketAddr::V4(addr)) => {
1055                        Some(forwarder.add_mapping(
1056                            addr.port(), // internal
1057                            addr.port(), // external
1058                            ip::Protocol::Tcp,
1059                        ))
1060                    }
1061                    PeerAddr::Quic(SocketAddr::V6(_)) | PeerAddr::Tcp(SocketAddr::V6(_)) => {
1062                        // TODO: the ipv6 port typically doesn't need to be port-mapped but it might
1063                        // need to be opened in the firewall ("pinholed"). Consider using UPnP for that
1064                        // as well.
1065                        None
1066                    }
1067                }
1068            })
1069            .collect();
1070
1071        Self {
1072            _mappings: mappings,
1073        }
1074    }
1075}
1076
1077enum ComponentState<T> {
1078    Enabled(T),
1079    Disabled(DisableReason),
1080}
1081
1082impl<T> ComponentState<T> {
1083    fn disabled(reason: DisableReason) -> Self {
1084        Self::Disabled(reason)
1085    }
1086
1087    fn is_enabled(&self) -> bool {
1088        matches!(self, Self::Enabled(_))
1089    }
1090
1091    fn is_disabled(&self, reason: DisableReason) -> bool {
1092        match self {
1093            Self::Disabled(current_reason) if *current_reason == reason => true,
1094            Self::Disabled(_) | Self::Enabled(_) => false,
1095        }
1096    }
1097
1098    fn disable(&mut self, reason: DisableReason) -> Option<T> {
1099        match mem::replace(self, Self::Disabled(reason)) {
1100            Self::Enabled(payload) => Some(payload),
1101            Self::Disabled(_) => None,
1102        }
1103    }
1104
1105    fn disable_if_enabled(&mut self, reason: DisableReason) -> Option<T> {
1106        match self {
1107            Self::Enabled(_) => match mem::replace(self, Self::Disabled(reason)) {
1108                Self::Enabled(payload) => Some(payload),
1109                Self::Disabled(_) => unreachable!(),
1110            },
1111            Self::Disabled(_) => None,
1112        }
1113    }
1114
1115    fn enable(&mut self, payload: T) -> Option<T> {
1116        match mem::replace(self, Self::Enabled(payload)) {
1117            Self::Enabled(payload) => Some(payload),
1118            Self::Disabled(_) => None,
1119        }
1120    }
1121}
1122
1123#[derive(Eq, PartialEq)]
1124enum DisableReason {
1125    // Disabled implicitly because `Network` was disabled
1126    Implicit,
1127    // Disabled explicitly
1128    Explicit,
1129}
1130
1131pub fn repository_info_hash(id: &RepositoryId) -> InfoHash {
1132    // Calculate the info hash by hashing the id with BLAKE3 and taking the first 20 bytes.
1133    // (bittorrent uses SHA-1 but that is less secure).
1134    // `unwrap` is OK because the byte slice has the correct length.
1135    InfoHash::try_from(&id.salted_hash(b"ouisync repository info-hash").as_ref()[..INFO_HASH_LEN])
1136        .unwrap()
1137}
1138
1139async fn shutdown_peers(peers: Slab<MessageBroker>) {
1140    future::join_all(peers.into_iter().map(|(_, peer)| peer.shutdown())).await;
1141}