1use ouisync_macros::api;
2use pin_project_lite::pin_project;
3use serde::{Deserialize, Serialize};
4use std::{
5 io::{self, IoSlice},
6 pin::Pin,
7 sync::{
8 atomic::{AtomicU64, Ordering},
9 Arc, Mutex,
10 },
11 task::{ready, Context, Poll},
12 time::Instant,
13};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15
16#[derive(Default, Clone, Copy, Eq, PartialEq, Debug, Serialize, Deserialize)]
18#[api]
19pub struct Stats {
20 pub bytes_tx: u64,
22 pub bytes_rx: u64,
24 pub throughput_tx: u64,
26 pub throughput_rx: u64,
28}
29
30#[derive(Default)]
31pub(super) struct StatsTracker {
32 pub bytes: Arc<ByteCounters>,
33 throughput: Mutex<Throughputs>,
34}
35
36impl StatsTracker {
37 pub fn read(&self) -> Stats {
38 let bytes_tx = self.bytes.read_tx();
39 let bytes_rx = self.bytes.read_rx();
40 let now = Instant::now();
41
42 let mut throughput = self.throughput.lock().unwrap();
43 let throughput_tx = throughput.tx.sample(bytes_tx, now);
44 let throughput_rx = throughput.rx.sample(bytes_rx, now);
45
46 Stats {
47 bytes_tx,
48 bytes_rx,
49 throughput_tx,
50 throughput_rx,
51 }
52 }
53}
54
55#[derive(Default)]
56struct Throughputs {
57 tx: Throughput,
58 rx: Throughput,
59}
60
61#[derive(Default)]
63pub(super) struct ByteCounters {
64 tx: AtomicU64,
65 rx: AtomicU64,
66}
67
68impl ByteCounters {
69 pub fn increment_tx(&self, by: u64) {
70 self.tx.fetch_add(by, Ordering::Relaxed);
71 }
72
73 pub fn increment_rx(&self, by: u64) {
74 self.rx.fetch_add(by, Ordering::Relaxed);
75 }
76
77 pub fn read_tx(&self) -> u64 {
78 self.tx.load(Ordering::Relaxed)
79 }
80
81 pub fn read_rx(&self) -> u64 {
82 self.rx.load(Ordering::Relaxed)
83 }
84}
85
86#[derive(Default)]
88pub(super) struct Throughput {
89 prev: Option<ThroughputSample>,
90}
91
92impl Throughput {
93 pub fn sample(&mut self, bytes: u64, timestamp: Instant) -> u64 {
98 let throughput = if let Some(prev) = self.prev.take() {
99 let millis = timestamp
100 .saturating_duration_since(prev.timestamp)
101 .as_millis()
102 .try_into()
103 .unwrap_or(u64::MAX);
104
105 if millis == 0 {
106 prev.throughput
107 } else {
108 bytes.saturating_sub(prev.bytes) * 1000 / millis
109 }
110 } else {
111 0
112 };
113
114 self.prev = Some(ThroughputSample {
115 timestamp,
116 bytes,
117 throughput,
118 });
119
120 throughput
121 }
122}
123
124struct ThroughputSample {
125 timestamp: Instant,
126 bytes: u64,
127 throughput: u64,
128}
129
130pin_project! {
131 pub(super) struct Instrumented<T> {
133 #[pin]
134 inner: T,
135 counters: Arc<ByteCounters>,
136 }
137}
138
139impl<T> Instrumented<T> {
140 pub fn new(inner: T, counters: Arc<ByteCounters>) -> Self {
141 Self { inner, counters }
142 }
143}
144
145impl<T> AsyncRead for Instrumented<T>
146where
147 T: AsyncRead,
148{
149 fn poll_read(
150 self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 buf: &mut ReadBuf<'_>,
153 ) -> Poll<io::Result<()>> {
154 let this = self.project();
155
156 let before = buf.filled().len();
157 let result = ready!(this.inner.poll_read(cx, buf));
158
159 if result.is_ok() {
160 this.counters
161 .increment_rx(buf.filled().len().saturating_sub(before) as u64);
162 }
163
164 Poll::Ready(result)
165 }
166}
167
168impl<T> AsyncWrite for Instrumented<T>
169where
170 T: AsyncWrite,
171{
172 fn poll_write(
173 self: Pin<&mut Self>,
174 cx: &mut Context<'_>,
175 buf: &[u8],
176 ) -> Poll<Result<usize, io::Error>> {
177 let this = self.project();
178 let result = ready!(this.inner.poll_write(cx, buf));
179
180 if let Ok(n) = result {
181 this.counters.increment_tx(n as u64);
182 }
183
184 Poll::Ready(result)
185 }
186
187 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
188 self.project().inner.poll_flush(cx)
189 }
190
191 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
192 self.project().inner.poll_shutdown(cx)
193 }
194
195 fn poll_write_vectored(
196 self: Pin<&mut Self>,
197 cx: &mut Context<'_>,
198 bufs: &[IoSlice<'_>],
199 ) -> Poll<Result<usize, io::Error>> {
200 let this = self.project();
201 let result = ready!(this.inner.poll_write_vectored(cx, bufs));
202
203 if let Ok(n) = result {
204 this.counters.increment_tx(n as u64);
205 }
206
207 Poll::Ready(result)
208 }
209
210 fn is_write_vectored(&self) -> bool {
211 self.inner.is_write_vectored()
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use std::time::Duration;
219
220 #[test]
221 fn throughput_sanity_check() {
222 let mut throughput = Throughput::default();
223 let start = Instant::now();
224
225 assert_eq!(throughput.sample(1024, start), 0);
226 assert_eq!(throughput.sample(1024, start + s(1)), 0);
227 assert_eq!(throughput.sample(2 * 1024, start + s(2)), 1024);
228 assert_eq!(throughput.sample(3 * 1024, start + s(3)), 1024);
229 }
230
231 #[test]
232 fn throughput_zero_duration() {
233 let mut throughput = Throughput::default();
234 let start = Instant::now();
235
236 assert_eq!(throughput.sample(1024, start), 0);
237 assert_eq!(throughput.sample(1024, start), 0);
238 assert_eq!(throughput.sample(2048, start), 0);
239
240 assert_eq!(throughput.sample(2048, start + s(1)), 0);
241 assert_eq!(throughput.sample(3072, start + s(1)), 0);
242
243 assert_eq!(throughput.sample(4096, start + s(2)), 1024);
244 assert_eq!(throughput.sample(5120, start + s(2)), 1024);
245 }
246
247 #[test]
248 fn throughput_negative_duration() {
249 let mut throughput = Throughput::default();
250 let start = Instant::now();
251
252 assert_eq!(throughput.sample(1024, start), 0);
253 assert_eq!(throughput.sample(2048, start + s(1)), 1024);
254 assert_eq!(throughput.sample(3072, start), 1024);
255 }
256
257 #[test]
258 fn throughput_non_monotonic_bytes() {
259 let mut throughput = Throughput::default();
260 let start = Instant::now();
261
262 assert_eq!(throughput.sample(1024, start), 0);
263 assert_eq!(throughput.sample(2048, start + s(1)), 1024);
264 assert_eq!(throughput.sample(1024, start + s(2)), 0);
265 }
266
267 fn s(value: u64) -> Duration {
268 Duration::from_secs(value)
269 }
270}