1 #![allow(unused)] 2 3 use core::fmt; 4 5 use managed::{ManagedMap, ManagedSlice}; 6 7 use crate::config::{REASSEMBLY_BUFFER_COUNT, REASSEMBLY_BUFFER_SIZE}; 8 use crate::storage::Assembler; 9 use crate::time::{Duration, Instant}; 10 11 #[cfg(feature = "alloc")] 12 type Buffer = alloc::vec::Vec<u8>; 13 #[cfg(not(feature = "alloc"))] 14 type Buffer = [u8; REASSEMBLY_BUFFER_SIZE]; 15 16 /// Problem when assembling: something was out of bounds. 17 #[derive(Copy, Clone, PartialEq, Eq, Debug)] 18 #[cfg_attr(feature = "defmt", derive(defmt::Format))] 19 pub struct AssemblerError; 20 21 /// Packet assembler is full 22 #[derive(Copy, Clone, PartialEq, Eq, Debug)] 23 #[cfg_attr(feature = "defmt", derive(defmt::Format))] 24 pub struct AssemblerFullError; 25 26 /// Holds different fragments of one packet, used for assembling fragmented packets. 27 /// 28 /// The buffer used for the `PacketAssembler` should either be dynamically sized (ex: Vec<u8>) 29 /// or should be statically allocated based upon the MTU of the type of packet being 30 /// assembled (ex: 1280 for a IPv6 frame). 31 #[derive(Debug)] 32 pub struct PacketAssembler<K> { 33 key: Option<K>, 34 buffer: Buffer, 35 36 assembler: Assembler, 37 total_size: Option<usize>, 38 expires_at: Instant, 39 } 40 41 impl<K> PacketAssembler<K> { 42 /// Create a new empty buffer for fragments. new() -> Self43 pub const fn new() -> Self { 44 Self { 45 key: None, 46 47 #[cfg(feature = "alloc")] 48 buffer: Buffer::new(), 49 #[cfg(not(feature = "alloc"))] 50 buffer: [0u8; REASSEMBLY_BUFFER_SIZE], 51 52 assembler: Assembler::new(), 53 total_size: None, 54 expires_at: Instant::ZERO, 55 } 56 } 57 reset(&mut self)58 pub(crate) fn reset(&mut self) { 59 self.key = None; 60 self.assembler.clear(); 61 self.total_size = None; 62 self.expires_at = Instant::ZERO; 63 } 64 65 /// Set the total size of the packet assembler. set_total_size(&mut self, size: usize) -> Result<(), AssemblerError>66 pub(crate) fn set_total_size(&mut self, size: usize) -> Result<(), AssemblerError> { 67 if let Some(old_size) = self.total_size { 68 if old_size != size { 69 return Err(AssemblerError); 70 } 71 } 72 73 #[cfg(not(feature = "alloc"))] 74 if self.buffer.len() < size { 75 return Err(AssemblerError); 76 } 77 78 #[cfg(feature = "alloc")] 79 if self.buffer.len() < size { 80 self.buffer.resize(size, 0); 81 } 82 83 self.total_size = Some(size); 84 Ok(()) 85 } 86 87 /// Return the instant when the assembler expires. expires_at(&self) -> Instant88 pub(crate) fn expires_at(&self) -> Instant { 89 self.expires_at 90 } 91 add_with( &mut self, offset: usize, f: impl Fn(&mut [u8]) -> Result<usize, AssemblerError>, ) -> Result<(), AssemblerError>92 pub(crate) fn add_with( 93 &mut self, 94 offset: usize, 95 f: impl Fn(&mut [u8]) -> Result<usize, AssemblerError>, 96 ) -> Result<(), AssemblerError> { 97 if self.buffer.len() < offset { 98 return Err(AssemblerError); 99 } 100 101 let len = f(&mut self.buffer[offset..])?; 102 assert!(offset + len <= self.buffer.len()); 103 104 net_debug!( 105 "frag assembler: receiving {} octets at offset {}", 106 len, 107 offset 108 ); 109 110 self.assembler.add(offset, len); 111 Ok(()) 112 } 113 114 /// Add a fragment into the packet that is being reassembled. 115 /// 116 /// # Errors 117 /// 118 /// - Returns [`Error::PacketAssemblerBufferTooSmall`] when trying to add data into the buffer at a non-existing 119 /// place. add(&mut self, data: &[u8], offset: usize) -> Result<(), AssemblerError>120 pub(crate) fn add(&mut self, data: &[u8], offset: usize) -> Result<(), AssemblerError> { 121 #[cfg(not(feature = "alloc"))] 122 if self.buffer.len() < offset + data.len() { 123 return Err(AssemblerError); 124 } 125 126 #[cfg(feature = "alloc")] 127 if self.buffer.len() < offset + data.len() { 128 self.buffer.resize(offset + data.len(), 0); 129 } 130 131 let len = data.len(); 132 self.buffer[offset..][..len].copy_from_slice(data); 133 134 net_debug!( 135 "frag assembler: receiving {} octets at offset {}", 136 len, 137 offset 138 ); 139 140 self.assembler.add(offset, data.len()); 141 Ok(()) 142 } 143 144 /// Get an immutable slice of the underlying packet data, if reassembly complete. 145 /// This will mark the assembler as empty, so that it can be reused. assemble(&mut self) -> Option<&'_ [u8]>146 pub(crate) fn assemble(&mut self) -> Option<&'_ [u8]> { 147 if !self.is_complete() { 148 return None; 149 } 150 151 // NOTE: we can unwrap because `is_complete` already checks this. 152 let total_size = self.total_size.unwrap(); 153 self.reset(); 154 Some(&self.buffer[..total_size]) 155 } 156 157 /// Returns `true` when all fragments have been received, otherwise `false`. is_complete(&self) -> bool158 pub(crate) fn is_complete(&self) -> bool { 159 self.total_size == Some(self.assembler.peek_front()) 160 } 161 162 /// Returns `true` when the packet assembler is free to use. is_free(&self) -> bool163 fn is_free(&self) -> bool { 164 self.key.is_none() 165 } 166 } 167 168 /// Set holding multiple [`PacketAssembler`]. 169 #[derive(Debug)] 170 pub struct PacketAssemblerSet<K: Eq + Copy> { 171 assemblers: [PacketAssembler<K>; REASSEMBLY_BUFFER_COUNT], 172 } 173 174 impl<K: Eq + Copy> PacketAssemblerSet<K> { 175 const NEW_PA: PacketAssembler<K> = PacketAssembler::new(); 176 177 /// Create a new set of packet assemblers. new() -> Self178 pub fn new() -> Self { 179 Self { 180 assemblers: [Self::NEW_PA; REASSEMBLY_BUFFER_COUNT], 181 } 182 } 183 184 /// Get a [`PacketAssembler`] for a specific key. 185 /// 186 /// If it doesn't exist, it is created, with the `expires_at` timestamp. 187 /// 188 /// If the assembler set is full, in which case an error is returned. get( &mut self, key: &K, expires_at: Instant, ) -> Result<&mut PacketAssembler<K>, AssemblerFullError>189 pub(crate) fn get( 190 &mut self, 191 key: &K, 192 expires_at: Instant, 193 ) -> Result<&mut PacketAssembler<K>, AssemblerFullError> { 194 let mut empty_slot = None; 195 for slot in &mut self.assemblers { 196 if slot.key.as_ref() == Some(key) { 197 return Ok(slot); 198 } 199 if slot.is_free() { 200 empty_slot = Some(slot) 201 } 202 } 203 204 let slot = empty_slot.ok_or(AssemblerFullError)?; 205 slot.key = Some(*key); 206 slot.expires_at = expires_at; 207 Ok(slot) 208 } 209 210 /// Remove all [`PacketAssembler`]s that are expired. remove_expired(&mut self, timestamp: Instant)211 pub fn remove_expired(&mut self, timestamp: Instant) { 212 for frag in &mut self.assemblers { 213 if !frag.is_free() && frag.expires_at < timestamp { 214 frag.reset(); 215 } 216 } 217 } 218 } 219 220 #[cfg(test)] 221 mod tests { 222 use super::*; 223 224 #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] 225 struct Key { 226 id: usize, 227 } 228 229 #[test] packet_assembler_overlap()230 fn packet_assembler_overlap() { 231 let mut p_assembler = PacketAssembler::<Key>::new(); 232 233 p_assembler.set_total_size(5).unwrap(); 234 235 let data = b"Rust"; 236 p_assembler.add(&data[..], 0); 237 p_assembler.add(&data[..], 1); 238 239 assert_eq!(p_assembler.assemble(), Some(&b"RRust"[..])) 240 } 241 242 #[test] packet_assembler_assemble()243 fn packet_assembler_assemble() { 244 let mut p_assembler = PacketAssembler::<Key>::new(); 245 246 let data = b"Hello World!"; 247 248 p_assembler.set_total_size(data.len()).unwrap(); 249 250 p_assembler.add(b"Hello ", 0).unwrap(); 251 assert_eq!(p_assembler.assemble(), None); 252 253 p_assembler.add(b"World!", b"Hello ".len()).unwrap(); 254 255 assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..])); 256 } 257 258 #[test] packet_assembler_out_of_order_assemble()259 fn packet_assembler_out_of_order_assemble() { 260 let mut p_assembler = PacketAssembler::<Key>::new(); 261 262 let data = b"Hello World!"; 263 264 p_assembler.set_total_size(data.len()).unwrap(); 265 266 p_assembler.add(b"World!", b"Hello ".len()).unwrap(); 267 assert_eq!(p_assembler.assemble(), None); 268 269 p_assembler.add(b"Hello ", 0).unwrap(); 270 271 assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..])); 272 } 273 274 #[test] packet_assembler_set()275 fn packet_assembler_set() { 276 let key = Key { id: 1 }; 277 278 let mut set = PacketAssemblerSet::new(); 279 280 assert!(set.get(&key, Instant::ZERO).is_ok()); 281 } 282 283 #[test] packet_assembler_set_full()284 fn packet_assembler_set_full() { 285 let mut set = PacketAssemblerSet::new(); 286 for i in 0..REASSEMBLY_BUFFER_COUNT { 287 set.get(&Key { id: i }, Instant::ZERO).unwrap(); 288 } 289 assert!(set.get(&Key { id: 4 }, Instant::ZERO).is_err()); 290 } 291 292 #[test] packet_assembler_set_assembling_many()293 fn packet_assembler_set_assembling_many() { 294 let mut set = PacketAssemblerSet::new(); 295 296 let key = Key { id: 0 }; 297 let assr = set.get(&key, Instant::ZERO).unwrap(); 298 assert_eq!(assr.assemble(), None); 299 assr.set_total_size(0).unwrap(); 300 assr.assemble().unwrap(); 301 302 // Test that `.assemble()` effectively deletes it. 303 let assr = set.get(&key, Instant::ZERO).unwrap(); 304 assert_eq!(assr.assemble(), None); 305 assr.set_total_size(0).unwrap(); 306 assr.assemble().unwrap(); 307 308 let key = Key { id: 1 }; 309 let assr = set.get(&key, Instant::ZERO).unwrap(); 310 assr.set_total_size(0).unwrap(); 311 assr.assemble().unwrap(); 312 313 let key = Key { id: 2 }; 314 let assr = set.get(&key, Instant::ZERO).unwrap(); 315 assr.set_total_size(0).unwrap(); 316 assr.assemble().unwrap(); 317 318 let key = Key { id: 2 }; 319 let assr = set.get(&key, Instant::ZERO).unwrap(); 320 assr.set_total_size(2).unwrap(); 321 assr.add(&[0x00], 0).unwrap(); 322 assert_eq!(assr.assemble(), None); 323 let assr = set.get(&key, Instant::ZERO).unwrap(); 324 assr.add(&[0x01], 1).unwrap(); 325 assert_eq!(assr.assemble(), Some(&[0x00, 0x01][..])); 326 } 327 } 328