ouisync/network/
stats.rs

1use ouisync_macros::api;
2use pin_project_lite::pin_project;
3use serde::{Deserialize, Serialize};
4use std::{
5    io::{self, IoSlice},
6    pin::Pin,
7    sync::{
8        atomic::{AtomicU64, Ordering},
9        Arc, Mutex,
10    },
11    task::{ready, Context, Poll},
12    time::Instant,
13};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15
16/// Network traffic statistics.
17#[derive(Default, Clone, Copy, Eq, PartialEq, Debug, Serialize, Deserialize)]
18#[api]
19pub struct Stats {
20    /// Total number of bytes sent.
21    pub bytes_tx: u64,
22    /// Total number of bytes received.
23    pub bytes_rx: u64,
24    /// Current send throughput in bytes per second.
25    pub throughput_tx: u64,
26    /// Current receive throughput in bytes per second.
27    pub throughput_rx: u64,
28}
29
30#[derive(Default)]
31pub(super) struct StatsTracker {
32    pub bytes: Arc<ByteCounters>,
33    throughput: Mutex<Throughputs>,
34}
35
36impl StatsTracker {
37    pub fn read(&self) -> Stats {
38        let bytes_tx = self.bytes.read_tx();
39        let bytes_rx = self.bytes.read_rx();
40        let now = Instant::now();
41
42        let mut throughput = self.throughput.lock().unwrap();
43        let throughput_tx = throughput.tx.sample(bytes_tx, now);
44        let throughput_rx = throughput.rx.sample(bytes_rx, now);
45
46        Stats {
47            bytes_tx,
48            bytes_rx,
49            throughput_tx,
50            throughput_rx,
51        }
52    }
53}
54
55#[derive(Default)]
56struct Throughputs {
57    tx: Throughput,
58    rx: Throughput,
59}
60
61/// Counter of sent/received bytes
62#[derive(Default)]
63pub(super) struct ByteCounters {
64    tx: AtomicU64,
65    rx: AtomicU64,
66}
67
68impl ByteCounters {
69    pub fn increment_tx(&self, by: u64) {
70        self.tx.fetch_add(by, Ordering::Relaxed);
71    }
72
73    pub fn increment_rx(&self, by: u64) {
74        self.rx.fetch_add(by, Ordering::Relaxed);
75    }
76
77    pub fn read_tx(&self) -> u64 {
78        self.tx.load(Ordering::Relaxed)
79    }
80
81    pub fn read_rx(&self) -> u64 {
82        self.rx.load(Ordering::Relaxed)
83    }
84}
85
86/// Throughput caculator
87#[derive(Default)]
88pub(super) struct Throughput {
89    prev: Option<ThroughputSample>,
90}
91
92impl Throughput {
93    /// Returns the current throughput (in bytes per second), given the current total amount of
94    /// bytes (sent or received) and the current time.
95    ///
96    /// Note: For best results, call this in regular intervals (e.g., once per second).
97    pub fn sample(&mut self, bytes: u64, timestamp: Instant) -> u64 {
98        let throughput = if let Some(prev) = self.prev.take() {
99            let millis = timestamp
100                .saturating_duration_since(prev.timestamp)
101                .as_millis()
102                .try_into()
103                .unwrap_or(u64::MAX);
104
105            if millis == 0 {
106                prev.throughput
107            } else {
108                bytes.saturating_sub(prev.bytes) * 1000 / millis
109            }
110        } else {
111            0
112        };
113
114        self.prev = Some(ThroughputSample {
115            timestamp,
116            bytes,
117            throughput,
118        });
119
120        throughput
121    }
122}
123
124struct ThroughputSample {
125    timestamp: Instant,
126    bytes: u64,
127    throughput: u64,
128}
129
130pin_project! {
131    /// Wrapper for an IO object (reader or writer) that counts the transferred bytes.
132    pub(super) struct Instrumented<T> {
133        #[pin]
134        inner: T,
135        counters: Arc<ByteCounters>,
136    }
137}
138
139impl<T> Instrumented<T> {
140    pub fn new(inner: T, counters: Arc<ByteCounters>) -> Self {
141        Self { inner, counters }
142    }
143}
144
145impl<T> AsyncRead for Instrumented<T>
146where
147    T: AsyncRead,
148{
149    fn poll_read(
150        self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        buf: &mut ReadBuf<'_>,
153    ) -> Poll<io::Result<()>> {
154        let this = self.project();
155
156        let before = buf.filled().len();
157        let result = ready!(this.inner.poll_read(cx, buf));
158
159        if result.is_ok() {
160            this.counters
161                .increment_rx(buf.filled().len().saturating_sub(before) as u64);
162        }
163
164        Poll::Ready(result)
165    }
166}
167
168impl<T> AsyncWrite for Instrumented<T>
169where
170    T: AsyncWrite,
171{
172    fn poll_write(
173        self: Pin<&mut Self>,
174        cx: &mut Context<'_>,
175        buf: &[u8],
176    ) -> Poll<Result<usize, io::Error>> {
177        let this = self.project();
178        let result = ready!(this.inner.poll_write(cx, buf));
179
180        if let Ok(n) = result {
181            this.counters.increment_tx(n as u64);
182        }
183
184        Poll::Ready(result)
185    }
186
187    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
188        self.project().inner.poll_flush(cx)
189    }
190
191    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
192        self.project().inner.poll_shutdown(cx)
193    }
194
195    fn poll_write_vectored(
196        self: Pin<&mut Self>,
197        cx: &mut Context<'_>,
198        bufs: &[IoSlice<'_>],
199    ) -> Poll<Result<usize, io::Error>> {
200        let this = self.project();
201        let result = ready!(this.inner.poll_write_vectored(cx, bufs));
202
203        if let Ok(n) = result {
204            this.counters.increment_tx(n as u64);
205        }
206
207        Poll::Ready(result)
208    }
209
210    fn is_write_vectored(&self) -> bool {
211        self.inner.is_write_vectored()
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use std::time::Duration;
219
220    #[test]
221    fn throughput_sanity_check() {
222        let mut throughput = Throughput::default();
223        let start = Instant::now();
224
225        assert_eq!(throughput.sample(1024, start), 0);
226        assert_eq!(throughput.sample(1024, start + s(1)), 0);
227        assert_eq!(throughput.sample(2 * 1024, start + s(2)), 1024);
228        assert_eq!(throughput.sample(3 * 1024, start + s(3)), 1024);
229    }
230
231    #[test]
232    fn throughput_zero_duration() {
233        let mut throughput = Throughput::default();
234        let start = Instant::now();
235
236        assert_eq!(throughput.sample(1024, start), 0);
237        assert_eq!(throughput.sample(1024, start), 0);
238        assert_eq!(throughput.sample(2048, start), 0);
239
240        assert_eq!(throughput.sample(2048, start + s(1)), 0);
241        assert_eq!(throughput.sample(3072, start + s(1)), 0);
242
243        assert_eq!(throughput.sample(4096, start + s(2)), 1024);
244        assert_eq!(throughput.sample(5120, start + s(2)), 1024);
245    }
246
247    #[test]
248    fn throughput_negative_duration() {
249        let mut throughput = Throughput::default();
250        let start = Instant::now();
251
252        assert_eq!(throughput.sample(1024, start), 0);
253        assert_eq!(throughput.sample(2048, start + s(1)), 1024);
254        assert_eq!(throughput.sample(3072, start), 1024);
255    }
256
257    #[test]
258    fn throughput_non_monotonic_bytes() {
259        let mut throughput = Throughput::default();
260        let start = Instant::now();
261
262        assert_eq!(throughput.sample(1024, start), 0);
263        assert_eq!(throughput.sample(2048, start + s(1)), 1024);
264        assert_eq!(throughput.sample(1024, start + s(2)), 0);
265    }
266
267    fn s(value: u64) -> Duration {
268        Duration::from_secs(value)
269    }
270}