1mod lookup_stream;
2mod monitored;
3mod restartable;
4
5pub use lookup_stream::DhtLookupStream;
6use slab::Slab;
7
8use crate::{network::dht::restartable::ObservableDht, sync::WatchSenderExt};
9
10use super::{
11 peer_addr::PeerAddr,
12 seen_peers::{SeenPeer, SeenPeers},
13};
14use async_trait::async_trait;
15use btdht::{InfoHash, MainlineDht};
16use chrono::{DateTime, offset::Local};
17use deadlock::BlockingMutex;
18use futures_util::{Stream, StreamExt, future, stream};
19use net::quic;
20use rand::Rng;
21use restartable::RestartableDht;
22use state_monitor::StateMonitor;
23use std::{
24 collections::{HashMap, HashSet},
25 io,
26 net::{SocketAddr, SocketAddrV4, SocketAddrV6},
27 pin::pin,
28 sync::{Arc, Weak},
29 time::{Duration, SystemTime},
30};
31use tokio::{
32 select,
33 sync::{mpsc, watch},
34 task, time,
35};
36use tracing::{Span, instrument::Instrument};
37
38pub const DEFAULT_DHT_ROUTERS: &[&str] = &[
40 "dht.ouisync.net:6881",
41 "router.bittorrent.com:6881",
42 "dht.transmissionbt.com:6881",
43];
44
45const MIN_DHT_ANNOUNCE_DELAY: Duration = Duration::from_secs(3 * 60);
51const MAX_DHT_ANNOUNCE_DELAY: Duration = Duration::from_secs(6 * 60);
52
53const BOOTSTRAP_TIMEOUT: Duration = Duration::from_secs(10);
55
56#[async_trait]
60pub trait DhtContactsStoreTrait: Sync + Send + 'static {
61 async fn load_v4(&self) -> io::Result<HashSet<SocketAddrV4>>;
62 async fn load_v6(&self) -> io::Result<HashSet<SocketAddrV6>>;
63 async fn store_v4(&self, contacts: HashSet<SocketAddrV4>) -> io::Result<()>;
64 async fn store_v6(&self, contacts: HashSet<SocketAddrV6>) -> io::Result<()>;
65}
66
67pub(super) enum DhtEvent {
68 PeerFound(SeenPeer),
69 RoundEnded,
70}
71
72pub(super) struct DhtDiscovery {
73 v4: RestartableDht,
74 v6: RestartableDht,
75 lookups: Arc<BlockingMutex<LookupMap>>,
76 lookups_monitor: StateMonitor,
77 span: Span,
78}
79
80impl DhtDiscovery {
81 pub fn new(
82 socket_maker_v4: Option<quic::SideChannelMaker>,
83 socket_maker_v6: Option<quic::SideChannelMaker>,
84 contacts: Option<Arc<dyn DhtContactsStoreTrait>>,
85 monitor: StateMonitor,
86 ) -> Self {
87 let routers: HashSet<String> = DEFAULT_DHT_ROUTERS
88 .iter()
89 .copied()
90 .map(Into::into)
91 .collect();
92
93 let v4 = RestartableDht::new(contacts.clone(), monitor.clone());
94 v4.bind(socket_maker_v4);
95 v4.set_routers(routers.clone());
96
97 let v6 = RestartableDht::new(contacts, monitor.clone());
98 v6.bind(socket_maker_v6);
99 v6.set_routers(routers);
100
101 let lookups = Arc::new(BlockingMutex::new(HashMap::default()));
102
103 let lookups_monitor = monitor.make_child("lookups");
104
105 Self {
106 v4,
107 v6,
108 lookups,
109 lookups_monitor,
110 span: Span::current(),
111 }
112 }
113
114 pub fn rebind(
117 &self,
118 socket_maker_v4: Option<quic::SideChannelMaker>,
119 socket_maker_v6: Option<quic::SideChannelMaker>,
120 ) {
121 self.v4.bind(socket_maker_v4);
122 self.v6.bind(socket_maker_v6);
123 }
124
125 pub fn set_routers(&self, routers: HashSet<String>) {
127 self.v4.set_routers(routers.clone());
128 self.v6.set_routers(routers);
129 }
130
131 pub fn routers(&self) -> HashSet<String> {
133 self.v4
134 .routers()
135 .into_iter()
136 .chain(self.v6.routers())
137 .collect()
138 }
139
140 pub fn start_lookup(
143 &self,
144 info_hash: InfoHash,
145 announce: bool,
146 event_tx: mpsc::UnboundedSender<DhtEvent>,
147 ) -> LookupRequest {
148 let key = self
149 .lookups
150 .lock()
151 .unwrap()
152 .entry(info_hash)
153 .or_insert_with(|| {
154 let v4 = self.v4.observe();
155 let v6 = self.v6.observe();
156 Lookup::start(v4, v6, info_hash, &self.lookups_monitor, self.span.clone())
157 })
158 .add_request(announce, event_tx);
159
160 LookupRequest {
161 lookups: Arc::downgrade(&self.lookups),
162 info_hash,
163 key,
164 }
165 }
166
167 pub async fn pin(&self) -> DhtPin {
171 let v4 = self.v4.observe();
172 let v6 = self.v6.observe();
173
174 future::join(v4.started_or_disabled(), v6.started_or_disabled()).await;
175
176 DhtPin { _v4: v4, _v6: v6 }
177 }
178}
179
180pub struct LookupRequest {
181 lookups: Weak<BlockingMutex<LookupMap>>,
182 info_hash: InfoHash,
183 key: usize,
184}
185
186impl Drop for LookupRequest {
187 fn drop(&mut self) {
188 if let Some(lookups) = self.lookups.upgrade() {
189 let mut lookups = lookups.lock().unwrap();
190
191 let empty = if let Some(lookup) = lookups.get_mut(&self.info_hash) {
192 lookup.requests_tx.send_modify_return(|requests| {
193 requests.remove(self.key);
194 requests.is_empty()
195 })
196 } else {
197 false
198 };
199
200 if empty {
201 lookups.remove(&self.info_hash);
202 }
203 }
204 }
205}
206
207struct RequestData {
208 event_tx: mpsc::UnboundedSender<DhtEvent>,
209 announce: bool,
210}
211
212struct Lookup {
213 requests_tx: watch::Sender<Slab<RequestData>>,
214}
215
216impl Lookup {
217 fn start(
218 v4: ObservableDht,
219 v6: ObservableDht,
220 info_hash: InfoHash,
221 monitor: &StateMonitor,
222 span: Span,
223 ) -> Self {
224 let (requests_tx, requests_rx) = watch::channel(Slab::new());
225 let monitor = monitor.make_child(format!("{info_hash:?}"));
226
227 task::spawn(Self::run(v4, v6, info_hash, requests_rx, monitor).instrument(span));
228
229 Self { requests_tx }
230 }
231
232 fn add_request(&mut self, announce: bool, event_tx: mpsc::UnboundedSender<DhtEvent>) -> usize {
233 self.requests_tx
234 .send_modify_return(|requests| requests.insert(RequestData { event_tx, announce }))
235 }
236
237 #[allow(clippy::too_many_arguments)]
238 async fn run(
239 dht_v4: ObservableDht,
240 dht_v6: ObservableDht,
241 info_hash: InfoHash,
242 mut requests_rx: watch::Receiver<Slab<RequestData>>,
243 monitor: StateMonitor,
244 ) {
245 let seen_peers = SeenPeers::new();
246
247 let state = monitor.make_value("state", "started");
248 let next = monitor.make_value("next", SystemTime::now().into());
249
250 loop {
251 if requests_rx
253 .wait_for(|requests| !requests.is_empty())
254 .await
255 .is_err()
256 {
257 return;
258 }
259
260 let dhts = async {
262 loop {
263 select! {
264 _ = dht_v4.enabled() => (),
265 _ = dht_v6.enabled() => (),
266 }
267
268 let (v4, v6) =
269 future::join(dht_v4.started_or_disabled(), dht_v6.started_or_disabled())
270 .await;
271
272 if v4.is_some() || v6.is_some() {
273 break (v4, v6);
274 }
275 }
276 };
277
278 let (dht_v4, dht_v6) = select! {
279 dhts = dhts => dhts,
280 _ = closed(&mut requests_rx) => return,
281 };
282
283 seen_peers.start_new_round();
284
285 tracing::debug!(
286 ?info_hash,
287 v4 = dht_v4.is_some(),
288 v6 = dht_v6.is_some(),
289 "starting search"
290 );
291
292 let announce = requests_rx
295 .borrow()
296 .iter()
297 .any(|(_, request)| request.announce);
298
299 let peers_v4 = search(dht_v4.as_ref(), info_hash, announce);
300 let peers_v6 = search(dht_v6.as_ref(), info_hash, announce);
301 let mut peers = pin!(stream::select(peers_v4, peers_v6));
302
303 *state.get() = "awaiting results";
304
305 loop {
306 let addr = select! {
307 addr = peers.next() => {
308 if let Some(addr) = addr {
309 addr
310 } else {
311 break;
312 }
313 }
314 _ = closed(&mut requests_rx) => return,
315 };
316
317 if let Some(peer) = seen_peers.insert(PeerAddr::Quic(addr)) {
318 for (_, request) in requests_rx.borrow().iter() {
319 request
320 .event_tx
321 .send(DhtEvent::PeerFound(peer.clone()))
322 .ok();
323 }
324 }
325 }
326
327 for (_, request) in requests_rx.borrow().iter() {
328 request.event_tx.send(DhtEvent::RoundEnded).ok();
329 }
330
331 let duration =
333 rand::thread_rng().gen_range(MIN_DHT_ANNOUNCE_DELAY..MAX_DHT_ANNOUNCE_DELAY);
334
335 {
336 let time: DateTime<Local> = (SystemTime::now() + duration).into();
337 tracing::debug!(
338 ?info_hash,
339 "search ended. next one scheduled at {} (in {:?})",
340 time.format("%T"),
341 duration
342 );
343
344 *state.get() = "sleeping";
345 *next.get() = time;
346 }
347
348 select! {
349 _ = time::sleep(duration) => (),
350 _ = closed(&mut requests_rx) => return,
351 }
352 }
353 }
354}
355
356type LookupMap = HashMap<InfoHash, Lookup>;
357
358pub struct DhtPin {
361 _v4: ObservableDht,
362 _v6: ObservableDht,
363}
364
365async fn closed<T>(rx: &mut watch::Receiver<T>) {
367 while rx.changed().await.is_ok() {}
368}
369
370fn search<'a>(
373 dht: Option<&'a MainlineDht>,
374 info_hash: InfoHash,
375 announce: bool,
376) -> impl Stream<Item = SocketAddr> + 'a {
377 stream::iter(dht)
378 .then(move |dht| async move {
379 time::timeout(BOOTSTRAP_TIMEOUT, dht.bootstrapped())
380 .await
381 .ok();
382
383 dht.search(info_hash, announce)
384 })
385 .flatten()
386}