Skip to main content

ouisync/network/
mod.rs

1mod addr_filter;
2mod choke;
3mod client;
4mod connection;
5mod connection_monitor;
6mod constants;
7mod crypto;
8mod debug_payload;
9mod dht;
10mod event;
11mod gateway;
12mod ip;
13mod local_discovery;
14mod message;
15mod message_broker;
16mod message_dispatcher;
17mod peer_addr;
18mod peer_exchange;
19mod peer_info;
20mod peer_source;
21mod peer_state;
22mod protocol;
23mod request_tracker;
24mod runtime_id;
25mod seen_peers;
26mod server;
27mod stats;
28mod stun;
29mod stun_server_list;
30#[cfg(test)]
31mod tests;
32mod upnp;
33
34pub use self::{
35    addr_filter::AddrFilter,
36    connection::PeerInfoCollector,
37    dht::{DEFAULT_DHT_ROUTERS, DhtContactsStoreTrait, DhtLookupStream, DhtPin},
38    event::{NetworkEvent, NetworkEventReceiver, NetworkEventStream},
39    peer_addr::PeerAddr,
40    peer_info::PeerInfo,
41    peer_source::PeerSource,
42    peer_state::PeerState,
43    runtime_id::{PublicRuntimeId, SecretRuntimeId},
44    stats::Stats,
45};
46use dht::DhtEvent;
47pub use net::{
48    bus::{BusRecvStream as RecvStream, BusSendStream as SendStream, TopicId},
49    stun::NatBehavior,
50};
51
52use self::{
53    choke::Choker,
54    connection::{ConnectionPermit, ConnectionSet, ReserveResult},
55    connection_monitor::ConnectionMonitor,
56    constants::REQUEST_TIMEOUT,
57    dht::DhtDiscovery,
58    event::ProtocolVersions,
59    gateway::{Connectivity, Gateway, StackAddresses},
60    local_discovery::LocalDiscovery,
61    message_broker::MessageBroker,
62    peer_addr::PeerPort,
63    peer_exchange::{PexDiscovery, PexRepository},
64    peer_source::ConnectionDirection,
65    protocol::{MAGIC, VERSION, Version},
66    request_tracker::RequestTracker,
67    seen_peers::{SeenPeer, SeenPeers},
68    stats::{ByteCounters, StatsTracker},
69    stun::StunClients,
70};
71use crate::{
72    protocol::RepositoryId,
73    repository::{RepositoryHandle, Vault},
74};
75use backoff::{ExponentialBackoffBuilder, backoff::Backoff};
76use btdht::{self, INFO_HASH_LEN, InfoHash};
77use deadlock::BlockingMutex;
78use futures_util::future;
79use net::{
80    quic,
81    unified::{Connection, ConnectionError},
82};
83use scoped_task::ScopedAbortHandle;
84use slab::Slab;
85use state_monitor::StateMonitor;
86use std::{
87    collections::HashSet,
88    io, mem,
89    net::{SocketAddr, SocketAddrV4, SocketAddrV6},
90    sync::{
91        Arc, Weak,
92        atomic::{AtomicBool, Ordering},
93    },
94};
95use thiserror::Error;
96use tokio::{
97    io::{AsyncReadExt, AsyncWriteExt},
98    sync::{mpsc, watch},
99    task::{AbortHandle, JoinSet},
100    time::Duration,
101};
102use tracing::{Instrument, Span};
103
104#[derive(Default)]
105pub struct NetworkBuilder {
106    dht_contacts: Option<Arc<dyn DhtContactsStoreTrait>>,
107    monitor: Option<StateMonitor>,
108    runtime_id: Option<SecretRuntimeId>,
109    addr_filter: AddrFilter,
110}
111
112impl NetworkBuilder {
113    pub fn dht_contacts(self, contacts: Arc<dyn DhtContactsStoreTrait>) -> Self {
114        Self {
115            dht_contacts: Some(contacts),
116            ..self
117        }
118    }
119
120    pub fn monitor(self, monitor: StateMonitor) -> Self {
121        Self {
122            monitor: Some(monitor),
123            ..self
124        }
125    }
126
127    pub fn runtime_id(self, runtime_id: SecretRuntimeId) -> Self {
128        Self {
129            runtime_id: Some(runtime_id),
130            ..self
131        }
132    }
133
134    pub fn addr_filter(self, addr_filter: AddrFilter) -> Self {
135        Self {
136            addr_filter,
137            ..self
138        }
139    }
140
141    pub fn build(self) -> Network {
142        let (incoming_tx, incoming_rx) = mpsc::channel(1);
143        let gateway = Gateway::new(incoming_tx);
144        let monitor = self.monitor.unwrap_or_else(StateMonitor::make_root);
145
146        // Note that we're now only using quic for the transport discovered over the dht.
147        // This is because the dht doesn't let us specify whether the remote peer SocketAddr is
148        // TCP, UDP or anything else.
149        // TODO: There are ways to address this: e.g. we could try both, or we could include
150        // the protocol information in the info-hash generation. There are pros and cons to
151        // these approaches.
152        let dht_discovery =
153            DhtDiscovery::new(None, None, self.dht_contacts, monitor.make_child("DHT"));
154        // TODO: do we need unbounded channel here?
155        let (dht_discovery_tx, dht_discovery_rx) = mpsc::unbounded_channel();
156
157        let port_forwarder = upnp::PortForwarder::new(monitor.make_child("UPnP"));
158
159        let (pex_discovery_tx, pex_discovery_rx) = mpsc::channel(1);
160        let pex_discovery = PexDiscovery::new(pex_discovery_tx);
161
162        let user_provided_peers = SeenPeers::new();
163
164        let this_runtime_id = self.runtime_id.unwrap_or_else(SecretRuntimeId::random);
165        let this_runtime_id_public = this_runtime_id.public();
166
167        let connections_monitor = monitor.make_child("Connections");
168        let peers_monitor = monitor.make_child("Peers");
169
170        let tasks = Arc::new(BlockingMutex::new(JoinSet::new()));
171
172        let inner = Arc::new(Inner {
173            main_monitor: monitor,
174            connections_monitor,
175            peers_monitor,
176            span: Span::current(),
177            gateway,
178            this_runtime_id,
179            registry: BlockingMutex::new(Registry {
180                peers: Some(Slab::new()),
181                repos: Slab::new(),
182            }),
183            port_forwarder,
184            port_forwarder_state: BlockingMutex::new(ComponentState::disabled(
185                DisableReason::Explicit,
186            )),
187            local_discovery_state: BlockingMutex::new(ComponentState::disabled(
188                DisableReason::Explicit,
189            )),
190            dht_discovery,
191            dht_discovery_tx,
192            local_dht_enabled: AtomicBool::new(false),
193            pex_discovery,
194            stun_clients: StunClients::new(),
195            connections: ConnectionSet::new(),
196            user_provided_peers,
197            tasks: Arc::downgrade(&tasks),
198            protocol_versions: watch::Sender::new(ProtocolVersions::new()),
199            our_addresses: BlockingMutex::new(HashSet::default()),
200            stats_tracker: StatsTracker::default(),
201            addr_filter: self.addr_filter,
202        });
203
204        inner.spawn(inner.clone().handle_incoming_connections(incoming_rx));
205        inner.spawn(inner.clone().run_dht(dht_discovery_rx));
206        inner.spawn(inner.clone().run_peer_exchange(pex_discovery_rx));
207
208        tracing::debug!(this_runtime_id = ?this_runtime_id_public.as_public_key(), "Network created");
209
210        Network {
211            inner,
212            _tasks: tasks,
213        }
214    }
215}
216
217pub struct Network {
218    inner: Arc<Inner>,
219    // We keep tasks here instead of in Inner because we want them to be
220    // destroyed when Network is Dropped.
221    _tasks: Arc<BlockingMutex<JoinSet<()>>>,
222}
223
224impl Network {
225    /// Returns builder to create `Network` with custom options.
226    pub fn builder() -> NetworkBuilder {
227        NetworkBuilder::default()
228    }
229
230    /// Create network with default options. Equal to `Self::builder().build()`.
231    pub fn new() -> Self {
232        Self::builder().build()
233    }
234
235    /// Binds the network to the specified addresses.
236    /// Rebinds if already bound. Unbinds and disables the network if `addrs` is empty.
237    ///
238    /// NOTE: currently at most one address per protocol (QUIC/TCP) and family (IPv4/IPv6) is used
239    /// and the rest are ignored, but this might change in the future.
240    pub async fn bind(&self, addrs: &[PeerAddr]) {
241        self.inner.bind(addrs).await
242    }
243
244    pub fn listener_local_addrs(&self) -> Vec<PeerAddr> {
245        self.inner.gateway.listener_local_addrs()
246    }
247
248    pub fn set_port_forwarding_enabled(&self, enabled: bool) {
249        let mut state = self.inner.port_forwarder_state.lock().unwrap();
250
251        if enabled {
252            if state.is_enabled() {
253                return;
254            }
255
256            state.enable(PortMappings::new(
257                &self.inner.port_forwarder,
258                &self.inner.gateway,
259            ));
260        } else {
261            state.disable(DisableReason::Explicit);
262        }
263    }
264
265    pub fn is_port_forwarding_enabled(&self) -> bool {
266        self.inner.port_forwarder_state.lock().unwrap().is_enabled()
267    }
268
269    pub fn set_local_discovery_enabled(&self, enabled: bool) {
270        let mut state = self.inner.local_discovery_state.lock().unwrap();
271
272        if enabled {
273            if state.is_enabled() {
274                return;
275            }
276
277            if let Some(handle) = self.inner.spawn_local_discovery() {
278                state.enable(handle.into());
279            } else {
280                state.disable(DisableReason::Implicit);
281            }
282        } else {
283            state.disable(DisableReason::Explicit);
284        }
285    }
286
287    pub fn is_local_discovery_enabled(&self) -> bool {
288        self.inner
289            .local_discovery_state
290            .lock()
291            .unwrap()
292            .is_enabled()
293    }
294
295    /// Sets whether sending contacts to other peer over peer exchange is enabled.
296    ///
297    /// Note: PEX sending for a given repo is enabled only if it's enabled globally using this
298    /// function and also for the repo using [Registration::set_pex_enabled].
299    pub fn set_pex_send_enabled(&self, enabled: bool) {
300        self.inner.pex_discovery.set_send_enabled(enabled)
301    }
302
303    pub fn is_pex_send_enabled(&self) -> bool {
304        self.inner.pex_discovery.is_send_enabled()
305    }
306
307    /// Sets whether receiving contacts over peer exchange is enabled.
308    ///
309    /// Note: PEX receiving for a given repo is enabled only if it's enabled globally using this
310    /// function and also for the repo using [Registration::set_pex_enabled].
311    pub fn set_pex_recv_enabled(&self, enabled: bool) {
312        self.inner.pex_discovery.set_recv_enabled(enabled)
313    }
314
315    pub fn is_pex_recv_enabled(&self) -> bool {
316        self.inner.pex_discovery.is_recv_enabled()
317    }
318    /// Find out external address using the STUN protocol.
319    /// Currently QUIC only.
320    pub async fn external_addr_v4(&self) -> Option<SocketAddrV4> {
321        self.inner.stun_clients.external_addr_v4().await
322    }
323
324    /// Find out external address using the STUN protocol.
325    /// Currently QUIC only.
326    pub async fn external_addr_v6(&self) -> Option<SocketAddrV6> {
327        self.inner.stun_clients.external_addr_v6().await
328    }
329
330    /// Determine the behaviour of the NAT we are behind. Returns `None` on unknown.
331    /// Currently IPv4 only.
332    pub async fn nat_behavior(&self) -> Option<NatBehavior> {
333        self.inner.stun_clients.nat_behavior().await
334    }
335
336    /// Get the network traffic stats.
337    pub fn stats(&self) -> Stats {
338        self.inner.stats_tracker.read()
339    }
340
341    pub fn add_user_provided_peer(&self, peer: &PeerAddr) {
342        self.inner.clone().establish_user_provided_connection(peer);
343    }
344
345    pub fn remove_user_provided_peer(&self, peer: &PeerAddr) {
346        self.inner.user_provided_peers.remove(peer)
347    }
348
349    pub fn this_runtime_id(&self) -> PublicRuntimeId {
350        self.inner.this_runtime_id.public()
351    }
352
353    pub fn peer_info_collector(&self) -> PeerInfoCollector {
354        self.inner.connections.peer_info_collector()
355    }
356
357    pub fn peer_info(&self, addr: PeerAddr) -> Option<PeerInfo> {
358        self.inner.connections.get_peer_info(addr)
359    }
360
361    pub fn current_protocol_version(&self) -> u64 {
362        self.inner.protocol_versions.borrow().our.into()
363    }
364
365    pub fn highest_seen_protocol_version(&self) -> u64 {
366        self.inner.protocol_versions.borrow().highest_seen.into()
367    }
368
369    /// Subscribe to network events.
370    pub fn subscribe(&self) -> NetworkEventReceiver {
371        NetworkEventReceiver::new(
372            self.inner.protocol_versions.subscribe(),
373            self.inner.connections.subscribe(),
374        )
375    }
376
377    /// Register a local repository into the network. This links the repository with all matching
378    /// repositories of currently connected remote replicas as well as any replicas connected in
379    /// the future. The repository is automatically deregistered when the returned handle is
380    /// dropped.
381    ///
382    /// Note: A repository should have at most one registration - creating more than one has
383    /// undesired effects. This is currently not enforced and so it's a responsibility of the
384    /// caller.
385    pub fn register(&self, handle: RepositoryHandle) -> Registration {
386        *handle.vault.monitor.info_hash.get() =
387            Some(repository_info_hash(handle.vault.repository_id()));
388
389        let pex = self.inner.pex_discovery.new_repository();
390
391        let request_tracker = RequestTracker::new(handle.vault.monitor.traffic.clone());
392        request_tracker.set_timeout(REQUEST_TIMEOUT);
393
394        // TODO: Should this be global instead of per repo?
395        let choker = Choker::new();
396
397        let stats_tracker = StatsTracker::default();
398
399        let mut registry = self.inner.registry.lock().unwrap();
400
401        registry.create_link(
402            handle.vault.clone(),
403            &pex,
404            &request_tracker,
405            &choker,
406            stats_tracker.bytes.clone(),
407        );
408
409        let key = registry.repos.insert(RegistrationHolder {
410            vault: handle.vault,
411            dht: None,
412            pex,
413            request_tracker,
414            choker,
415            stats_tracker,
416        });
417
418        Registration {
419            inner: self.inner.clone(),
420            key,
421        }
422    }
423
424    /// Gracefully disconnect from peers. Failing to call this function on app termination will
425    /// cause the peers to not learn that we disconnected just now. They will still find out later
426    /// once the keep-alive mechanism kicks in, but in the mean time we will not be able to
427    /// reconnect (by starting the app again) because the remote peer will keep dropping new
428    /// connections from us.
429    pub async fn shutdown(&self) {
430        // TODO: Would be a nice-to-have to also wait for all the spawned tasks here (e.g. dicovery
431        // mechanisms).
432        let Some(peers) = self.inner.registry.lock().unwrap().peers.take() else {
433            tracing::warn!("Network already shut down");
434            return;
435        };
436
437        shutdown_peers(peers).await;
438    }
439
440    /// Change the sync protocol request timeout. Useful mostly for testing and benchmarking as the
441    /// default value should be sufficient for most use cases.
442    pub fn set_request_timeout(&self, timeout: Duration) {
443        for (_, holder) in &self.inner.registry.lock().unwrap().repos {
444            holder.request_tracker.set_timeout(timeout);
445        }
446    }
447
448    /// Opens a side channel for the underlying IPv4 UDP socket, or `None` if IPv4 QUIC stack isn't
449    /// configured.
450    ///
451    /// The side channel is used to send/receive raw UDP datagrams on the same socket that the sync
452    /// protocol uses.
453    pub fn open_udp_side_channel_v4(&self) -> Option<quic::SideChannel> {
454        self.inner
455            .gateway
456            .udp_side_channel_maker_v4()
457            .as_ref()
458            .map(|m| m.make())
459    }
460
461    /// Opens a side channel for the underlying IPv6 UDP socket, or `None` if IPv6 QUIC stack isn't
462    /// configured.
463    ///
464    /// The side channel is used to send/receive raw UDP datagrams on the same socket that the sync
465    /// protocol uses.
466    pub fn open_udp_side_channel_v6(&self) -> Option<quic::SideChannel> {
467        self.inner
468            .gateway
469            .udp_side_channel_maker_v4()
470            .as_ref()
471            .map(|m| m.make())
472    }
473
474    /// Opens raw byte stream to the given peer, bound to the given topic. This can be used to
475    /// send/recv arbitrary data to the peer, outside of the ouisync protocol.
476    ///
477    /// Returns `None` if no active connection to the peer exists.
478    pub fn open_stream(
479        &self,
480        addr: PeerAddr,
481        topic_id: TopicId,
482    ) -> Option<(SendStream, RecvStream)> {
483        let key = self.inner.connections.get_peer_key(addr)?;
484        Some(
485            self.inner
486                .registry
487                .lock()
488                .unwrap()
489                .peers
490                .as_ref()?
491                .get(key)?
492                .open_stream(topic_id),
493        )
494    }
495
496    /// Changes the DHT routers (boostrap nodes), rebootstraps the DHTs and restarts any ongoing
497    /// lookups.
498    pub fn set_dht_routers(&self, routers: HashSet<String>) {
499        self.inner.dht_discovery.set_routers(routers);
500    }
501
502    /// Returns the current DHT routers (bootstrap nodes).
503    pub fn dht_routers(&self) -> HashSet<String> {
504        self.inner.dht_discovery.routers()
505    }
506
507    /// Performs explicit DHT lookup or announce for the given infohash and returns a stream of the
508    /// discovered peer addresses. It will not automatically connect to them.
509    pub fn dht_lookup(&self, info_hash: InfoHash, announce: bool) -> DhtLookupStream {
510        DhtLookupStream::start(
511            &self.inner.dht_discovery,
512            info_hash,
513            announce,
514            self.is_local_dht_enabled(),
515        )
516    }
517
518    /// Set whether DHT on the local network (or localhost) is enabled. By default this is `false`
519    /// because DHT is a global discovery mechanism and finding a local peer on it is unexpected
520    /// (and could indicate malice). However, is some situations it's still useful to enable it
521    /// (typically for testing).
522    ///
523    /// Note: this option is currently experimental and unstable (semver extempt). It's possible it
524    /// will be removed in the future.
525    pub fn set_local_dht_enabled(&self, enabled: bool) {
526        let prev = self
527            .inner
528            .local_dht_enabled
529            .swap(enabled, Ordering::Release);
530
531        if prev != enabled {
532            self.inner.rebind_dht(self.inner.gateway.connectivity());
533        }
534    }
535
536    pub fn is_local_dht_enabled(&self) -> bool {
537        self.inner.local_dht_enabled.load(Ordering::Acquire)
538    }
539
540    /// Creates a "pin" which starts the DHT instances and keeps them running. This prevents the
541    /// DHTs to shut down even when there are no more ongoing lookups. This is useful if one wants
542    /// to avoid having to rebootstrap the DHT when doing another lookup in the future.
543    ///
544    /// Note that DHT is automatically started and kept running when there is at least one
545    /// repository with DHT enabled. Thus, pinning the DHT while having DHT-enabled repos is
546    /// unnecessary (but harmless).
547    pub async fn pin_dht(&self) -> DhtPin {
548        self.inner.dht_discovery.pin().await
549    }
550}
551
552impl Default for Network {
553    fn default() -> Self {
554        Self::new()
555    }
556}
557
558pub struct Registration {
559    inner: Arc<Inner>,
560    key: usize,
561}
562
563impl Registration {
564    pub fn set_dht_enabled(&self, enabled: bool) {
565        let mut registry = self.inner.registry.lock().unwrap();
566        let holder = &mut registry.repos[self.key];
567
568        if enabled {
569            holder.dht = Some(
570                self.inner
571                    .start_dht_lookup(repository_info_hash(holder.vault.repository_id())),
572            );
573        } else {
574            holder.dht = None;
575        }
576    }
577
578    /// This function provides the information to the user whether DHT is enabled for this
579    /// repository, not necessarily whether the DHT tasks are currently running. The subtle
580    /// difference is in that this function should return true even in case e.g. the whole network
581    /// is disabled.
582    pub fn is_dht_enabled(&self) -> bool {
583        self.inner.registry.lock().unwrap().repos[self.key]
584            .dht
585            .is_some()
586    }
587
588    /// Enables/disables peer exchange for this repo.
589    ///
590    /// Note: sending/receiving over PEX for this repo is enabled only if it's enabled using this
591    /// function and also globally using [Network::set_pex_send_enabled] and/or
592    /// [Network::set_pex_recv_enabled].
593    pub fn set_pex_enabled(&self, enabled: bool) {
594        let registry = self.inner.registry.lock().unwrap();
595        registry.repos[self.key].pex.set_enabled(enabled);
596    }
597
598    pub fn is_pex_enabled(&self) -> bool {
599        self.inner.registry.lock().unwrap().repos[self.key]
600            .pex
601            .is_enabled()
602    }
603
604    /// Fetch per-repository network statistics.
605    pub fn stats(&self) -> Stats {
606        self.inner.registry.lock().unwrap().repos[self.key]
607            .stats_tracker
608            .read()
609    }
610}
611
612impl Drop for Registration {
613    fn drop(&mut self) {
614        let mut registry = self
615            .inner
616            .registry
617            .lock()
618            .unwrap_or_else(|error| error.into_inner());
619
620        if let Some(holder) = registry.repos.try_remove(self.key) {
621            for (_, peer) in registry.peers.as_mut().into_iter().flatten() {
622                peer.destroy_link(holder.vault.repository_id());
623            }
624        }
625    }
626}
627
628struct RegistrationHolder {
629    vault: Vault,
630    dht: Option<dht::LookupRequest>,
631    pex: PexRepository,
632    request_tracker: RequestTracker,
633    choker: Choker,
634    stats_tracker: StatsTracker,
635}
636
637struct Inner {
638    main_monitor: StateMonitor,
639    connections_monitor: StateMonitor,
640    peers_monitor: StateMonitor,
641    span: Span,
642    gateway: Gateway,
643    this_runtime_id: SecretRuntimeId,
644    registry: BlockingMutex<Registry>,
645    port_forwarder: upnp::PortForwarder,
646    port_forwarder_state: BlockingMutex<ComponentState<PortMappings>>,
647    local_discovery_state: BlockingMutex<ComponentState<ScopedAbortHandle>>,
648    dht_discovery: DhtDiscovery,
649    dht_discovery_tx: mpsc::UnboundedSender<DhtEvent>,
650    local_dht_enabled: AtomicBool,
651    pex_discovery: PexDiscovery,
652    stun_clients: StunClients,
653    connections: ConnectionSet,
654    protocol_versions: watch::Sender<ProtocolVersions>,
655    user_provided_peers: SeenPeers,
656    // Note that unwrapping the upgraded weak pointer should be fine because if the underlying Arc
657    // was Dropped, we would not be asking for the upgrade in the first place.
658    tasks: Weak<BlockingMutex<JoinSet<()>>>,
659    // Used to prevent repeatedly connecting to self.
660    our_addresses: BlockingMutex<HashSet<PeerAddr>>,
661    stats_tracker: StatsTracker,
662    addr_filter: AddrFilter,
663}
664
665struct Registry {
666    // This is None once the network calls shutdown.
667    peers: Option<Slab<MessageBroker>>,
668    repos: Slab<RegistrationHolder>,
669}
670
671impl Registry {
672    fn create_link(
673        &mut self,
674        repo: Vault,
675        pex: &PexRepository,
676        request_tracker: &RequestTracker,
677        choker: &Choker,
678        byte_counters: Arc<ByteCounters>,
679    ) {
680        if let Some(peers) = &mut self.peers {
681            for (_, peer) in peers {
682                peer.create_link(
683                    repo.clone(),
684                    pex,
685                    request_tracker.clone(),
686                    choker.clone(),
687                    byte_counters.clone(),
688                )
689            }
690        }
691    }
692}
693
694impl Inner {
695    fn is_shutdown(&self) -> bool {
696        self.registry.lock().unwrap().peers.is_none()
697    }
698
699    async fn bind(self: &Arc<Self>, bind: &[PeerAddr]) {
700        let bind = StackAddresses::from(bind);
701
702        // TODO: Would be preferable to only rebind those stacks that actually need rebinding.
703        if !self.gateway.addresses().any_stack_needs_rebind(&bind) {
704            return;
705        }
706
707        // Gateway
708        self.span.in_scope(|| self.gateway.bind(&bind));
709
710        let conn = self.gateway.connectivity();
711
712        // STUN
713        match conn {
714            Connectivity::Full => self.stun_clients.rebind(
715                self.gateway.udp_side_channel_maker_v4().map(|m| m.make()),
716                self.gateway.udp_side_channel_maker_v6().map(|m| m.make()),
717            ),
718            Connectivity::LocalOnly | Connectivity::Disabled => (),
719        }
720
721        // DHT
722        self.rebind_dht(conn);
723
724        // Port forwarding
725        match conn {
726            Connectivity::Full => {
727                let mut state = self.port_forwarder_state.lock().unwrap();
728                if !state.is_disabled(DisableReason::Explicit) {
729                    state.enable(PortMappings::new(&self.port_forwarder, &self.gateway));
730                }
731            }
732            Connectivity::LocalOnly | Connectivity::Disabled => {
733                self.port_forwarder_state
734                    .lock()
735                    .unwrap()
736                    .disable_if_enabled(DisableReason::Implicit);
737            }
738        }
739
740        // Local discovery
741        //
742        // Note: no need to check the Connectivity because local discovery depends only on whether
743        // Gateway is bound.
744        {
745            let mut state = self.local_discovery_state.lock().unwrap();
746            if !state.is_disabled(DisableReason::Explicit) {
747                if let Some(handle) = self.spawn_local_discovery() {
748                    state.enable(handle.into());
749                } else {
750                    state.disable(DisableReason::Implicit);
751                }
752            }
753        }
754
755        // - If we are disabling connectivity, disconnect from all existing peers.
756        // - If we are going from `Full` -> `LocalOnly`, also disconnect from all with the
757        //   assumption that the local ones will be subsequently re-established. Ideally we would
758        //   disconnect only the non-local ones to avoid the reconnect overhead, but the
759        //   implementation is simpler this way and the trade-off doesn't seem to be too bad.
760        // - If we are going to `Full`, keep all existing connections.
761        if matches!(conn, Connectivity::LocalOnly | Connectivity::Disabled) {
762            self.disconnect_all().await;
763        }
764    }
765
766    // Disconnect from all currently connected peers, regardless of their source.
767    async fn disconnect_all(&self) {
768        let Some(peers) = self.registry.lock().unwrap().peers.replace(Slab::default()) else {
769            return;
770        };
771
772        shutdown_peers(peers).await;
773    }
774
775    fn spawn_local_discovery(self: &Arc<Self>) -> Option<AbortHandle> {
776        let addrs = self.gateway.listener_local_addrs();
777        let tcp_port = addrs
778            .iter()
779            .find(|addr| matches!(addr, PeerAddr::Tcp(SocketAddr::V4(_))))
780            .map(|addr| PeerPort::Tcp(addr.port()));
781        let quic_port = addrs
782            .iter()
783            .find(|addr| matches!(addr, PeerAddr::Quic(SocketAddr::V4(_))))
784            .map(|addr| PeerPort::Quic(addr.port()));
785
786        // Arbitrary order of preference.
787        // TODO: Should we support all available?
788        let port = tcp_port.or(quic_port);
789
790        if let Some(port) = port {
791            Some(
792                self.spawn(
793                    self.clone()
794                        .run_local_discovery(port)
795                        .instrument(self.span.clone()),
796                ),
797            )
798        } else {
799            tracing::error!("Not enabling local discovery because there is no IPv4 listener");
800            None
801        }
802    }
803
804    async fn run_local_discovery(self: Arc<Self>, listener_port: PeerPort) {
805        let mut discovery = LocalDiscovery::new(
806            listener_port,
807            self.main_monitor.make_child("LocalDiscovery"),
808        );
809
810        loop {
811            let peer = discovery.recv().await;
812
813            if self.is_shutdown() {
814                break;
815            }
816
817            self.spawn(
818                self.clone()
819                    .handle_peer_found(peer, PeerSource::LocalDiscovery),
820            );
821        }
822    }
823
824    fn start_dht_lookup(&self, info_hash: InfoHash) -> dht::LookupRequest {
825        self.dht_discovery
826            .start_lookup(info_hash, true, self.dht_discovery_tx.clone())
827    }
828
829    fn rebind_dht(&self, conn: Connectivity) {
830        match (conn, self.local_dht_enabled.load(Ordering::Acquire)) {
831            (Connectivity::Full, _) | (Connectivity::LocalOnly, true) => self.dht_discovery.rebind(
832                self.gateway.udp_side_channel_maker_v4(),
833                self.gateway.udp_side_channel_maker_v6(),
834            ),
835            (Connectivity::LocalOnly, false) | (Connectivity::Disabled, _) => {
836                self.dht_discovery.rebind(None, None)
837            }
838        }
839    }
840
841    async fn run_dht(self: Arc<Self>, mut discovery_rx: mpsc::UnboundedReceiver<DhtEvent>) {
842        while let Some(event) = discovery_rx.recv().await {
843            if self.is_shutdown() {
844                break;
845            }
846
847            let peer = match event {
848                DhtEvent::PeerFound(peer) => peer,
849                DhtEvent::RoundEnded => continue,
850            };
851
852            if !self.local_dht_enabled.load(Ordering::Acquire) && peer.initial_addr().is_local() {
853                continue;
854            }
855
856            self.spawn(self.clone().handle_peer_found(peer, PeerSource::Dht));
857        }
858    }
859
860    async fn run_peer_exchange(self: Arc<Self>, mut discovery_rx: mpsc::Receiver<SeenPeer>) {
861        while let Some(peer) = discovery_rx.recv().await {
862            if self.is_shutdown() {
863                break;
864            }
865
866            self.spawn(
867                self.clone()
868                    .handle_peer_found(peer, PeerSource::PeerExchange),
869            );
870        }
871    }
872
873    fn establish_user_provided_connection(self: Arc<Self>, peer: &PeerAddr) {
874        let peer = match self.user_provided_peers.insert(*peer) {
875            Some(peer) => peer,
876            // Already in `user_provided_peers`.
877            None => return,
878        };
879
880        self.spawn(
881            self.clone()
882                .handle_peer_found(peer, PeerSource::UserProvided),
883        );
884    }
885
886    async fn handle_incoming_connections(
887        self: Arc<Self>,
888        mut rx: mpsc::Receiver<(Connection, PeerAddr)>,
889    ) {
890        while let Some((connection, addr)) = rx.recv().await {
891            match self.connections.reserve(addr, PeerSource::Listener) {
892                ReserveResult::Permit(permit) => {
893                    if self.is_shutdown() {
894                        break;
895                    }
896
897                    let this = self.clone();
898
899                    let monitor = self.span.in_scope(|| {
900                        ConnectionMonitor::new(
901                            &self.connections_monitor,
902                            &permit.addr(),
903                            permit.source(),
904                        )
905                    });
906                    monitor.mark_as_connecting(permit.id());
907
908                    self.spawn(async move {
909                        this.handle_connection(connection, permit, &monitor).await;
910                    });
911                }
912                ReserveResult::Occupied(_, _their_source, permit_id) => {
913                    tracing::debug!(?addr, ?permit_id, "dropping accepted duplicate connection");
914                }
915            }
916        }
917    }
918
919    async fn handle_peer_found(self: Arc<Self>, peer: SeenPeer, source: PeerSource) {
920        let create_backoff = || {
921            ExponentialBackoffBuilder::new()
922                .with_initial_interval(Duration::from_millis(100))
923                .with_max_interval(Duration::from_secs(8))
924                .with_max_elapsed_time(None)
925                .build()
926        };
927
928        let mut backoff = create_backoff();
929
930        let mut next_sleep = None;
931
932        loop {
933            let monitor = self.span.in_scope(|| {
934                ConnectionMonitor::new(&self.connections_monitor, peer.initial_addr(), source)
935            });
936
937            // TODO: We should also check whether the user still wants to accept connections from
938            // the given `source` (the preference may have changed in the mean time).
939
940            if self.is_shutdown() {
941                return;
942            }
943
944            let addr = match peer.addr_if_seen() {
945                Some(addr) => *addr,
946                None => return,
947            };
948
949            if self.our_addresses.lock().unwrap().contains(&addr) {
950                // Don't connect to self.
951                return;
952            }
953
954            let permit = match self.connections.reserve(addr, source) {
955                ReserveResult::Permit(permit) => permit,
956                ReserveResult::Occupied(on_release, their_source, connection_id) => {
957                    if source == their_source {
958                        // This is a duplicate from the same source, ignore it.
959                        return;
960                    }
961
962                    // This is a duplicate from a different source, if the other source releases
963                    // it, then we may want to try to keep hold of it.
964                    monitor.mark_as_awaiting_permit();
965                    tracing::debug!(
966                        parent: monitor.span(),
967                        %connection_id,
968                        "Duplicate from different source - awaiting permit"
969                    );
970
971                    on_release.await;
972
973                    next_sleep = None;
974                    backoff = create_backoff();
975
976                    continue;
977                }
978            };
979
980            if let Some(sleep) = next_sleep {
981                tracing::debug!(parent: monitor.span(), "Next connection attempt in {:?}", sleep);
982                tokio::time::sleep(sleep).await;
983            }
984
985            next_sleep = backoff.next_backoff();
986
987            permit.mark_as_connecting();
988            monitor.mark_as_connecting(permit.id());
989            tracing::trace!(parent: monitor.span(), "Connecting");
990
991            let Some(addr) = peer.addr_if_seen() else {
992                break;
993            };
994
995            if !self.addr_filter.apply(addr.socket_addr()) {
996                tracing::debug!("Invalid peer address - discarding");
997                break;
998            }
999
1000            let socket = match self
1001                .gateway
1002                .connect_with_retries(&peer)
1003                .instrument(monitor.span().clone())
1004                .await
1005            {
1006                Some(socket) => socket,
1007                None => break,
1008            };
1009
1010            if !self.handle_connection(socket, permit, &monitor).await {
1011                break;
1012            }
1013        }
1014    }
1015
1016    /// Return true iff the peer is suitable for reconnection.
1017    async fn handle_connection(
1018        &self,
1019        connection: Connection,
1020        permit: ConnectionPermit,
1021        monitor: &ConnectionMonitor,
1022    ) -> bool {
1023        tracing::trace!(parent: monitor.span(), "Handshaking");
1024
1025        permit.mark_as_handshaking();
1026        monitor.mark_as_handshaking();
1027
1028        let handshake_result = perform_handshake(
1029            &connection,
1030            VERSION,
1031            &self.this_runtime_id,
1032            permit.source().direction(),
1033        )
1034        .await;
1035
1036        if let Err(error) = &handshake_result {
1037            tracing::debug!(parent: monitor.span(), ?error, "Handshake failed");
1038        }
1039
1040        let that_runtime_id = match handshake_result {
1041            Ok(writer_id) => writer_id,
1042            Err(HandshakeError::ProtocolVersionMismatch(their_version)) => {
1043                self.on_protocol_mismatch(their_version);
1044                return false;
1045            }
1046            Err(
1047                HandshakeError::Timeout
1048                | HandshakeError::BadMagic
1049                | HandshakeError::Io(_)
1050                | HandshakeError::Connection(_),
1051            ) => return false,
1052        };
1053
1054        // prevent self-connections.
1055        if that_runtime_id == self.this_runtime_id.public() {
1056            tracing::debug!(parent: monitor.span(), "Connection from self, discarding");
1057            self.our_addresses.lock().unwrap().insert(permit.addr());
1058            return false;
1059        }
1060
1061        let closed = connection.closed();
1062
1063        let key = {
1064            let mut registry = self.registry.lock().unwrap();
1065            let registry = &mut *registry;
1066
1067            let Some(peers) = &mut registry.peers else {
1068                // Network has been shut down.
1069                return false;
1070            };
1071
1072            let pex_peer = self.pex_discovery.new_peer();
1073            pex_peer.handle_connection(permit.addr(), permit.source(), permit.released());
1074
1075            let mut peer = monitor.span().in_scope(|| {
1076                MessageBroker::new(
1077                    self.this_runtime_id.public(),
1078                    that_runtime_id,
1079                    connection,
1080                    pex_peer,
1081                    self.peers_monitor.make_child(format!(
1082                        "{} {}",
1083                        permit.source().direction().glyph(),
1084                        permit.addr()
1085                    )),
1086                    self.stats_tracker.bytes.clone(),
1087                    permit.byte_counters(),
1088                )
1089            });
1090
1091            // TODO: for DHT connection we should only link the repository for which we did the
1092            // lookup but make sure we correctly handle edge cases, for example, when we have
1093            // more than one repository shared with the peer.
1094            for (_, holder) in &registry.repos {
1095                peer.create_link(
1096                    holder.vault.clone(),
1097                    &holder.pex,
1098                    holder.request_tracker.clone(),
1099                    holder.choker.clone(),
1100                    holder.stats_tracker.bytes.clone(),
1101                );
1102            }
1103
1104            peers.insert(peer)
1105        };
1106
1107        permit.mark_as_active(that_runtime_id, key);
1108        monitor.mark_as_active(that_runtime_id);
1109
1110        // Wait until the connection gets closed, then remove the `MessageBroker` instance. Using a
1111        // RAII to also remove it in case this function gets cancelled.
1112        let _guard = PeerGuard {
1113            registry: &self.registry,
1114            key,
1115        };
1116
1117        closed.await;
1118
1119        true
1120    }
1121
1122    fn on_protocol_mismatch(&self, their_version: Version) {
1123        self.protocol_versions.send_if_modified(|versions| {
1124            if versions.highest_seen < their_version {
1125                versions.highest_seen = their_version;
1126                true
1127            } else {
1128                false
1129            }
1130        });
1131    }
1132
1133    fn spawn<Fut>(&self, f: Fut) -> AbortHandle
1134    where
1135        Fut: Future<Output = ()> + Send + 'static,
1136    {
1137        // TODO: this `unwrap` is sketchy. Maybe we should simply not spawn if `tasks` can't be
1138        // upgraded?
1139        let tasks = self.tasks.upgrade().unwrap();
1140        let mut tasks = tasks.lock().unwrap();
1141
1142        // IMPORTANT: Drain completed tasks. This is necessary because `JoinSet` doesn't
1143        // automatically remove completed tasks (presumably to not lose their results), and so not
1144        // doing it would cause memory leak.
1145        while tasks.try_join_next().is_some() {}
1146
1147        tasks.spawn(f.instrument(Span::current()))
1148    }
1149}
1150
1151//------------------------------------------------------------------------------
1152
1153// Exchange runtime ids with the peer. Returns their (verified) runtime id.
1154async fn perform_handshake(
1155    connection: &Connection,
1156    this_version: Version,
1157    this_runtime_id: &SecretRuntimeId,
1158    dir: ConnectionDirection,
1159) -> Result<PublicRuntimeId, HandshakeError> {
1160    let result = tokio::time::timeout(std::time::Duration::from_secs(5), async move {
1161        let (mut writer, mut reader) = match dir {
1162            ConnectionDirection::Incoming => connection.incoming().await?,
1163            ConnectionDirection::Outgoing => connection.outgoing().await?,
1164        };
1165
1166        writer.write_all(MAGIC).await?;
1167
1168        // Backward fix for Ouisync *App* versions v0.8.3 (and possibly prior)
1169        {
1170            // Those versions of Ouisync App had a race condition where if some peer with a higher
1171            // protocol version managed to connect and handshake before the app subscribed to
1172            // "higher protocol version" notifications then they would never get the notification.
1173            // On Pixel 9a, and when connecting to a PC, the handshake happened ~100ms prior to
1174            // subscribe. Using 700ms to account for slower devices.
1175            tokio::time::sleep(std::time::Duration::from_millis(700)).await;
1176        }
1177
1178        this_version.write_into(&mut writer).await?;
1179
1180        let mut that_magic = [0; MAGIC.len()];
1181        reader.read_exact(&mut that_magic).await?;
1182
1183        if MAGIC != &that_magic {
1184            return Err(HandshakeError::BadMagic);
1185        }
1186
1187        let that_version = Version::read_from(&mut reader).await?;
1188        if that_version > this_version {
1189            return Err(HandshakeError::ProtocolVersionMismatch(that_version));
1190        }
1191
1192        let that_runtime_id =
1193            runtime_id::exchange(this_runtime_id, &mut writer, &mut reader).await?;
1194
1195        writer.shutdown().await?;
1196
1197        Ok(that_runtime_id)
1198    })
1199    .await;
1200
1201    match result {
1202        Ok(subresult) => subresult,
1203        Err(_) => Err(HandshakeError::Timeout),
1204    }
1205}
1206
1207#[derive(Debug, Error)]
1208enum HandshakeError {
1209    #[error("protocol version mismatch")]
1210    ProtocolVersionMismatch(Version),
1211    #[error("bad magic")]
1212    BadMagic,
1213    #[error("timeout")]
1214    Timeout,
1215    #[error("IO error")]
1216    Io(#[from] io::Error),
1217    #[error("connection error")]
1218    Connection(#[from] ConnectionError),
1219}
1220
1221// RAII guard which when dropped removes the peer from the registry.
1222struct PeerGuard<'a> {
1223    registry: &'a BlockingMutex<Registry>,
1224    key: usize,
1225}
1226
1227impl Drop for PeerGuard<'_> {
1228    fn drop(&mut self) {
1229        if let Some(peers) = &mut self
1230            .registry
1231            .lock()
1232            .unwrap_or_else(|error| error.into_inner())
1233            .peers
1234        {
1235            peers.try_remove(self.key);
1236        }
1237    }
1238}
1239
1240struct PortMappings {
1241    _mappings: Vec<upnp::Mapping>,
1242}
1243
1244impl PortMappings {
1245    fn new(forwarder: &upnp::PortForwarder, gateway: &Gateway) -> Self {
1246        let mappings = gateway
1247            .listener_local_addrs()
1248            .into_iter()
1249            .filter_map(|addr| {
1250                match addr {
1251                    PeerAddr::Quic(SocketAddr::V4(addr)) => {
1252                        Some(forwarder.add_mapping(
1253                            addr.port(), // internal
1254                            addr.port(), // external
1255                            ip::Protocol::Udp,
1256                        ))
1257                    }
1258                    PeerAddr::Tcp(SocketAddr::V4(addr)) => {
1259                        Some(forwarder.add_mapping(
1260                            addr.port(), // internal
1261                            addr.port(), // external
1262                            ip::Protocol::Tcp,
1263                        ))
1264                    }
1265                    PeerAddr::Quic(SocketAddr::V6(_)) | PeerAddr::Tcp(SocketAddr::V6(_)) => {
1266                        // TODO: the ipv6 port typically doesn't need to be port-mapped but it might
1267                        // need to be opened in the firewall ("pinholed"). Consider using UPnP for that
1268                        // as well.
1269                        None
1270                    }
1271                }
1272            })
1273            .collect();
1274
1275        Self {
1276            _mappings: mappings,
1277        }
1278    }
1279}
1280
1281enum ComponentState<T> {
1282    Enabled(T),
1283    Disabled(DisableReason),
1284}
1285
1286impl<T> ComponentState<T> {
1287    fn disabled(reason: DisableReason) -> Self {
1288        Self::Disabled(reason)
1289    }
1290
1291    fn is_enabled(&self) -> bool {
1292        matches!(self, Self::Enabled(_))
1293    }
1294
1295    fn is_disabled(&self, reason: DisableReason) -> bool {
1296        match self {
1297            Self::Disabled(current_reason) if *current_reason == reason => true,
1298            Self::Disabled(_) | Self::Enabled(_) => false,
1299        }
1300    }
1301
1302    fn disable(&mut self, reason: DisableReason) -> Option<T> {
1303        match mem::replace(self, Self::Disabled(reason)) {
1304            Self::Enabled(payload) => Some(payload),
1305            Self::Disabled(_) => None,
1306        }
1307    }
1308
1309    fn disable_if_enabled(&mut self, reason: DisableReason) -> Option<T> {
1310        match self {
1311            Self::Enabled(_) => match mem::replace(self, Self::Disabled(reason)) {
1312                Self::Enabled(payload) => Some(payload),
1313                Self::Disabled(_) => unreachable!(),
1314            },
1315            Self::Disabled(_) => None,
1316        }
1317    }
1318
1319    fn enable(&mut self, payload: T) -> Option<T> {
1320        match mem::replace(self, Self::Enabled(payload)) {
1321            Self::Enabled(payload) => Some(payload),
1322            Self::Disabled(_) => None,
1323        }
1324    }
1325}
1326
1327#[derive(Eq, PartialEq)]
1328enum DisableReason {
1329    // Disabled implicitly because `Network` was disabled
1330    Implicit,
1331    // Disabled explicitly
1332    Explicit,
1333}
1334
1335pub fn repository_info_hash(id: &RepositoryId) -> InfoHash {
1336    // Calculate the info hash by hashing the id with BLAKE3 and taking the first 20 bytes.
1337    // (bittorrent uses SHA-1 but that is less secure).
1338    // `unwrap` is OK because the byte slice has the correct length.
1339    InfoHash::try_from(&id.salted_hash(b"ouisync repository info-hash").as_ref()[..INFO_HASH_LEN])
1340        .unwrap()
1341}
1342
1343async fn shutdown_peers(peers: Slab<MessageBroker>) {
1344    future::join_all(peers.into_iter().map(|(_, peer)| peer.shutdown())).await;
1345}