ouisync/network/
dht_discovery.rs

1use super::{
2    peer_addr::PeerAddr,
3    seen_peers::{SeenPeer, SeenPeers},
4};
5use async_trait::async_trait;
6use btdht::{InfoHash, MainlineDht};
7use chrono::{offset::Local, DateTime};
8use deadlock::{AsyncMutex, BlockingMutex};
9use futures_util::{stream, StreamExt};
10use net::{quic, udp::DatagramSocket};
11use rand::Rng;
12use scoped_task::ScopedJoinHandle;
13use state_monitor::StateMonitor;
14use std::{
15    collections::{hash_map, HashMap, HashSet},
16    future::pending,
17    io,
18    net::{SocketAddr, SocketAddrV4, SocketAddrV6},
19    sync::{
20        atomic::{AtomicU64, Ordering},
21        Arc, OnceLock, Weak,
22    },
23    time::SystemTime,
24};
25use tokio::{
26    select,
27    sync::{mpsc, watch},
28    time::{self, timeout, Duration},
29};
30use tracing::{instrument::Instrument, Span};
31
32// Hardcoded DHT routers to bootstrap the DHT against.
33// TODO: add this to `NetworkOptions` so it can be overriden by the user.
34pub const DHT_ROUTERS: &[&str] = &[
35    "dht.ouisync.net:6881",
36    "router.bittorrent.com:6881",
37    "dht.transmissionbt.com:6881",
38];
39
40// Interval for the delay before a repository is re-announced on the DHT. The actual delay is an
41// uniformly random value from this interval.
42// BEP5 indicates that "After 15 minutes of inactivity, a node becomes questionable." so try not
43// to get too close to that value to avoid DHT churn. However, too frequent updates may cause
44// other nodes to put us on a blacklist.
45const MIN_DHT_ANNOUNCE_DELAY: Duration = Duration::from_secs(3 * 60);
46const MAX_DHT_ANNOUNCE_DELAY: Duration = Duration::from_secs(6 * 60);
47
48#[async_trait]
49pub trait DhtContactsStoreTrait: Sync + Send + 'static {
50    async fn load_v4(&self) -> io::Result<HashSet<SocketAddrV4>>;
51    async fn load_v6(&self) -> io::Result<HashSet<SocketAddrV6>>;
52    async fn store_v4(&self, contacts: HashSet<SocketAddrV4>) -> io::Result<()>;
53    async fn store_v6(&self, contacts: HashSet<SocketAddrV6>) -> io::Result<()>;
54}
55
56pub(super) struct DhtDiscovery {
57    v4: BlockingMutex<RestartableDht>,
58    v6: BlockingMutex<RestartableDht>,
59    lookups: Arc<BlockingMutex<Lookups>>,
60    next_id: AtomicU64,
61    main_monitor: StateMonitor,
62    lookups_monitor: StateMonitor,
63    span: Span,
64}
65
66impl DhtDiscovery {
67    pub fn new(
68        socket_maker_v4: Option<quic::SideChannelMaker>,
69        socket_maker_v6: Option<quic::SideChannelMaker>,
70        contacts_store: Option<Arc<dyn DhtContactsStoreTrait>>,
71        monitor: StateMonitor,
72    ) -> Self {
73        let v4 = BlockingMutex::new(RestartableDht::new(socket_maker_v4, contacts_store.clone()));
74        let v6 = BlockingMutex::new(RestartableDht::new(socket_maker_v6, contacts_store));
75
76        let lookups = Arc::new(BlockingMutex::new(HashMap::default()));
77
78        let lookups_monitor = monitor.make_child("lookups");
79
80        Self {
81            v4,
82            v6,
83            lookups,
84            next_id: AtomicU64::new(0),
85            span: Span::current(),
86            main_monitor: monitor,
87            lookups_monitor,
88        }
89    }
90
91    // Bind new sockets to the DHT instances. If there are any ongoing lookups, the current DHTs
92    // are terminated, new DHTs with the new sockets are created and the lookups are restarted on
93    // those new DHTs.
94    pub fn rebind(
95        &self,
96        socket_maker_v4: Option<quic::SideChannelMaker>,
97        socket_maker_v6: Option<quic::SideChannelMaker>,
98    ) {
99        let mut v4 = self.v4.lock().unwrap();
100        let mut v6 = self.v6.lock().unwrap();
101
102        v4.rebind(socket_maker_v4);
103        v6.rebind(socket_maker_v6);
104
105        let mut lookups = self.lookups.lock().unwrap();
106
107        if lookups.is_empty() {
108            return;
109        }
110
111        let dht_v4 = v4.fetch(&self.main_monitor, &self.span);
112        let dht_v6 = v6.fetch(&self.main_monitor, &self.span);
113
114        for (info_hash, lookup) in &mut *lookups {
115            lookup.restart(
116                dht_v4.clone(),
117                dht_v6.clone(),
118                *info_hash,
119                &self.lookups_monitor,
120                &self.span,
121            );
122        }
123    }
124
125    pub fn start_lookup(
126        &self,
127        info_hash: InfoHash,
128        found_peers_tx: mpsc::UnboundedSender<SeenPeer>,
129    ) -> LookupRequest {
130        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
131
132        let request = LookupRequest {
133            id,
134            info_hash,
135            lookups: Arc::downgrade(&self.lookups),
136        };
137
138        let mut lookups = self.lookups.lock().unwrap();
139
140        match lookups.entry(info_hash) {
141            hash_map::Entry::Occupied(mut entry) => entry.get_mut().add_request(id, found_peers_tx),
142            hash_map::Entry::Vacant(entry) => {
143                let dht_v4 = self
144                    .v4
145                    .lock()
146                    .unwrap()
147                    .fetch(&self.main_monitor, &self.span);
148
149                let dht_v6 = self
150                    .v6
151                    .lock()
152                    .unwrap()
153                    .fetch(&self.main_monitor, &self.span);
154
155                entry
156                    .insert(Lookup::start(
157                        dht_v4,
158                        dht_v6,
159                        info_hash,
160                        &self.lookups_monitor,
161                        &self.span,
162                    ))
163                    .add_request(id, found_peers_tx);
164            }
165        }
166
167        request
168    }
169}
170
171// Wrapper for a DHT instance that can be stopped and restarted at any point.
172struct RestartableDht {
173    socket_maker: Option<quic::SideChannelMaker>,
174    dht: Weak<Option<TaskOrResult<MonitoredDht>>>,
175    contacts_store: Option<Arc<dyn DhtContactsStoreTrait>>,
176}
177
178impl RestartableDht {
179    fn new(
180        socket_maker: Option<quic::SideChannelMaker>,
181        contacts_store: Option<Arc<dyn DhtContactsStoreTrait>>,
182    ) -> Self {
183        Self {
184            socket_maker,
185            dht: Weak::new(),
186            contacts_store,
187        }
188    }
189
190    // Retrieve a shared pointer to a running DHT instance if there is one already or start a new
191    // one. When all such pointers are dropped, the underlying DHT is terminated.
192    fn fetch(
193        &mut self,
194        monitor: &StateMonitor,
195        span: &Span,
196    ) -> Arc<Option<TaskOrResult<MonitoredDht>>> {
197        if let Some(dht) = self.dht.upgrade() {
198            dht
199        } else if let Some(maker) = &self.socket_maker {
200            let socket = maker.make();
201            let dht = MonitoredDht::start(socket, monitor, span, self.contacts_store.clone());
202
203            let dht = Arc::new(Some(dht));
204
205            self.dht = Arc::downgrade(&dht);
206
207            dht
208        } else {
209            Arc::new(None)
210        }
211    }
212
213    fn rebind(&mut self, socket_maker: Option<quic::SideChannelMaker>) {
214        self.socket_maker = socket_maker;
215        self.dht = Weak::new();
216    }
217}
218
219// Wrapper for a DHT instance that periodically outputs it's state to the provided StateMonitor.
220struct MonitoredDht {
221    dht: MainlineDht,
222    _monitoring_task: ScopedJoinHandle<()>,
223    _periodic_dht_node_load_task: Option<ScopedJoinHandle<()>>,
224}
225
226impl MonitoredDht {
227    fn start(
228        socket: quic::SideChannel,
229        parent_monitor: &StateMonitor,
230        span: &Span,
231        contacts_store: Option<Arc<dyn DhtContactsStoreTrait>>,
232    ) -> TaskOrResult<Self> {
233        // TODO: Unwrap
234        let local_addr = socket.local_addr().unwrap();
235
236        let (is_v4, monitor_name, span) = match local_addr {
237            SocketAddr::V4(_) => (true, "IPv4", tracing::info_span!(parent: span, "DHT/IPv4")),
238            SocketAddr::V6(_) => (false, "IPv6", tracing::info_span!(parent: span, "DHT/IPv6")),
239        };
240
241        let monitor = parent_monitor.make_child(monitor_name);
242
243        TaskOrResult::new(scoped_task::spawn(MonitoredDht::create(
244            is_v4,
245            socket,
246            monitor,
247            span,
248            contacts_store,
249        )))
250    }
251
252    async fn create(
253        is_v4: bool,
254        socket: quic::SideChannel,
255        monitor: StateMonitor,
256        span: Span,
257        contacts_store: Option<Arc<dyn DhtContactsStoreTrait>>,
258    ) -> Self {
259        // TODO: load the DHT state from a previous save if it exists.
260        let mut builder = MainlineDht::builder()
261            .add_routers(DHT_ROUTERS.iter().copied())
262            .set_read_only(false);
263
264        if let Some(contacts_store) = &contacts_store {
265            let initial_contacts = Self::load_initial_contacts(is_v4, &**contacts_store).await;
266
267            for contact in initial_contacts {
268                builder = builder.add_node(contact);
269            }
270        }
271
272        let dht = builder
273            .start(Socket(socket))
274            // TODO: `start` only fails if the socket has been closed. That shouldn't be the case
275            // there but better check.
276            .unwrap();
277
278        // Spawn a task to monitor the DHT status.
279        let monitoring_task = {
280            let dht = dht.clone();
281
282            let first_bootstrap = monitor.make_value("first_bootstrap", "in progress");
283            let probe_counter = monitor.make_value("probe_counter", 0);
284            let is_running = monitor.make_value("is_running", false);
285            let bootstrapped = monitor.make_value("bootstrapped", false);
286            let good_nodes = monitor.make_value("good_nodes", 0);
287            let questionable_nodes = monitor.make_value("questionable_nodes", 0);
288            let buckets = monitor.make_value("buckets", 0);
289
290            async move {
291                tracing::info!("bootstrap started");
292
293                if dht.bootstrapped().await {
294                    *first_bootstrap.get() = "done";
295                    tracing::info!("bootstrap complete");
296                } else {
297                    *first_bootstrap.get() = "failed";
298                    tracing::error!("bootstrap failed");
299
300                    // Don't `return`, instead halt here so that the `first_bootstrap` monitored value
301                    // is preserved for the user to see.
302                    pending::<()>().await;
303                }
304
305                loop {
306                    *probe_counter.get() += 1;
307
308                    if let Some(state) = dht.get_state().await {
309                        *is_running.get() = true;
310                        *bootstrapped.get() = true;
311                        *good_nodes.get() = state.good_node_count;
312                        *questionable_nodes.get() = state.questionable_node_count;
313                        *buckets.get() = state.bucket_count;
314                    } else {
315                        *is_running.get() = false;
316                        *bootstrapped.get() = false;
317                        *good_nodes.get() = 0;
318                        *questionable_nodes.get() = 0;
319                        *buckets.get() = 0;
320                    }
321
322                    time::sleep(Duration::from_secs(5)).await;
323                }
324            }
325        };
326        let monitoring_task = monitoring_task.instrument(span.clone());
327        let monitoring_task = scoped_task::spawn(monitoring_task);
328
329        let _periodic_dht_node_load_task = contacts_store.map(|contacts_store| {
330            scoped_task::spawn(
331                Self::keep_reading_contacts(is_v4, dht.clone(), contacts_store).instrument(span),
332            )
333        });
334
335        Self {
336            dht,
337            _monitoring_task: monitoring_task,
338            _periodic_dht_node_load_task,
339        }
340    }
341
342    /// Periodically read contacts from the `dht` and send it to `on_periodic_dht_node_load_tx`.
343    async fn keep_reading_contacts(
344        is_v4: bool,
345        dht: MainlineDht,
346        contacts_store: Arc<dyn DhtContactsStoreTrait>,
347    ) {
348        let mut reported_failure = false;
349
350        // Give `MainlineDht` a chance to bootstrap.
351        time::sleep(Duration::from_secs(10)).await;
352
353        loop {
354            let (good, questionable) = match dht.load_contacts().await {
355                Ok((good, questionable)) => (good, questionable),
356                Err(error) => {
357                    tracing::warn!("DhtDiscovery stopped reading contacts: {error:?}");
358                    break;
359                }
360            };
361
362            // TODO: Make use of the information which is good and which questionable.
363            let mix = good.union(&questionable);
364
365            if is_v4 {
366                let mix = mix.filter_map(|addr| match addr {
367                    SocketAddr::V4(addr) => Some(*addr),
368                    SocketAddr::V6(_) => None,
369                });
370
371                match contacts_store.store_v4(mix.collect()).await {
372                    Ok(()) => reported_failure = false,
373                    Err(error) => {
374                        if !reported_failure {
375                            reported_failure = true;
376                            tracing::error!("DhtDiscovery failed to write contacts {error:?}");
377                        }
378                    }
379                }
380            } else {
381                let mix = mix.filter_map(|addr| match addr {
382                    SocketAddr::V4(_) => None,
383                    SocketAddr::V6(addr) => Some(*addr),
384                });
385
386                match contacts_store.store_v6(mix.collect()).await {
387                    Ok(()) => reported_failure = false,
388                    Err(error) => {
389                        if !reported_failure {
390                            reported_failure = true;
391                            tracing::error!("DhtDiscovery failed to write contacts {error:?}");
392                        }
393                    }
394                }
395            }
396
397            time::sleep(Duration::from_secs(60)).await;
398        }
399    }
400
401    async fn load_initial_contacts(
402        is_v4: bool,
403        contacts_store: &(impl DhtContactsStoreTrait + ?Sized),
404    ) -> HashSet<SocketAddr> {
405        if is_v4 {
406            match contacts_store.load_v4().await {
407                Ok(contacts) => contacts.iter().cloned().map(SocketAddr::V4).collect(),
408                Err(error) => {
409                    tracing::error!("Failed to load DHT IPv4 contacts {:?}", error);
410                    Default::default()
411                }
412            }
413        } else {
414            match contacts_store.load_v6().await {
415                Ok(contacts) => contacts.iter().cloned().map(SocketAddr::V6).collect(),
416                Err(error) => {
417                    tracing::error!("Failed to load DHT IPv4 contacts {:?}", error);
418                    Default::default()
419                }
420            }
421        }
422    }
423}
424
425type Lookups = HashMap<InfoHash, Lookup>;
426
427type RequestId = u64;
428
429pub struct LookupRequest {
430    id: RequestId,
431    info_hash: InfoHash,
432    lookups: Weak<BlockingMutex<Lookups>>,
433}
434
435impl Drop for LookupRequest {
436    fn drop(&mut self) {
437        if let Some(lookups) = self.lookups.upgrade() {
438            let mut lookups = lookups.lock().unwrap();
439
440            let empty = if let Some(lookup) = lookups.get_mut(&self.info_hash) {
441                let mut requests = lookup.requests.lock().unwrap();
442                requests.remove(&self.id);
443                requests.is_empty()
444            } else {
445                false
446            };
447
448            if empty {
449                lookups.remove(&self.info_hash);
450            }
451        }
452    }
453}
454
455struct Lookup {
456    seen_peers: Arc<SeenPeers>,
457    requests: Arc<BlockingMutex<HashMap<RequestId, mpsc::UnboundedSender<SeenPeer>>>>,
458    wake_up_tx: watch::Sender<()>,
459    task: Option<ScopedJoinHandle<()>>,
460}
461
462impl Lookup {
463    fn start(
464        dht_v4: Arc<Option<TaskOrResult<MonitoredDht>>>,
465        dht_v6: Arc<Option<TaskOrResult<MonitoredDht>>>,
466        info_hash: InfoHash,
467        monitor: &StateMonitor,
468        span: &Span,
469    ) -> Self {
470        let (wake_up_tx, mut wake_up_rx) = watch::channel(());
471        // Mark the initial value as seen so the change notification is not triggered immediately
472        // but only when we create the first request.
473        wake_up_rx.borrow_and_update();
474
475        let seen_peers = Arc::new(SeenPeers::new());
476        let requests = Arc::new(BlockingMutex::new(HashMap::default()));
477
478        let task = if dht_v4.is_some() || dht_v6.is_some() {
479            Some(Self::start_task(
480                dht_v4,
481                dht_v6,
482                info_hash,
483                seen_peers.clone(),
484                requests.clone(),
485                wake_up_rx,
486                monitor,
487                span,
488            ))
489        } else {
490            None
491        };
492
493        Lookup {
494            seen_peers,
495            requests,
496            wake_up_tx,
497            task,
498        }
499    }
500
501    // Start this same lookup on different DHT instances
502    fn restart(
503        &mut self,
504        dht_v4: Arc<Option<TaskOrResult<MonitoredDht>>>,
505        dht_v6: Arc<Option<TaskOrResult<MonitoredDht>>>,
506        info_hash: InfoHash,
507        monitor: &StateMonitor,
508        span: &Span,
509    ) {
510        if dht_v4.is_none() && dht_v6.is_none() {
511            self.task.take();
512            return;
513        }
514
515        let task = Self::start_task(
516            dht_v4,
517            dht_v6,
518            info_hash,
519            self.seen_peers.clone(),
520            self.requests.clone(),
521            self.wake_up_tx.subscribe(),
522            monitor,
523            span,
524        );
525
526        self.task = Some(task);
527        self.wake_up_tx.send(()).ok();
528    }
529
530    fn add_request(&mut self, id: RequestId, tx: mpsc::UnboundedSender<SeenPeer>) {
531        for peer in self.seen_peers.collect() {
532            tx.send(peer.clone()).unwrap_or(());
533        }
534
535        self.requests.lock().unwrap().insert(id, tx);
536        // `unwrap_or` because if the network is down, there should be no tasks that listen to this
537        // wake up request.
538        self.wake_up_tx.send(()).unwrap_or(());
539    }
540
541    #[allow(clippy::too_many_arguments)]
542    fn start_task(
543        dht_v4: Arc<Option<TaskOrResult<MonitoredDht>>>,
544        dht_v6: Arc<Option<TaskOrResult<MonitoredDht>>>,
545        info_hash: InfoHash,
546        seen_peers: Arc<SeenPeers>,
547        requests: Arc<BlockingMutex<HashMap<RequestId, mpsc::UnboundedSender<SeenPeer>>>>,
548        mut wake_up: watch::Receiver<()>,
549        lookups_monitor: &StateMonitor,
550        span: &Span,
551    ) -> ScopedJoinHandle<()> {
552        let monitor = lookups_monitor.make_child(format!("{info_hash:?}"));
553        let state = monitor.make_value("state", "started");
554        let next = monitor.make_value("next", SystemTime::now().into());
555
556        let task = async move {
557            let dht_v4 = match &*dht_v4 {
558                Some(dht) => Some(dht.result().await),
559                None => None,
560            };
561
562            let dht_v6 = match &*dht_v6 {
563                Some(dht) => Some(dht.result().await),
564                None => None,
565            };
566
567            // Wait for the first request to be created
568            wake_up.changed().await.unwrap_or(());
569
570            loop {
571                seen_peers.start_new_round();
572
573                tracing::debug!(?info_hash, "starting search");
574                *state.get() = "making request";
575
576                // find peers for the repo and also announce that we have it.
577                let dhts = dht_v4.iter().chain(dht_v6.iter());
578
579                let mut peers = Box::pin(stream::iter(dhts).flat_map(|dht| {
580                    stream::once(async {
581                        timeout(Duration::from_secs(10), dht.dht.bootstrapped())
582                            .await
583                            .unwrap_or(false);
584                        dht.dht.search(info_hash, true)
585                    })
586                    .flatten()
587                }));
588
589                *state.get() = "awaiting results";
590
591                while let Some(addr) = peers.next().await {
592                    if let Some(peer) = seen_peers.insert(PeerAddr::Quic(addr)) {
593                        for tx in requests.lock().unwrap().values() {
594                            tx.send(peer.clone()).unwrap_or(());
595                        }
596                    }
597                }
598
599                // sleep a random duration before the next search, but wake up if there is a new
600                // request.
601                let duration =
602                    rand::thread_rng().gen_range(MIN_DHT_ANNOUNCE_DELAY..MAX_DHT_ANNOUNCE_DELAY);
603
604                {
605                    let time: DateTime<Local> = (SystemTime::now() + duration).into();
606                    tracing::debug!(
607                        ?info_hash,
608                        "search ended. next one scheduled at {} (in {:?})",
609                        time.format("%T"),
610                        duration
611                    );
612
613                    *state.get() = "sleeping";
614                    *next.get() = time;
615                }
616
617                select! {
618                    _ = time::sleep(duration) => (),
619                    _ = wake_up.changed() => (),
620                }
621            }
622        };
623        let task = task.instrument(span.clone());
624
625        scoped_task::spawn(task)
626    }
627}
628
629struct Socket(quic::SideChannel);
630
631#[async_trait]
632impl btdht::SocketTrait for Socket {
633    async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<()> {
634        self.0.send_to(buf, *target).await?;
635        Ok(())
636    }
637
638    async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
639        self.0.recv_from(buf).await
640    }
641
642    fn local_addr(&self) -> io::Result<SocketAddr> {
643        self.0.local_addr()
644    }
645}
646
647struct TaskOrResult<T> {
648    task: AsyncMutex<Option<ScopedJoinHandle<T>>>,
649    result: OnceLock<T>,
650}
651
652impl<T> TaskOrResult<T> {
653    fn new(task: ScopedJoinHandle<T>) -> Self {
654        Self {
655            task: AsyncMutex::new(Some(task)),
656            result: OnceLock::new(),
657        }
658    }
659
660    // Note that this function is not cancel safe.
661    async fn result(&self) -> &T {
662        if let Some(result) = self.result.get() {
663            return result;
664        }
665
666        let mut lock = self.task.lock().await;
667
668        if let Some(handle) = lock.take() {
669            // The unwrap is OK for the same reason we unwrap `BlockingMutex::lock()`s.
670            // The assert is OK because we can await on the handle only once.
671            assert!(self.result.set(handle.await.unwrap()).is_ok());
672        }
673
674        // Unwrap is OK because we ensured the `result` holds a value.
675        self.result.get().unwrap()
676    }
677}