1 use byteorder::{ByteOrder, NetworkEndian};
2 use core::fmt;
3 
4 use super::{Error, Result};
5 use crate::phy::ChecksumCapabilities;
6 use crate::wire::ip::checksum;
7 use crate::wire::{IpAddress, IpProtocol};
8 
9 /// A read/write wrapper around an User Datagram Protocol packet buffer.
10 #[derive(Debug, PartialEq, Eq, Clone)]
11 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
12 pub struct Packet<T: AsRef<[u8]>> {
13     buffer: T,
14 }
15 
16 mod field {
17     #![allow(non_snake_case)]
18 
19     use crate::wire::field::*;
20 
21     pub const SRC_PORT: Field = 0..2;
22     pub const DST_PORT: Field = 2..4;
23     pub const LENGTH: Field = 4..6;
24     pub const CHECKSUM: Field = 6..8;
25 
PAYLOAD(length: u16) -> Field26     pub const fn PAYLOAD(length: u16) -> Field {
27         CHECKSUM.end..(length as usize)
28     }
29 }
30 
31 pub const HEADER_LEN: usize = field::CHECKSUM.end;
32 
33 #[allow(clippy::len_without_is_empty)]
34 impl<T: AsRef<[u8]>> Packet<T> {
35     /// Imbue a raw octet buffer with UDP packet structure.
new_unchecked(buffer: T) -> Packet<T>36     pub const fn new_unchecked(buffer: T) -> Packet<T> {
37         Packet { buffer }
38     }
39 
40     /// Shorthand for a combination of [new_unchecked] and [check_len].
41     ///
42     /// [new_unchecked]: #method.new_unchecked
43     /// [check_len]: #method.check_len
new_checked(buffer: T) -> Result<Packet<T>>44     pub fn new_checked(buffer: T) -> Result<Packet<T>> {
45         let packet = Self::new_unchecked(buffer);
46         packet.check_len()?;
47         Ok(packet)
48     }
49 
50     /// Ensure that no accessor method will panic if called.
51     /// Returns `Err(Error)` if the buffer is too short.
52     /// Returns `Err(Error)` if the length field has a value smaller
53     /// than the header length.
54     ///
55     /// The result of this check is invalidated by calling [set_len].
56     ///
57     /// [set_len]: #method.set_len
check_len(&self) -> Result<()>58     pub fn check_len(&self) -> Result<()> {
59         let buffer_len = self.buffer.as_ref().len();
60         if buffer_len < HEADER_LEN {
61             Err(Error)
62         } else {
63             let field_len = self.len() as usize;
64             if buffer_len < field_len || field_len < HEADER_LEN {
65                 Err(Error)
66             } else {
67                 Ok(())
68             }
69         }
70     }
71 
72     /// Consume the packet, returning the underlying buffer.
into_inner(self) -> T73     pub fn into_inner(self) -> T {
74         self.buffer
75     }
76 
77     /// Return the source port field.
78     #[inline]
src_port(&self) -> u1679     pub fn src_port(&self) -> u16 {
80         let data = self.buffer.as_ref();
81         NetworkEndian::read_u16(&data[field::SRC_PORT])
82     }
83 
84     /// Return the destination port field.
85     #[inline]
dst_port(&self) -> u1686     pub fn dst_port(&self) -> u16 {
87         let data = self.buffer.as_ref();
88         NetworkEndian::read_u16(&data[field::DST_PORT])
89     }
90 
91     /// Return the length field.
92     #[inline]
len(&self) -> u1693     pub fn len(&self) -> u16 {
94         let data = self.buffer.as_ref();
95         NetworkEndian::read_u16(&data[field::LENGTH])
96     }
97 
98     /// Return the checksum field.
99     #[inline]
checksum(&self) -> u16100     pub fn checksum(&self) -> u16 {
101         let data = self.buffer.as_ref();
102         NetworkEndian::read_u16(&data[field::CHECKSUM])
103     }
104 
105     /// Validate the packet checksum.
106     ///
107     /// # Panics
108     /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
109     /// and that family is IPv4 or IPv6.
110     ///
111     /// # Fuzzing
112     /// This function always returns `true` when fuzzing.
verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool113     pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool {
114         if cfg!(fuzzing) {
115             return true;
116         }
117 
118         // From the RFC:
119         // > An all zero transmitted checksum value means that the transmitter
120         // > generated no checksum (for debugging or for higher level protocols
121         // > that don't care).
122         if self.checksum() == 0 {
123             return true;
124         }
125 
126         let data = self.buffer.as_ref();
127         checksum::combine(&[
128             checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
129             checksum::data(&data[..self.len() as usize]),
130         ]) == !0
131     }
132 }
133 
134 impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> {
135     /// Return a pointer to the payload.
136     #[inline]
payload(&self) -> &'a [u8]137     pub fn payload(&self) -> &'a [u8] {
138         let length = self.len();
139         let data = self.buffer.as_ref();
140         &data[field::PAYLOAD(length)]
141     }
142 }
143 
144 impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
145     /// Set the source port field.
146     #[inline]
set_src_port(&mut self, value: u16)147     pub fn set_src_port(&mut self, value: u16) {
148         let data = self.buffer.as_mut();
149         NetworkEndian::write_u16(&mut data[field::SRC_PORT], value)
150     }
151 
152     /// Set the destination port field.
153     #[inline]
set_dst_port(&mut self, value: u16)154     pub fn set_dst_port(&mut self, value: u16) {
155         let data = self.buffer.as_mut();
156         NetworkEndian::write_u16(&mut data[field::DST_PORT], value)
157     }
158 
159     /// Set the length field.
160     #[inline]
set_len(&mut self, value: u16)161     pub fn set_len(&mut self, value: u16) {
162         let data = self.buffer.as_mut();
163         NetworkEndian::write_u16(&mut data[field::LENGTH], value)
164     }
165 
166     /// Set the checksum field.
167     #[inline]
set_checksum(&mut self, value: u16)168     pub fn set_checksum(&mut self, value: u16) {
169         let data = self.buffer.as_mut();
170         NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
171     }
172 
173     /// Compute and fill in the header checksum.
174     ///
175     /// # Panics
176     /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
177     /// and that family is IPv4 or IPv6.
fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress)178     pub fn fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress) {
179         self.set_checksum(0);
180         let checksum = {
181             let data = self.buffer.as_ref();
182             !checksum::combine(&[
183                 checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
184                 checksum::data(&data[..self.len() as usize]),
185             ])
186         };
187         // UDP checksum value of 0 means no checksum; if the checksum really is zero,
188         // use all-ones, which indicates that the remote end must verify the checksum.
189         // Arithmetically, RFC 1071 checksums of all-zeroes and all-ones behave identically,
190         // so no action is necessary on the remote end.
191         self.set_checksum(if checksum == 0 { 0xffff } else { checksum })
192     }
193 
194     /// Return a mutable pointer to the payload.
195     #[inline]
payload_mut(&mut self) -> &mut [u8]196     pub fn payload_mut(&mut self) -> &mut [u8] {
197         let length = self.len();
198         let data = self.buffer.as_mut();
199         &mut data[field::PAYLOAD(length)]
200     }
201 }
202 
203 impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
as_ref(&self) -> &[u8]204     fn as_ref(&self) -> &[u8] {
205         self.buffer.as_ref()
206     }
207 }
208 
209 /// A high-level representation of an User Datagram Protocol packet.
210 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
211 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
212 pub struct Repr {
213     pub src_port: u16,
214     pub dst_port: u16,
215 }
216 
217 impl Repr {
218     /// Parse an User Datagram Protocol packet and return a high-level representation.
parse<T>( packet: &Packet<&T>, src_addr: &IpAddress, dst_addr: &IpAddress, checksum_caps: &ChecksumCapabilities, ) -> Result<Repr> where T: AsRef<[u8]> + ?Sized,219     pub fn parse<T>(
220         packet: &Packet<&T>,
221         src_addr: &IpAddress,
222         dst_addr: &IpAddress,
223         checksum_caps: &ChecksumCapabilities,
224     ) -> Result<Repr>
225     where
226         T: AsRef<[u8]> + ?Sized,
227     {
228         // Destination port cannot be omitted (but source port can be).
229         if packet.dst_port() == 0 {
230             return Err(Error);
231         }
232         // Valid checksum is expected...
233         if checksum_caps.udp.rx() && !packet.verify_checksum(src_addr, dst_addr) {
234             match (src_addr, dst_addr) {
235                 // ... except on UDP-over-IPv4, where it can be omitted.
236                 #[cfg(feature = "proto-ipv4")]
237                 (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_)) if packet.checksum() == 0 => (),
238                 _ => return Err(Error),
239             }
240         }
241 
242         Ok(Repr {
243             src_port: packet.src_port(),
244             dst_port: packet.dst_port(),
245         })
246     }
247 
248     /// Return the length of the packet header that will be emitted from this high-level representation.
header_len(&self) -> usize249     pub const fn header_len(&self) -> usize {
250         HEADER_LEN
251     }
252 
253     /// Emit a high-level representation into an User Datagram Protocol packet.
254     ///
255     /// This never calculates the checksum, and is intended for internal-use only,
256     /// not for packets that are going to be actually sent over the network. For
257     /// example, when decompressing 6lowpan.
emit_header<T: ?Sized>(&self, packet: &mut Packet<&mut T>, payload_len: usize) where T: AsRef<[u8]> + AsMut<[u8]>,258     pub(crate) fn emit_header<T: ?Sized>(&self, packet: &mut Packet<&mut T>, payload_len: usize)
259     where
260         T: AsRef<[u8]> + AsMut<[u8]>,
261     {
262         packet.set_src_port(self.src_port);
263         packet.set_dst_port(self.dst_port);
264         packet.set_len((HEADER_LEN + payload_len) as u16);
265         packet.set_checksum(0);
266     }
267 
268     /// Emit a high-level representation into an User Datagram Protocol packet.
emit<T: ?Sized>( &self, packet: &mut Packet<&mut T>, src_addr: &IpAddress, dst_addr: &IpAddress, payload_len: usize, emit_payload: impl FnOnce(&mut [u8]), checksum_caps: &ChecksumCapabilities, ) where T: AsRef<[u8]> + AsMut<[u8]>,269     pub fn emit<T: ?Sized>(
270         &self,
271         packet: &mut Packet<&mut T>,
272         src_addr: &IpAddress,
273         dst_addr: &IpAddress,
274         payload_len: usize,
275         emit_payload: impl FnOnce(&mut [u8]),
276         checksum_caps: &ChecksumCapabilities,
277     ) where
278         T: AsRef<[u8]> + AsMut<[u8]>,
279     {
280         packet.set_src_port(self.src_port);
281         packet.set_dst_port(self.dst_port);
282         packet.set_len((HEADER_LEN + payload_len) as u16);
283         emit_payload(packet.payload_mut());
284 
285         if checksum_caps.udp.tx() {
286             packet.fill_checksum(src_addr, dst_addr)
287         } else {
288             // make sure we get a consistently zeroed checksum,
289             // since implementations might rely on it
290             packet.set_checksum(0);
291         }
292     }
293 }
294 
295 impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result296     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
297         // Cannot use Repr::parse because we don't have the IP addresses.
298         write!(
299             f,
300             "UDP src={} dst={} len={}",
301             self.src_port(),
302             self.dst_port(),
303             self.payload().len()
304         )
305     }
306 }
307 
308 impl fmt::Display for Repr {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result309     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
310         write!(f, "UDP src={} dst={}", self.src_port, self.dst_port)
311     }
312 }
313 
314 use crate::wire::pretty_print::{PrettyIndent, PrettyPrint};
315 
316 impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
pretty_print( buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter, indent: &mut PrettyIndent, ) -> fmt::Result317     fn pretty_print(
318         buffer: &dyn AsRef<[u8]>,
319         f: &mut fmt::Formatter,
320         indent: &mut PrettyIndent,
321     ) -> fmt::Result {
322         match Packet::new_checked(buffer) {
323             Err(err) => write!(f, "{indent}({err})"),
324             Ok(packet) => write!(f, "{indent}{packet}"),
325         }
326     }
327 }
328 
329 #[cfg(test)]
330 mod test {
331     use super::*;
332     #[cfg(feature = "proto-ipv4")]
333     use crate::wire::Ipv4Address;
334 
335     #[cfg(feature = "proto-ipv4")]
336     const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
337     #[cfg(feature = "proto-ipv4")]
338     const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
339 
340     #[cfg(feature = "proto-ipv4")]
341     static PACKET_BYTES: [u8; 12] = [
342         0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff,
343     ];
344 
345     #[cfg(feature = "proto-ipv4")]
346     static NO_CHECKSUM_PACKET: [u8; 12] = [
347         0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0xaa, 0x00, 0x00, 0xff,
348     ];
349 
350     #[cfg(feature = "proto-ipv4")]
351     static PAYLOAD_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff];
352 
353     #[test]
354     #[cfg(feature = "proto-ipv4")]
test_deconstruct()355     fn test_deconstruct() {
356         let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
357         assert_eq!(packet.src_port(), 48896);
358         assert_eq!(packet.dst_port(), 53);
359         assert_eq!(packet.len(), 12);
360         assert_eq!(packet.checksum(), 0x124d);
361         assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]);
362         assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()));
363     }
364 
365     #[test]
366     #[cfg(feature = "proto-ipv4")]
test_construct()367     fn test_construct() {
368         let mut bytes = vec![0xa5; 12];
369         let mut packet = Packet::new_unchecked(&mut bytes);
370         packet.set_src_port(48896);
371         packet.set_dst_port(53);
372         packet.set_len(12);
373         packet.set_checksum(0xffff);
374         packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]);
375         packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
376         assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]);
377     }
378 
379     #[test]
test_impossible_len()380     fn test_impossible_len() {
381         let mut bytes = vec![0; 12];
382         let mut packet = Packet::new_unchecked(&mut bytes);
383         packet.set_len(4);
384         assert_eq!(packet.check_len(), Err(Error));
385     }
386 
387     #[test]
388     #[cfg(feature = "proto-ipv4")]
test_zero_checksum()389     fn test_zero_checksum() {
390         let mut bytes = vec![0; 8];
391         let mut packet = Packet::new_unchecked(&mut bytes);
392         packet.set_src_port(1);
393         packet.set_dst_port(31881);
394         packet.set_len(8);
395         packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
396         assert_eq!(packet.checksum(), 0xffff);
397     }
398 
399     #[test]
400     #[cfg(feature = "proto-ipv4")]
test_no_checksum()401     fn test_no_checksum() {
402         let mut bytes = vec![0; 8];
403         let mut packet = Packet::new_unchecked(&mut bytes);
404         packet.set_src_port(1);
405         packet.set_dst_port(31881);
406         packet.set_len(8);
407         packet.set_checksum(0);
408         assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()));
409     }
410 
411     #[cfg(feature = "proto-ipv4")]
packet_repr() -> Repr412     fn packet_repr() -> Repr {
413         Repr {
414             src_port: 48896,
415             dst_port: 53,
416         }
417     }
418 
419     #[test]
420     #[cfg(feature = "proto-ipv4")]
test_parse()421     fn test_parse() {
422         let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
423         let repr = Repr::parse(
424             &packet,
425             &SRC_ADDR.into(),
426             &DST_ADDR.into(),
427             &ChecksumCapabilities::default(),
428         )
429         .unwrap();
430         assert_eq!(repr, packet_repr());
431     }
432 
433     #[test]
434     #[cfg(feature = "proto-ipv4")]
test_emit()435     fn test_emit() {
436         let repr = packet_repr();
437         let mut bytes = vec![0xa5; repr.header_len() + PAYLOAD_BYTES.len()];
438         let mut packet = Packet::new_unchecked(&mut bytes);
439         repr.emit(
440             &mut packet,
441             &SRC_ADDR.into(),
442             &DST_ADDR.into(),
443             PAYLOAD_BYTES.len(),
444             |payload| payload.copy_from_slice(&PAYLOAD_BYTES),
445             &ChecksumCapabilities::default(),
446         );
447         assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]);
448     }
449 
450     #[test]
451     #[cfg(feature = "proto-ipv4")]
test_checksum_omitted()452     fn test_checksum_omitted() {
453         let packet = Packet::new_unchecked(&NO_CHECKSUM_PACKET[..]);
454         let repr = Repr::parse(
455             &packet,
456             &SRC_ADDR.into(),
457             &DST_ADDR.into(),
458             &ChecksumCapabilities::default(),
459         )
460         .unwrap();
461         assert_eq!(repr, packet_repr());
462     }
463 }
464