1use super::{
2 peer_addr::PeerAddr,
3 seen_peers::{SeenPeer, SeenPeers},
4};
5use async_trait::async_trait;
6use btdht::{InfoHash, MainlineDht};
7use chrono::{offset::Local, DateTime};
8use deadlock::{AsyncMutex, BlockingMutex};
9use futures_util::{stream, StreamExt};
10use net::{quic, udp::DatagramSocket};
11use rand::Rng;
12use scoped_task::ScopedJoinHandle;
13use state_monitor::StateMonitor;
14use std::{
15 collections::{hash_map, HashMap, HashSet},
16 future::pending,
17 io,
18 net::{SocketAddr, SocketAddrV4, SocketAddrV6},
19 sync::{
20 atomic::{AtomicU64, Ordering},
21 Arc, OnceLock, Weak,
22 },
23 time::SystemTime,
24};
25use tokio::{
26 select,
27 sync::{mpsc, watch},
28 time::{self, timeout, Duration},
29};
30use tracing::{instrument::Instrument, Span};
31
32pub const DHT_ROUTERS: &[&str] = &[
35 "dht.ouisync.net:6881",
36 "router.bittorrent.com:6881",
37 "dht.transmissionbt.com:6881",
38];
39
40const MIN_DHT_ANNOUNCE_DELAY: Duration = Duration::from_secs(3 * 60);
46const MAX_DHT_ANNOUNCE_DELAY: Duration = Duration::from_secs(6 * 60);
47
48#[async_trait]
49pub trait DhtContactsStoreTrait: Sync + Send + 'static {
50 async fn load_v4(&self) -> io::Result<HashSet<SocketAddrV4>>;
51 async fn load_v6(&self) -> io::Result<HashSet<SocketAddrV6>>;
52 async fn store_v4(&self, contacts: HashSet<SocketAddrV4>) -> io::Result<()>;
53 async fn store_v6(&self, contacts: HashSet<SocketAddrV6>) -> io::Result<()>;
54}
55
56pub(super) struct DhtDiscovery {
57 v4: BlockingMutex<RestartableDht>,
58 v6: BlockingMutex<RestartableDht>,
59 lookups: Arc<BlockingMutex<Lookups>>,
60 next_id: AtomicU64,
61 main_monitor: StateMonitor,
62 lookups_monitor: StateMonitor,
63 span: Span,
64}
65
66impl DhtDiscovery {
67 pub fn new(
68 socket_maker_v4: Option<quic::SideChannelMaker>,
69 socket_maker_v6: Option<quic::SideChannelMaker>,
70 contacts_store: Option<Arc<dyn DhtContactsStoreTrait>>,
71 monitor: StateMonitor,
72 ) -> Self {
73 let v4 = BlockingMutex::new(RestartableDht::new(socket_maker_v4, contacts_store.clone()));
74 let v6 = BlockingMutex::new(RestartableDht::new(socket_maker_v6, contacts_store));
75
76 let lookups = Arc::new(BlockingMutex::new(HashMap::default()));
77
78 let lookups_monitor = monitor.make_child("lookups");
79
80 Self {
81 v4,
82 v6,
83 lookups,
84 next_id: AtomicU64::new(0),
85 span: Span::current(),
86 main_monitor: monitor,
87 lookups_monitor,
88 }
89 }
90
91 pub fn rebind(
95 &self,
96 socket_maker_v4: Option<quic::SideChannelMaker>,
97 socket_maker_v6: Option<quic::SideChannelMaker>,
98 ) {
99 let mut v4 = self.v4.lock().unwrap();
100 let mut v6 = self.v6.lock().unwrap();
101
102 v4.rebind(socket_maker_v4);
103 v6.rebind(socket_maker_v6);
104
105 let mut lookups = self.lookups.lock().unwrap();
106
107 if lookups.is_empty() {
108 return;
109 }
110
111 let dht_v4 = v4.fetch(&self.main_monitor, &self.span);
112 let dht_v6 = v6.fetch(&self.main_monitor, &self.span);
113
114 for (info_hash, lookup) in &mut *lookups {
115 lookup.restart(
116 dht_v4.clone(),
117 dht_v6.clone(),
118 *info_hash,
119 &self.lookups_monitor,
120 &self.span,
121 );
122 }
123 }
124
125 pub fn start_lookup(
126 &self,
127 info_hash: InfoHash,
128 found_peers_tx: mpsc::UnboundedSender<SeenPeer>,
129 ) -> LookupRequest {
130 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
131
132 let request = LookupRequest {
133 id,
134 info_hash,
135 lookups: Arc::downgrade(&self.lookups),
136 };
137
138 let mut lookups = self.lookups.lock().unwrap();
139
140 match lookups.entry(info_hash) {
141 hash_map::Entry::Occupied(mut entry) => entry.get_mut().add_request(id, found_peers_tx),
142 hash_map::Entry::Vacant(entry) => {
143 let dht_v4 = self
144 .v4
145 .lock()
146 .unwrap()
147 .fetch(&self.main_monitor, &self.span);
148
149 let dht_v6 = self
150 .v6
151 .lock()
152 .unwrap()
153 .fetch(&self.main_monitor, &self.span);
154
155 entry
156 .insert(Lookup::start(
157 dht_v4,
158 dht_v6,
159 info_hash,
160 &self.lookups_monitor,
161 &self.span,
162 ))
163 .add_request(id, found_peers_tx);
164 }
165 }
166
167 request
168 }
169}
170
171struct RestartableDht {
173 socket_maker: Option<quic::SideChannelMaker>,
174 dht: Weak<Option<TaskOrResult<MonitoredDht>>>,
175 contacts_store: Option<Arc<dyn DhtContactsStoreTrait>>,
176}
177
178impl RestartableDht {
179 fn new(
180 socket_maker: Option<quic::SideChannelMaker>,
181 contacts_store: Option<Arc<dyn DhtContactsStoreTrait>>,
182 ) -> Self {
183 Self {
184 socket_maker,
185 dht: Weak::new(),
186 contacts_store,
187 }
188 }
189
190 fn fetch(
193 &mut self,
194 monitor: &StateMonitor,
195 span: &Span,
196 ) -> Arc<Option<TaskOrResult<MonitoredDht>>> {
197 if let Some(dht) = self.dht.upgrade() {
198 dht
199 } else if let Some(maker) = &self.socket_maker {
200 let socket = maker.make();
201 let dht = MonitoredDht::start(socket, monitor, span, self.contacts_store.clone());
202
203 let dht = Arc::new(Some(dht));
204
205 self.dht = Arc::downgrade(&dht);
206
207 dht
208 } else {
209 Arc::new(None)
210 }
211 }
212
213 fn rebind(&mut self, socket_maker: Option<quic::SideChannelMaker>) {
214 self.socket_maker = socket_maker;
215 self.dht = Weak::new();
216 }
217}
218
219struct MonitoredDht {
221 dht: MainlineDht,
222 _monitoring_task: ScopedJoinHandle<()>,
223 _periodic_dht_node_load_task: Option<ScopedJoinHandle<()>>,
224}
225
226impl MonitoredDht {
227 fn start(
228 socket: quic::SideChannel,
229 parent_monitor: &StateMonitor,
230 span: &Span,
231 contacts_store: Option<Arc<dyn DhtContactsStoreTrait>>,
232 ) -> TaskOrResult<Self> {
233 let local_addr = socket.local_addr().unwrap();
235
236 let (is_v4, monitor_name, span) = match local_addr {
237 SocketAddr::V4(_) => (true, "IPv4", tracing::info_span!(parent: span, "DHT/IPv4")),
238 SocketAddr::V6(_) => (false, "IPv6", tracing::info_span!(parent: span, "DHT/IPv6")),
239 };
240
241 let monitor = parent_monitor.make_child(monitor_name);
242
243 TaskOrResult::new(scoped_task::spawn(MonitoredDht::create(
244 is_v4,
245 socket,
246 monitor,
247 span,
248 contacts_store,
249 )))
250 }
251
252 async fn create(
253 is_v4: bool,
254 socket: quic::SideChannel,
255 monitor: StateMonitor,
256 span: Span,
257 contacts_store: Option<Arc<dyn DhtContactsStoreTrait>>,
258 ) -> Self {
259 let mut builder = MainlineDht::builder()
261 .add_routers(DHT_ROUTERS.iter().copied())
262 .set_read_only(false);
263
264 if let Some(contacts_store) = &contacts_store {
265 let initial_contacts = Self::load_initial_contacts(is_v4, &**contacts_store).await;
266
267 for contact in initial_contacts {
268 builder = builder.add_node(contact);
269 }
270 }
271
272 let dht = builder
273 .start(Socket(socket))
274 .unwrap();
277
278 let monitoring_task = {
280 let dht = dht.clone();
281
282 let first_bootstrap = monitor.make_value("first_bootstrap", "in progress");
283 let probe_counter = monitor.make_value("probe_counter", 0);
284 let is_running = monitor.make_value("is_running", false);
285 let bootstrapped = monitor.make_value("bootstrapped", false);
286 let good_nodes = monitor.make_value("good_nodes", 0);
287 let questionable_nodes = monitor.make_value("questionable_nodes", 0);
288 let buckets = monitor.make_value("buckets", 0);
289
290 async move {
291 tracing::info!("bootstrap started");
292
293 if dht.bootstrapped().await {
294 *first_bootstrap.get() = "done";
295 tracing::info!("bootstrap complete");
296 } else {
297 *first_bootstrap.get() = "failed";
298 tracing::error!("bootstrap failed");
299
300 pending::<()>().await;
303 }
304
305 loop {
306 *probe_counter.get() += 1;
307
308 if let Some(state) = dht.get_state().await {
309 *is_running.get() = true;
310 *bootstrapped.get() = true;
311 *good_nodes.get() = state.good_node_count;
312 *questionable_nodes.get() = state.questionable_node_count;
313 *buckets.get() = state.bucket_count;
314 } else {
315 *is_running.get() = false;
316 *bootstrapped.get() = false;
317 *good_nodes.get() = 0;
318 *questionable_nodes.get() = 0;
319 *buckets.get() = 0;
320 }
321
322 time::sleep(Duration::from_secs(5)).await;
323 }
324 }
325 };
326 let monitoring_task = monitoring_task.instrument(span.clone());
327 let monitoring_task = scoped_task::spawn(monitoring_task);
328
329 let _periodic_dht_node_load_task = contacts_store.map(|contacts_store| {
330 scoped_task::spawn(
331 Self::keep_reading_contacts(is_v4, dht.clone(), contacts_store).instrument(span),
332 )
333 });
334
335 Self {
336 dht,
337 _monitoring_task: monitoring_task,
338 _periodic_dht_node_load_task,
339 }
340 }
341
342 async fn keep_reading_contacts(
344 is_v4: bool,
345 dht: MainlineDht,
346 contacts_store: Arc<dyn DhtContactsStoreTrait>,
347 ) {
348 let mut reported_failure = false;
349
350 time::sleep(Duration::from_secs(10)).await;
352
353 loop {
354 let (good, questionable) = match dht.load_contacts().await {
355 Ok((good, questionable)) => (good, questionable),
356 Err(error) => {
357 tracing::warn!("DhtDiscovery stopped reading contacts: {error:?}");
358 break;
359 }
360 };
361
362 let mix = good.union(&questionable);
364
365 if is_v4 {
366 let mix = mix.filter_map(|addr| match addr {
367 SocketAddr::V4(addr) => Some(*addr),
368 SocketAddr::V6(_) => None,
369 });
370
371 match contacts_store.store_v4(mix.collect()).await {
372 Ok(()) => reported_failure = false,
373 Err(error) => {
374 if !reported_failure {
375 reported_failure = true;
376 tracing::error!("DhtDiscovery failed to write contacts {error:?}");
377 }
378 }
379 }
380 } else {
381 let mix = mix.filter_map(|addr| match addr {
382 SocketAddr::V4(_) => None,
383 SocketAddr::V6(addr) => Some(*addr),
384 });
385
386 match contacts_store.store_v6(mix.collect()).await {
387 Ok(()) => reported_failure = false,
388 Err(error) => {
389 if !reported_failure {
390 reported_failure = true;
391 tracing::error!("DhtDiscovery failed to write contacts {error:?}");
392 }
393 }
394 }
395 }
396
397 time::sleep(Duration::from_secs(60)).await;
398 }
399 }
400
401 async fn load_initial_contacts(
402 is_v4: bool,
403 contacts_store: &(impl DhtContactsStoreTrait + ?Sized),
404 ) -> HashSet<SocketAddr> {
405 if is_v4 {
406 match contacts_store.load_v4().await {
407 Ok(contacts) => contacts.iter().cloned().map(SocketAddr::V4).collect(),
408 Err(error) => {
409 tracing::error!("Failed to load DHT IPv4 contacts {:?}", error);
410 Default::default()
411 }
412 }
413 } else {
414 match contacts_store.load_v6().await {
415 Ok(contacts) => contacts.iter().cloned().map(SocketAddr::V6).collect(),
416 Err(error) => {
417 tracing::error!("Failed to load DHT IPv4 contacts {:?}", error);
418 Default::default()
419 }
420 }
421 }
422 }
423}
424
425type Lookups = HashMap<InfoHash, Lookup>;
426
427type RequestId = u64;
428
429pub struct LookupRequest {
430 id: RequestId,
431 info_hash: InfoHash,
432 lookups: Weak<BlockingMutex<Lookups>>,
433}
434
435impl Drop for LookupRequest {
436 fn drop(&mut self) {
437 if let Some(lookups) = self.lookups.upgrade() {
438 let mut lookups = lookups.lock().unwrap();
439
440 let empty = if let Some(lookup) = lookups.get_mut(&self.info_hash) {
441 let mut requests = lookup.requests.lock().unwrap();
442 requests.remove(&self.id);
443 requests.is_empty()
444 } else {
445 false
446 };
447
448 if empty {
449 lookups.remove(&self.info_hash);
450 }
451 }
452 }
453}
454
455struct Lookup {
456 seen_peers: Arc<SeenPeers>,
457 requests: Arc<BlockingMutex<HashMap<RequestId, mpsc::UnboundedSender<SeenPeer>>>>,
458 wake_up_tx: watch::Sender<()>,
459 task: Option<ScopedJoinHandle<()>>,
460}
461
462impl Lookup {
463 fn start(
464 dht_v4: Arc<Option<TaskOrResult<MonitoredDht>>>,
465 dht_v6: Arc<Option<TaskOrResult<MonitoredDht>>>,
466 info_hash: InfoHash,
467 monitor: &StateMonitor,
468 span: &Span,
469 ) -> Self {
470 let (wake_up_tx, mut wake_up_rx) = watch::channel(());
471 wake_up_rx.borrow_and_update();
474
475 let seen_peers = Arc::new(SeenPeers::new());
476 let requests = Arc::new(BlockingMutex::new(HashMap::default()));
477
478 let task = if dht_v4.is_some() || dht_v6.is_some() {
479 Some(Self::start_task(
480 dht_v4,
481 dht_v6,
482 info_hash,
483 seen_peers.clone(),
484 requests.clone(),
485 wake_up_rx,
486 monitor,
487 span,
488 ))
489 } else {
490 None
491 };
492
493 Lookup {
494 seen_peers,
495 requests,
496 wake_up_tx,
497 task,
498 }
499 }
500
501 fn restart(
503 &mut self,
504 dht_v4: Arc<Option<TaskOrResult<MonitoredDht>>>,
505 dht_v6: Arc<Option<TaskOrResult<MonitoredDht>>>,
506 info_hash: InfoHash,
507 monitor: &StateMonitor,
508 span: &Span,
509 ) {
510 if dht_v4.is_none() && dht_v6.is_none() {
511 self.task.take();
512 return;
513 }
514
515 let task = Self::start_task(
516 dht_v4,
517 dht_v6,
518 info_hash,
519 self.seen_peers.clone(),
520 self.requests.clone(),
521 self.wake_up_tx.subscribe(),
522 monitor,
523 span,
524 );
525
526 self.task = Some(task);
527 self.wake_up_tx.send(()).ok();
528 }
529
530 fn add_request(&mut self, id: RequestId, tx: mpsc::UnboundedSender<SeenPeer>) {
531 for peer in self.seen_peers.collect() {
532 tx.send(peer.clone()).unwrap_or(());
533 }
534
535 self.requests.lock().unwrap().insert(id, tx);
536 self.wake_up_tx.send(()).unwrap_or(());
539 }
540
541 #[allow(clippy::too_many_arguments)]
542 fn start_task(
543 dht_v4: Arc<Option<TaskOrResult<MonitoredDht>>>,
544 dht_v6: Arc<Option<TaskOrResult<MonitoredDht>>>,
545 info_hash: InfoHash,
546 seen_peers: Arc<SeenPeers>,
547 requests: Arc<BlockingMutex<HashMap<RequestId, mpsc::UnboundedSender<SeenPeer>>>>,
548 mut wake_up: watch::Receiver<()>,
549 lookups_monitor: &StateMonitor,
550 span: &Span,
551 ) -> ScopedJoinHandle<()> {
552 let monitor = lookups_monitor.make_child(format!("{info_hash:?}"));
553 let state = monitor.make_value("state", "started");
554 let next = monitor.make_value("next", SystemTime::now().into());
555
556 let task = async move {
557 let dht_v4 = match &*dht_v4 {
558 Some(dht) => Some(dht.result().await),
559 None => None,
560 };
561
562 let dht_v6 = match &*dht_v6 {
563 Some(dht) => Some(dht.result().await),
564 None => None,
565 };
566
567 wake_up.changed().await.unwrap_or(());
569
570 loop {
571 seen_peers.start_new_round();
572
573 tracing::debug!(?info_hash, "starting search");
574 *state.get() = "making request";
575
576 let dhts = dht_v4.iter().chain(dht_v6.iter());
578
579 let mut peers = Box::pin(stream::iter(dhts).flat_map(|dht| {
580 stream::once(async {
581 timeout(Duration::from_secs(10), dht.dht.bootstrapped())
582 .await
583 .unwrap_or(false);
584 dht.dht.search(info_hash, true)
585 })
586 .flatten()
587 }));
588
589 *state.get() = "awaiting results";
590
591 while let Some(addr) = peers.next().await {
592 if let Some(peer) = seen_peers.insert(PeerAddr::Quic(addr)) {
593 for tx in requests.lock().unwrap().values() {
594 tx.send(peer.clone()).unwrap_or(());
595 }
596 }
597 }
598
599 let duration =
602 rand::thread_rng().gen_range(MIN_DHT_ANNOUNCE_DELAY..MAX_DHT_ANNOUNCE_DELAY);
603
604 {
605 let time: DateTime<Local> = (SystemTime::now() + duration).into();
606 tracing::debug!(
607 ?info_hash,
608 "search ended. next one scheduled at {} (in {:?})",
609 time.format("%T"),
610 duration
611 );
612
613 *state.get() = "sleeping";
614 *next.get() = time;
615 }
616
617 select! {
618 _ = time::sleep(duration) => (),
619 _ = wake_up.changed() => (),
620 }
621 }
622 };
623 let task = task.instrument(span.clone());
624
625 scoped_task::spawn(task)
626 }
627}
628
629struct Socket(quic::SideChannel);
630
631#[async_trait]
632impl btdht::SocketTrait for Socket {
633 async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<()> {
634 self.0.send_to(buf, *target).await?;
635 Ok(())
636 }
637
638 async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
639 self.0.recv_from(buf).await
640 }
641
642 fn local_addr(&self) -> io::Result<SocketAddr> {
643 self.0.local_addr()
644 }
645}
646
647struct TaskOrResult<T> {
648 task: AsyncMutex<Option<ScopedJoinHandle<T>>>,
649 result: OnceLock<T>,
650}
651
652impl<T> TaskOrResult<T> {
653 fn new(task: ScopedJoinHandle<T>) -> Self {
654 Self {
655 task: AsyncMutex::new(Some(task)),
656 result: OnceLock::new(),
657 }
658 }
659
660 async fn result(&self) -> &T {
662 if let Some(result) = self.result.get() {
663 return result;
664 }
665
666 let mut lock = self.task.lock().await;
667
668 if let Some(handle) = lock.take() {
669 assert!(self.result.set(handle.await.unwrap()).is_ok());
672 }
673
674 self.result.get().unwrap()
676 }
677}