1use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error};
2use std::{
3 fmt,
4 net::{IpAddr, SocketAddr},
5 str::FromStr,
6};
7
8#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
9pub enum PeerPort {
10 Tcp(u16),
11 Quic(u16),
12}
13
14#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
15pub enum PeerAddr {
16 Tcp(SocketAddr),
17 Quic(SocketAddr),
18}
19
20impl PeerAddr {
21 pub fn socket_addr(&self) -> &SocketAddr {
22 match self {
23 Self::Tcp(addr) => addr,
24 Self::Quic(addr) => addr,
25 }
26 }
27
28 pub fn ip(&self) -> IpAddr {
29 self.socket_addr().ip()
30 }
31
32 pub fn port(&self) -> u16 {
33 self.socket_addr().port()
34 }
35
36 pub fn set_port(&mut self, port: u16) {
37 match self {
38 Self::Tcp(addr) => addr.set_port(port),
39 Self::Quic(addr) => addr.set_port(port),
40 }
41 }
42
43 pub fn peer_port(&self) -> PeerPort {
44 match self {
45 Self::Tcp(addr) => PeerPort::Tcp(addr.port()),
46 Self::Quic(addr) => PeerPort::Quic(addr.port()),
47 }
48 }
49
50 pub fn is_quic(&self) -> bool {
51 match self {
52 Self::Tcp(_) => false,
53 Self::Quic(_) => true,
54 }
55 }
56
57 pub fn is_tcp(&self) -> bool {
58 match self {
59 Self::Tcp(_) => true,
60 Self::Quic(_) => false,
61 }
62 }
63
64 pub fn is_local(&self) -> bool {
65 match self.socket_addr().ip() {
66 IpAddr::V4(addr) => addr.is_private() || addr.is_loopback() || addr.is_link_local(),
67 IpAddr::V6(addr) => {
68 addr.is_loopback() || addr.is_unicast_link_local() || addr.is_unique_local()
69 }
70 }
71 }
72}
73
74impl FromStr for PeerAddr {
75 type Err = String;
76
77 fn from_str(s: &str) -> Result<Self, Self::Err> {
78 let (proto, addr) = match s.split_once('/') {
79 Some((proto, addr)) => (proto, addr),
80 None => {
81 return Err(format!("Could not find '/' delimiter in the address {s:?}"));
82 }
83 };
84
85 let addr = match SocketAddr::from_str(addr) {
86 Ok(addr) => addr,
87 Err(_) => return Err(format!("Failed to parse IP:PORT {addr:?}")),
88 };
89
90 if proto.eq_ignore_ascii_case("tcp") {
91 Ok(PeerAddr::Tcp(addr))
92 } else if proto.eq_ignore_ascii_case("quic") {
93 Ok(PeerAddr::Quic(addr))
94 } else {
95 Err(format!("Unrecognized protocol {proto:?} in {s:?}"))
96 }
97 }
98}
99
100impl fmt::Display for PeerAddr {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match self {
103 Self::Tcp(addr) => write!(f, "tcp/{addr}"),
104 Self::Quic(addr) => write!(f, "quic/{addr}"),
105 }
106 }
107}
108
109impl Serialize for PeerAddr {
110 fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
111 where
112 S: Serializer,
113 {
114 if s.is_human_readable() {
115 self.to_string().serialize(s)
116 } else {
117 SerdeProxy::serialize(self, s)
118 }
119 }
120}
121
122impl<'de> Deserialize<'de> for PeerAddr {
123 fn deserialize<D>(d: D) -> Result<Self, D::Error>
124 where
125 D: Deserializer<'de>,
126 {
127 if d.is_human_readable() {
128 <&str>::deserialize(d)?.parse().map_err(D::Error::custom)
129 } else {
130 SerdeProxy::deserialize(d)
131 }
132 }
133}
134
135#[derive(Serialize, Deserialize)]
137#[serde(remote = "PeerAddr")]
138enum SerdeProxy {
139 Tcp(#[allow(dead_code)] SocketAddr),
140 Quic(#[allow(dead_code)] SocketAddr),
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use std::net::{Ipv4Addr, Ipv6Addr};
147
148 #[test]
149 fn parse() {
150 for (orig, expected) in [
151 (
152 PeerAddr::Tcp((Ipv4Addr::UNSPECIFIED, 0).into()),
153 "tcp/0.0.0.0:0",
154 ),
155 (
156 PeerAddr::Tcp((Ipv6Addr::UNSPECIFIED, 0).into()),
157 "tcp/[::]:0",
158 ),
159 (
160 PeerAddr::Quic((Ipv4Addr::UNSPECIFIED, 0).into()),
161 "quic/0.0.0.0:0",
162 ),
163 (
164 PeerAddr::Quic((Ipv6Addr::UNSPECIFIED, 0).into()),
165 "quic/[::]:0",
166 ),
167 ] {
168 assert_eq!(orig.to_string(), expected);
169 assert_eq!(expected.parse::<PeerAddr>().unwrap(), orig);
170 }
171 }
172
173 #[test]
174 fn serialize_binary() {
175 for (orig, expected) in [
176 (
177 PeerAddr::Tcp(([192, 0, 2, 0], 12481).into()),
178 "0000000000000000c0000200c130",
179 ),
180 (
181 PeerAddr::Quic(([0x2001, 0xdb8, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5], 24816).into()),
182 "010000000100000020010db8000000010002000300040005f060",
183 ),
184 ] {
185 assert_eq!(hex::encode(bincode::serialize(&orig).unwrap()), expected);
186 assert_eq!(
187 bincode::deserialize::<PeerAddr>(&hex::decode(expected).unwrap()).unwrap(),
188 orig
189 );
190 }
191 }
192
193 #[test]
194 fn serialize_human_readable() {
195 for addr in [
196 PeerAddr::Tcp(([192, 0, 2, 0], 12481).into()),
197 PeerAddr::Quic(([0x2001, 0xdb8, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5], 24816).into()),
198 ] {
199 let expected = addr.to_string();
200 let actual = serde_json::to_string(&addr).unwrap();
201 assert_eq!(actual, format!("\"{expected}\""))
202 }
203 }
204}