Skip to main content

ouisync/network/
stats.rs

1use ouisync_macros::api;
2use pin_project_lite::pin_project;
3use serde::{Deserialize, Serialize};
4use std::{
5    io::{self, IoSlice},
6    pin::Pin,
7    sync::{
8        Arc, Mutex,
9        atomic::{AtomicU64, Ordering},
10    },
11    task::{Context, Poll, ready},
12    time::Instant,
13};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15
16/// Network traffic statistics.
17#[derive(Default, Clone, Copy, Eq, PartialEq, Debug, Serialize, Deserialize)]
18#[api]
19pub struct Stats {
20    /// Total number of bytes sent.
21    pub bytes_tx: u64,
22    /// Total number of bytes received.
23    pub bytes_rx: u64,
24    /// Current send throughput in bytes per second.
25    pub throughput_tx: u64,
26    /// Current receive throughput in bytes per second.
27    pub throughput_rx: u64,
28}
29
30#[derive(Default)]
31pub(super) struct StatsTracker {
32    pub bytes: Arc<ByteCounters>,
33    throughput: Mutex<Throughputs>,
34}
35
36impl StatsTracker {
37    pub fn read(&self) -> Stats {
38        let bytes_tx = self.bytes.read_tx();
39        let bytes_rx = self.bytes.read_rx();
40        let now = Instant::now();
41
42        let mut throughput = self.throughput.lock().unwrap();
43        let throughput_tx = throughput.tx.sample(bytes_tx, now);
44        let throughput_rx = throughput.rx.sample(bytes_rx, now);
45
46        Stats {
47            bytes_tx,
48            bytes_rx,
49            throughput_tx,
50            throughput_rx,
51        }
52    }
53}
54
55#[derive(Default)]
56struct Throughputs {
57    tx: Throughput,
58    rx: Throughput,
59}
60
61/// Counter of sent/received bytes
62#[derive(Default)]
63pub(super) struct ByteCounters {
64    tx: AtomicU64,
65    rx: AtomicU64,
66}
67
68impl ByteCounters {
69    pub fn increment_tx(&self, by: u64) {
70        self.tx.fetch_add(by, Ordering::Relaxed);
71    }
72
73    pub fn increment_rx(&self, by: u64) {
74        self.rx.fetch_add(by, Ordering::Relaxed);
75    }
76
77    pub fn read_tx(&self) -> u64 {
78        self.tx.load(Ordering::Relaxed)
79    }
80
81    pub fn read_rx(&self) -> u64 {
82        self.rx.load(Ordering::Relaxed)
83    }
84}
85
86/// Throughput caculator
87#[derive(Default)]
88pub(super) struct Throughput {
89    prev: Option<ThroughputSample>,
90}
91
92impl Throughput {
93    /// Returns the current throughput (in bytes per second), given the current total amount of
94    /// bytes (sent or received) and the current time.
95    ///
96    /// Note: For best results, call this in regular intervals (e.g., once per second).
97    pub fn sample(&mut self, bytes: u64, timestamp: Instant) -> u64 {
98        let throughput = if let Some(prev) = self.prev.take() {
99            let millis = timestamp
100                .saturating_duration_since(prev.timestamp)
101                .as_millis()
102                .try_into()
103                .unwrap_or(u64::MAX);
104
105            (bytes.saturating_sub(prev.bytes) * 1000)
106                .checked_div(millis)
107                .unwrap_or(prev.throughput)
108        } else {
109            0
110        };
111
112        self.prev = Some(ThroughputSample {
113            timestamp,
114            bytes,
115            throughput,
116        });
117
118        throughput
119    }
120}
121
122struct ThroughputSample {
123    timestamp: Instant,
124    bytes: u64,
125    throughput: u64,
126}
127
128pin_project! {
129    /// Wrapper for an IO object (reader or writer) that counts the transferred bytes.
130    pub(super) struct Instrumented<T> {
131        #[pin]
132        inner: T,
133        counters: Arc<ByteCounters>,
134    }
135}
136
137impl<T> Instrumented<T> {
138    pub fn new(inner: T, counters: Arc<ByteCounters>) -> Self {
139        Self { inner, counters }
140    }
141}
142
143impl<T> AsyncRead for Instrumented<T>
144where
145    T: AsyncRead,
146{
147    fn poll_read(
148        self: Pin<&mut Self>,
149        cx: &mut Context<'_>,
150        buf: &mut ReadBuf<'_>,
151    ) -> Poll<io::Result<()>> {
152        let this = self.project();
153
154        let before = buf.filled().len();
155        let result = ready!(this.inner.poll_read(cx, buf));
156
157        if result.is_ok() {
158            this.counters
159                .increment_rx(buf.filled().len().saturating_sub(before) as u64);
160        }
161
162        Poll::Ready(result)
163    }
164}
165
166impl<T> AsyncWrite for Instrumented<T>
167where
168    T: AsyncWrite,
169{
170    fn poll_write(
171        self: Pin<&mut Self>,
172        cx: &mut Context<'_>,
173        buf: &[u8],
174    ) -> Poll<Result<usize, io::Error>> {
175        let this = self.project();
176        let result = ready!(this.inner.poll_write(cx, buf));
177
178        if let Ok(n) = result {
179            this.counters.increment_tx(n as u64);
180        }
181
182        Poll::Ready(result)
183    }
184
185    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
186        self.project().inner.poll_flush(cx)
187    }
188
189    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
190        self.project().inner.poll_shutdown(cx)
191    }
192
193    fn poll_write_vectored(
194        self: Pin<&mut Self>,
195        cx: &mut Context<'_>,
196        bufs: &[IoSlice<'_>],
197    ) -> Poll<Result<usize, io::Error>> {
198        let this = self.project();
199        let result = ready!(this.inner.poll_write_vectored(cx, bufs));
200
201        if let Ok(n) = result {
202            this.counters.increment_tx(n as u64);
203        }
204
205        Poll::Ready(result)
206    }
207
208    fn is_write_vectored(&self) -> bool {
209        self.inner.is_write_vectored()
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use std::time::Duration;
217
218    #[test]
219    fn throughput_sanity_check() {
220        let mut throughput = Throughput::default();
221        let start = Instant::now();
222
223        assert_eq!(throughput.sample(1024, start), 0);
224        assert_eq!(throughput.sample(1024, start + s(1)), 0);
225        assert_eq!(throughput.sample(2 * 1024, start + s(2)), 1024);
226        assert_eq!(throughput.sample(3 * 1024, start + s(3)), 1024);
227    }
228
229    #[test]
230    fn throughput_zero_duration() {
231        let mut throughput = Throughput::default();
232        let start = Instant::now();
233
234        assert_eq!(throughput.sample(1024, start), 0);
235        assert_eq!(throughput.sample(1024, start), 0);
236        assert_eq!(throughput.sample(2048, start), 0);
237
238        assert_eq!(throughput.sample(2048, start + s(1)), 0);
239        assert_eq!(throughput.sample(3072, start + s(1)), 0);
240
241        assert_eq!(throughput.sample(4096, start + s(2)), 1024);
242        assert_eq!(throughput.sample(5120, start + s(2)), 1024);
243    }
244
245    #[test]
246    fn throughput_negative_duration() {
247        let mut throughput = Throughput::default();
248        let start = Instant::now();
249
250        assert_eq!(throughput.sample(1024, start), 0);
251        assert_eq!(throughput.sample(2048, start + s(1)), 1024);
252        assert_eq!(throughput.sample(3072, start), 1024);
253    }
254
255    #[test]
256    fn throughput_non_monotonic_bytes() {
257        let mut throughput = Throughput::default();
258        let start = Instant::now();
259
260        assert_eq!(throughput.sample(1024, start), 0);
261        assert_eq!(throughput.sample(2048, start + s(1)), 1024);
262        assert_eq!(throughput.sample(1024, start + s(2)), 0);
263    }
264
265    fn s(value: u64) -> Duration {
266        Duration::from_secs(value)
267    }
268}