1use ouisync_macros::api;
2use pin_project_lite::pin_project;
3use serde::{Deserialize, Serialize};
4use std::{
5 io::{self, IoSlice},
6 pin::Pin,
7 sync::{
8 Arc, Mutex,
9 atomic::{AtomicU64, Ordering},
10 },
11 task::{Context, Poll, ready},
12 time::Instant,
13};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15
16#[derive(Default, Clone, Copy, Eq, PartialEq, Debug, Serialize, Deserialize)]
18#[api]
19pub struct Stats {
20 pub bytes_tx: u64,
22 pub bytes_rx: u64,
24 pub throughput_tx: u64,
26 pub throughput_rx: u64,
28}
29
30#[derive(Default)]
31pub(super) struct StatsTracker {
32 pub bytes: Arc<ByteCounters>,
33 throughput: Mutex<Throughputs>,
34}
35
36impl StatsTracker {
37 pub fn read(&self) -> Stats {
38 let bytes_tx = self.bytes.read_tx();
39 let bytes_rx = self.bytes.read_rx();
40 let now = Instant::now();
41
42 let mut throughput = self.throughput.lock().unwrap();
43 let throughput_tx = throughput.tx.sample(bytes_tx, now);
44 let throughput_rx = throughput.rx.sample(bytes_rx, now);
45
46 Stats {
47 bytes_tx,
48 bytes_rx,
49 throughput_tx,
50 throughput_rx,
51 }
52 }
53}
54
55#[derive(Default)]
56struct Throughputs {
57 tx: Throughput,
58 rx: Throughput,
59}
60
61#[derive(Default)]
63pub(super) struct ByteCounters {
64 tx: AtomicU64,
65 rx: AtomicU64,
66}
67
68impl ByteCounters {
69 pub fn increment_tx(&self, by: u64) {
70 self.tx.fetch_add(by, Ordering::Relaxed);
71 }
72
73 pub fn increment_rx(&self, by: u64) {
74 self.rx.fetch_add(by, Ordering::Relaxed);
75 }
76
77 pub fn read_tx(&self) -> u64 {
78 self.tx.load(Ordering::Relaxed)
79 }
80
81 pub fn read_rx(&self) -> u64 {
82 self.rx.load(Ordering::Relaxed)
83 }
84}
85
86#[derive(Default)]
88pub(super) struct Throughput {
89 prev: Option<ThroughputSample>,
90}
91
92impl Throughput {
93 pub fn sample(&mut self, bytes: u64, timestamp: Instant) -> u64 {
98 let throughput = if let Some(prev) = self.prev.take() {
99 let millis = timestamp
100 .saturating_duration_since(prev.timestamp)
101 .as_millis()
102 .try_into()
103 .unwrap_or(u64::MAX);
104
105 (bytes.saturating_sub(prev.bytes) * 1000)
106 .checked_div(millis)
107 .unwrap_or(prev.throughput)
108 } else {
109 0
110 };
111
112 self.prev = Some(ThroughputSample {
113 timestamp,
114 bytes,
115 throughput,
116 });
117
118 throughput
119 }
120}
121
122struct ThroughputSample {
123 timestamp: Instant,
124 bytes: u64,
125 throughput: u64,
126}
127
128pin_project! {
129 pub(super) struct Instrumented<T> {
131 #[pin]
132 inner: T,
133 counters: Arc<ByteCounters>,
134 }
135}
136
137impl<T> Instrumented<T> {
138 pub fn new(inner: T, counters: Arc<ByteCounters>) -> Self {
139 Self { inner, counters }
140 }
141}
142
143impl<T> AsyncRead for Instrumented<T>
144where
145 T: AsyncRead,
146{
147 fn poll_read(
148 self: Pin<&mut Self>,
149 cx: &mut Context<'_>,
150 buf: &mut ReadBuf<'_>,
151 ) -> Poll<io::Result<()>> {
152 let this = self.project();
153
154 let before = buf.filled().len();
155 let result = ready!(this.inner.poll_read(cx, buf));
156
157 if result.is_ok() {
158 this.counters
159 .increment_rx(buf.filled().len().saturating_sub(before) as u64);
160 }
161
162 Poll::Ready(result)
163 }
164}
165
166impl<T> AsyncWrite for Instrumented<T>
167where
168 T: AsyncWrite,
169{
170 fn poll_write(
171 self: Pin<&mut Self>,
172 cx: &mut Context<'_>,
173 buf: &[u8],
174 ) -> Poll<Result<usize, io::Error>> {
175 let this = self.project();
176 let result = ready!(this.inner.poll_write(cx, buf));
177
178 if let Ok(n) = result {
179 this.counters.increment_tx(n as u64);
180 }
181
182 Poll::Ready(result)
183 }
184
185 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
186 self.project().inner.poll_flush(cx)
187 }
188
189 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
190 self.project().inner.poll_shutdown(cx)
191 }
192
193 fn poll_write_vectored(
194 self: Pin<&mut Self>,
195 cx: &mut Context<'_>,
196 bufs: &[IoSlice<'_>],
197 ) -> Poll<Result<usize, io::Error>> {
198 let this = self.project();
199 let result = ready!(this.inner.poll_write_vectored(cx, bufs));
200
201 if let Ok(n) = result {
202 this.counters.increment_tx(n as u64);
203 }
204
205 Poll::Ready(result)
206 }
207
208 fn is_write_vectored(&self) -> bool {
209 self.inner.is_write_vectored()
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use std::time::Duration;
217
218 #[test]
219 fn throughput_sanity_check() {
220 let mut throughput = Throughput::default();
221 let start = Instant::now();
222
223 assert_eq!(throughput.sample(1024, start), 0);
224 assert_eq!(throughput.sample(1024, start + s(1)), 0);
225 assert_eq!(throughput.sample(2 * 1024, start + s(2)), 1024);
226 assert_eq!(throughput.sample(3 * 1024, start + s(3)), 1024);
227 }
228
229 #[test]
230 fn throughput_zero_duration() {
231 let mut throughput = Throughput::default();
232 let start = Instant::now();
233
234 assert_eq!(throughput.sample(1024, start), 0);
235 assert_eq!(throughput.sample(1024, start), 0);
236 assert_eq!(throughput.sample(2048, start), 0);
237
238 assert_eq!(throughput.sample(2048, start + s(1)), 0);
239 assert_eq!(throughput.sample(3072, start + s(1)), 0);
240
241 assert_eq!(throughput.sample(4096, start + s(2)), 1024);
242 assert_eq!(throughput.sample(5120, start + s(2)), 1024);
243 }
244
245 #[test]
246 fn throughput_negative_duration() {
247 let mut throughput = Throughput::default();
248 let start = Instant::now();
249
250 assert_eq!(throughput.sample(1024, start), 0);
251 assert_eq!(throughput.sample(2048, start + s(1)), 1024);
252 assert_eq!(throughput.sample(3072, start), 1024);
253 }
254
255 #[test]
256 fn throughput_non_monotonic_bytes() {
257 let mut throughput = Throughput::default();
258 let start = Instant::now();
259
260 assert_eq!(throughput.sample(1024, start), 0);
261 assert_eq!(throughput.sample(2048, start + s(1)), 1024);
262 assert_eq!(throughput.sample(1024, start + s(2)), 0);
263 }
264
265 fn s(value: u64) -> Duration {
266 Duration::from_secs(value)
267 }
268}