1 use byteorder::{ByteOrder, NetworkEndian};
2 use core::fmt;
3 
4 use super::{Error, Result};
5 use crate::time::Duration;
6 use crate::wire::ip::checksum;
7 
8 use crate::wire::Ipv4Address;
9 
10 enum_with_unknown! {
11     /// Internet Group Management Protocol v1/v2 message version/type.
12     pub enum Message(u8) {
13         /// Membership Query
14         MembershipQuery = 0x11,
15         /// Version 2 Membership Report
16         MembershipReportV2 = 0x16,
17         /// Leave Group
18         LeaveGroup = 0x17,
19         /// Version 1 Membership Report
20         MembershipReportV1 = 0x12
21     }
22 }
23 
24 /// A read/write wrapper around an Internet Group Management Protocol v1/v2 packet buffer.
25 #[derive(Debug)]
26 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
27 pub struct Packet<T: AsRef<[u8]>> {
28     buffer: T,
29 }
30 
31 mod field {
32     use crate::wire::field::*;
33 
34     pub const TYPE: usize = 0;
35     pub const MAX_RESP_CODE: usize = 1;
36     pub const CHECKSUM: Field = 2..4;
37     pub const GROUP_ADDRESS: Field = 4..8;
38 }
39 
40 impl fmt::Display for Message {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result41     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42         match *self {
43             Message::MembershipQuery => write!(f, "membership query"),
44             Message::MembershipReportV2 => write!(f, "version 2 membership report"),
45             Message::LeaveGroup => write!(f, "leave group"),
46             Message::MembershipReportV1 => write!(f, "version 1 membership report"),
47             Message::Unknown(id) => write!(f, "{id}"),
48         }
49     }
50 }
51 
52 /// Internet Group Management Protocol v1/v2 defined in [RFC 2236].
53 ///
54 /// [RFC 2236]: https://tools.ietf.org/html/rfc2236
55 impl<T: AsRef<[u8]>> Packet<T> {
56     /// Imbue a raw octet buffer with IGMPv2 packet structure.
new_unchecked(buffer: T) -> Packet<T>57     pub const fn new_unchecked(buffer: T) -> Packet<T> {
58         Packet { buffer }
59     }
60 
61     /// Shorthand for a combination of [new_unchecked] and [check_len].
62     ///
63     /// [new_unchecked]: #method.new_unchecked
64     /// [check_len]: #method.check_len
new_checked(buffer: T) -> Result<Packet<T>>65     pub fn new_checked(buffer: T) -> Result<Packet<T>> {
66         let packet = Self::new_unchecked(buffer);
67         packet.check_len()?;
68         Ok(packet)
69     }
70 
71     /// Ensure that no accessor method will panic if called.
72     /// Returns `Err(Error)` if the buffer is too short.
check_len(&self) -> Result<()>73     pub fn check_len(&self) -> Result<()> {
74         let len = self.buffer.as_ref().len();
75         if len < field::GROUP_ADDRESS.end {
76             Err(Error)
77         } else {
78             Ok(())
79         }
80     }
81 
82     /// Consume the packet, returning the underlying buffer.
into_inner(self) -> T83     pub fn into_inner(self) -> T {
84         self.buffer
85     }
86 
87     /// Return the message type field.
88     #[inline]
msg_type(&self) -> Message89     pub fn msg_type(&self) -> Message {
90         let data = self.buffer.as_ref();
91         Message::from(data[field::TYPE])
92     }
93 
94     /// Return the maximum response time, using the encoding specified in
95     /// [RFC 3376]: 4.1.1. Max Resp Code.
96     ///
97     /// [RFC 3376]: https://tools.ietf.org/html/rfc3376
98     #[inline]
max_resp_code(&self) -> u899     pub fn max_resp_code(&self) -> u8 {
100         let data = self.buffer.as_ref();
101         data[field::MAX_RESP_CODE]
102     }
103 
104     /// Return the checksum field.
105     #[inline]
checksum(&self) -> u16106     pub fn checksum(&self) -> u16 {
107         let data = self.buffer.as_ref();
108         NetworkEndian::read_u16(&data[field::CHECKSUM])
109     }
110 
111     /// Return the source address field.
112     #[inline]
group_addr(&self) -> Ipv4Address113     pub fn group_addr(&self) -> Ipv4Address {
114         let data = self.buffer.as_ref();
115         Ipv4Address::from_bytes(&data[field::GROUP_ADDRESS])
116     }
117 
118     /// Validate the header checksum.
119     ///
120     /// # Fuzzing
121     /// This function always returns `true` when fuzzing.
verify_checksum(&self) -> bool122     pub fn verify_checksum(&self) -> bool {
123         if cfg!(fuzzing) {
124             return true;
125         }
126 
127         let data = self.buffer.as_ref();
128         checksum::data(data) == !0
129     }
130 }
131 
132 impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
133     /// Set the message type field.
134     #[inline]
set_msg_type(&mut self, value: Message)135     pub fn set_msg_type(&mut self, value: Message) {
136         let data = self.buffer.as_mut();
137         data[field::TYPE] = value.into()
138     }
139 
140     /// Set the maximum response time, using the encoding specified in
141     /// [RFC 3376]: 4.1.1. Max Resp Code.
142     #[inline]
set_max_resp_code(&mut self, value: u8)143     pub fn set_max_resp_code(&mut self, value: u8) {
144         let data = self.buffer.as_mut();
145         data[field::MAX_RESP_CODE] = value;
146     }
147 
148     /// Set the checksum field.
149     #[inline]
set_checksum(&mut self, value: u16)150     pub fn set_checksum(&mut self, value: u16) {
151         let data = self.buffer.as_mut();
152         NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
153     }
154 
155     /// Set the group address field
156     #[inline]
set_group_address(&mut self, addr: Ipv4Address)157     pub fn set_group_address(&mut self, addr: Ipv4Address) {
158         let data = self.buffer.as_mut();
159         data[field::GROUP_ADDRESS].copy_from_slice(addr.as_bytes());
160     }
161 
162     /// Compute and fill in the header checksum.
fill_checksum(&mut self)163     pub fn fill_checksum(&mut self) {
164         self.set_checksum(0);
165         let checksum = {
166             let data = self.buffer.as_ref();
167             !checksum::data(data)
168         };
169         self.set_checksum(checksum)
170     }
171 }
172 
173 /// A high-level representation of an Internet Group Management Protocol v1/v2 header.
174 #[derive(Debug, PartialEq, Eq, Clone)]
175 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
176 pub enum Repr {
177     MembershipQuery {
178         max_resp_time: Duration,
179         group_addr: Ipv4Address,
180         version: IgmpVersion,
181     },
182     MembershipReport {
183         group_addr: Ipv4Address,
184         version: IgmpVersion,
185     },
186     LeaveGroup {
187         group_addr: Ipv4Address,
188     },
189 }
190 
191 /// Type of IGMP membership report version
192 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
193 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
194 pub enum IgmpVersion {
195     /// IGMPv1
196     Version1,
197     /// IGMPv2
198     Version2,
199 }
200 
201 impl Repr {
202     /// Parse an Internet Group Management Protocol v1/v2 packet and return
203     /// a high-level representation.
parse<T>(packet: &Packet<&T>) -> Result<Repr> where T: AsRef<[u8]> + ?Sized,204     pub fn parse<T>(packet: &Packet<&T>) -> Result<Repr>
205     where
206         T: AsRef<[u8]> + ?Sized,
207     {
208         // Check if the address is 0.0.0.0 or multicast
209         let addr = packet.group_addr();
210         if !addr.is_unspecified() && !addr.is_multicast() {
211             return Err(Error);
212         }
213 
214         // construct a packet based on the Type field
215         match packet.msg_type() {
216             Message::MembershipQuery => {
217                 let max_resp_time = max_resp_code_to_duration(packet.max_resp_code());
218                 // See RFC 3376: 7.1. Query Version Distinctions
219                 let version = if packet.max_resp_code() == 0 {
220                     IgmpVersion::Version1
221                 } else {
222                     IgmpVersion::Version2
223                 };
224                 Ok(Repr::MembershipQuery {
225                     max_resp_time,
226                     group_addr: addr,
227                     version,
228                 })
229             }
230             Message::MembershipReportV2 => Ok(Repr::MembershipReport {
231                 group_addr: packet.group_addr(),
232                 version: IgmpVersion::Version2,
233             }),
234             Message::LeaveGroup => Ok(Repr::LeaveGroup {
235                 group_addr: packet.group_addr(),
236             }),
237             Message::MembershipReportV1 => {
238                 // for backwards compatibility with IGMPv1
239                 Ok(Repr::MembershipReport {
240                     group_addr: packet.group_addr(),
241                     version: IgmpVersion::Version1,
242                 })
243             }
244             _ => Err(Error),
245         }
246     }
247 
248     /// Return the length of a packet that will be emitted from this high-level representation.
buffer_len(&self) -> usize249     pub const fn buffer_len(&self) -> usize {
250         // always 8 bytes
251         field::GROUP_ADDRESS.end
252     }
253 
254     /// Emit a high-level representation into an Internet Group Management Protocol v2 packet.
emit<T>(&self, packet: &mut Packet<&mut T>) where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized,255     pub fn emit<T>(&self, packet: &mut Packet<&mut T>)
256     where
257         T: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
258     {
259         match *self {
260             Repr::MembershipQuery {
261                 max_resp_time,
262                 group_addr,
263                 version,
264             } => {
265                 packet.set_msg_type(Message::MembershipQuery);
266                 match version {
267                     IgmpVersion::Version1 => packet.set_max_resp_code(0),
268                     IgmpVersion::Version2 => {
269                         packet.set_max_resp_code(duration_to_max_resp_code(max_resp_time))
270                     }
271                 }
272                 packet.set_group_address(group_addr);
273             }
274             Repr::MembershipReport {
275                 group_addr,
276                 version,
277             } => {
278                 match version {
279                     IgmpVersion::Version1 => packet.set_msg_type(Message::MembershipReportV1),
280                     IgmpVersion::Version2 => packet.set_msg_type(Message::MembershipReportV2),
281                 };
282                 packet.set_max_resp_code(0);
283                 packet.set_group_address(group_addr);
284             }
285             Repr::LeaveGroup { group_addr } => {
286                 packet.set_msg_type(Message::LeaveGroup);
287                 packet.set_group_address(group_addr);
288             }
289         }
290 
291         packet.fill_checksum()
292     }
293 }
294 
max_resp_code_to_duration(value: u8) -> Duration295 fn max_resp_code_to_duration(value: u8) -> Duration {
296     let value: u64 = value.into();
297     let decisecs = if value < 128 {
298         value
299     } else {
300         let mant = value & 0xF;
301         let exp = (value >> 4) & 0x7;
302         (mant | 0x10) << (exp + 3)
303     };
304     Duration::from_millis(decisecs * 100)
305 }
306 
duration_to_max_resp_code(duration: Duration) -> u8307 const fn duration_to_max_resp_code(duration: Duration) -> u8 {
308     let decisecs = duration.total_millis() / 100;
309     if decisecs < 128 {
310         decisecs as u8
311     } else if decisecs < 31744 {
312         let mut mant = decisecs >> 3;
313         let mut exp = 0u8;
314         while mant > 0x1F && exp < 0x8 {
315             mant >>= 1;
316             exp += 1;
317         }
318         0x80 | (exp << 4) | (mant as u8 & 0xF)
319     } else {
320         0xFF
321     }
322 }
323 
324 impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result325     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
326         match Repr::parse(self) {
327             Ok(repr) => write!(f, "{repr}"),
328             Err(err) => write!(f, "IGMP ({err})"),
329         }
330     }
331 }
332 
333 impl fmt::Display for Repr {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result334     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
335         match *self {
336             Repr::MembershipQuery {
337                 max_resp_time,
338                 group_addr,
339                 version,
340             } => write!(
341                 f,
342                 "IGMP membership query max_resp_time={max_resp_time} group_addr={group_addr} version={version:?}"
343             ),
344             Repr::MembershipReport {
345                 group_addr,
346                 version,
347             } => write!(
348                 f,
349                 "IGMP membership report group_addr={group_addr} version={version:?}"
350             ),
351             Repr::LeaveGroup { group_addr } => {
352                 write!(f, "IGMP leave group group_addr={group_addr})")
353             }
354         }
355     }
356 }
357 
358 use crate::wire::pretty_print::{PrettyIndent, PrettyPrint};
359 
360 impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
pretty_print( buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter, indent: &mut PrettyIndent, ) -> fmt::Result361     fn pretty_print(
362         buffer: &dyn AsRef<[u8]>,
363         f: &mut fmt::Formatter,
364         indent: &mut PrettyIndent,
365     ) -> fmt::Result {
366         match Packet::new_checked(buffer) {
367             Err(err) => writeln!(f, "{indent}({err})"),
368             Ok(packet) => writeln!(f, "{indent}{packet}"),
369         }
370     }
371 }
372 
373 #[cfg(test)]
374 mod test {
375     use super::*;
376 
377     static LEAVE_PACKET_BYTES: [u8; 8] = [0x17, 0x00, 0x02, 0x69, 0xe0, 0x00, 0x06, 0x96];
378     static REPORT_PACKET_BYTES: [u8; 8] = [0x16, 0x00, 0x08, 0xda, 0xe1, 0x00, 0x00, 0x25];
379 
380     #[test]
test_leave_group_deconstruct()381     fn test_leave_group_deconstruct() {
382         let packet = Packet::new_unchecked(&LEAVE_PACKET_BYTES[..]);
383         assert_eq!(packet.msg_type(), Message::LeaveGroup);
384         assert_eq!(packet.max_resp_code(), 0);
385         assert_eq!(packet.checksum(), 0x269);
386         assert_eq!(
387             packet.group_addr(),
388             Ipv4Address::from_bytes(&[224, 0, 6, 150])
389         );
390         assert!(packet.verify_checksum());
391     }
392 
393     #[test]
test_report_deconstruct()394     fn test_report_deconstruct() {
395         let packet = Packet::new_unchecked(&REPORT_PACKET_BYTES[..]);
396         assert_eq!(packet.msg_type(), Message::MembershipReportV2);
397         assert_eq!(packet.max_resp_code(), 0);
398         assert_eq!(packet.checksum(), 0x08da);
399         assert_eq!(
400             packet.group_addr(),
401             Ipv4Address::from_bytes(&[225, 0, 0, 37])
402         );
403         assert!(packet.verify_checksum());
404     }
405 
406     #[test]
test_leave_construct()407     fn test_leave_construct() {
408         let mut bytes = vec![0xa5; 8];
409         let mut packet = Packet::new_unchecked(&mut bytes);
410         packet.set_msg_type(Message::LeaveGroup);
411         packet.set_max_resp_code(0);
412         packet.set_group_address(Ipv4Address::from_bytes(&[224, 0, 6, 150]));
413         packet.fill_checksum();
414         assert_eq!(&*packet.into_inner(), &LEAVE_PACKET_BYTES[..]);
415     }
416 
417     #[test]
test_report_construct()418     fn test_report_construct() {
419         let mut bytes = vec![0xa5; 8];
420         let mut packet = Packet::new_unchecked(&mut bytes);
421         packet.set_msg_type(Message::MembershipReportV2);
422         packet.set_max_resp_code(0);
423         packet.set_group_address(Ipv4Address::from_bytes(&[225, 0, 0, 37]));
424         packet.fill_checksum();
425         assert_eq!(&*packet.into_inner(), &REPORT_PACKET_BYTES[..]);
426     }
427 
428     #[test]
max_resp_time_to_duration_and_back()429     fn max_resp_time_to_duration_and_back() {
430         for i in 0..256usize {
431             let time1 = i as u8;
432             let duration = max_resp_code_to_duration(time1);
433             let time2 = duration_to_max_resp_code(duration);
434             assert!(time1 == time2);
435         }
436     }
437 
438     #[test]
duration_to_max_resp_time_max()439     fn duration_to_max_resp_time_max() {
440         for duration in 31744..65536 {
441             let time = duration_to_max_resp_code(Duration::from_millis(duration * 100));
442             assert_eq!(time, 0xFF);
443         }
444     }
445 }
446