1 #![allow(unused)]
2 
3 use core::fmt;
4 
5 use managed::{ManagedMap, ManagedSlice};
6 
7 use crate::config::{REASSEMBLY_BUFFER_COUNT, REASSEMBLY_BUFFER_SIZE};
8 use crate::storage::Assembler;
9 use crate::time::{Duration, Instant};
10 
11 #[cfg(feature = "alloc")]
12 type Buffer = alloc::vec::Vec<u8>;
13 #[cfg(not(feature = "alloc"))]
14 type Buffer = [u8; REASSEMBLY_BUFFER_SIZE];
15 
16 /// Problem when assembling: something was out of bounds.
17 #[derive(Copy, Clone, PartialEq, Eq, Debug)]
18 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
19 pub struct AssemblerError;
20 
21 /// Packet assembler is full
22 #[derive(Copy, Clone, PartialEq, Eq, Debug)]
23 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
24 pub struct AssemblerFullError;
25 
26 /// Holds different fragments of one packet, used for assembling fragmented packets.
27 ///
28 /// The buffer used for the `PacketAssembler` should either be dynamically sized (ex: Vec<u8>)
29 /// or should be statically allocated based upon the MTU of the type of packet being
30 /// assembled (ex: 1280 for a IPv6 frame).
31 #[derive(Debug)]
32 pub struct PacketAssembler<K> {
33     key: Option<K>,
34     buffer: Buffer,
35 
36     assembler: Assembler,
37     total_size: Option<usize>,
38     expires_at: Instant,
39 }
40 
41 impl<K> PacketAssembler<K> {
42     /// Create a new empty buffer for fragments.
new() -> Self43     pub const fn new() -> Self {
44         Self {
45             key: None,
46 
47             #[cfg(feature = "alloc")]
48             buffer: Buffer::new(),
49             #[cfg(not(feature = "alloc"))]
50             buffer: [0u8; REASSEMBLY_BUFFER_SIZE],
51 
52             assembler: Assembler::new(),
53             total_size: None,
54             expires_at: Instant::ZERO,
55         }
56     }
57 
reset(&mut self)58     pub(crate) fn reset(&mut self) {
59         self.key = None;
60         self.assembler.clear();
61         self.total_size = None;
62         self.expires_at = Instant::ZERO;
63     }
64 
65     /// Set the total size of the packet assembler.
set_total_size(&mut self, size: usize) -> Result<(), AssemblerError>66     pub(crate) fn set_total_size(&mut self, size: usize) -> Result<(), AssemblerError> {
67         if let Some(old_size) = self.total_size {
68             if old_size != size {
69                 return Err(AssemblerError);
70             }
71         }
72 
73         #[cfg(not(feature = "alloc"))]
74         if self.buffer.len() < size {
75             return Err(AssemblerError);
76         }
77 
78         #[cfg(feature = "alloc")]
79         if self.buffer.len() < size {
80             self.buffer.resize(size, 0);
81         }
82 
83         self.total_size = Some(size);
84         Ok(())
85     }
86 
87     /// Return the instant when the assembler expires.
expires_at(&self) -> Instant88     pub(crate) fn expires_at(&self) -> Instant {
89         self.expires_at
90     }
91 
add_with( &mut self, offset: usize, f: impl Fn(&mut [u8]) -> Result<usize, AssemblerError>, ) -> Result<(), AssemblerError>92     pub(crate) fn add_with(
93         &mut self,
94         offset: usize,
95         f: impl Fn(&mut [u8]) -> Result<usize, AssemblerError>,
96     ) -> Result<(), AssemblerError> {
97         if self.buffer.len() < offset {
98             return Err(AssemblerError);
99         }
100 
101         let len = f(&mut self.buffer[offset..])?;
102         assert!(offset + len <= self.buffer.len());
103 
104         net_debug!(
105             "frag assembler: receiving {} octets at offset {}",
106             len,
107             offset
108         );
109 
110         self.assembler.add(offset, len);
111         Ok(())
112     }
113 
114     /// Add a fragment into the packet that is being reassembled.
115     ///
116     /// # Errors
117     ///
118     /// - Returns [`Error::PacketAssemblerBufferTooSmall`] when trying to add data into the buffer at a non-existing
119     /// place.
add(&mut self, data: &[u8], offset: usize) -> Result<(), AssemblerError>120     pub(crate) fn add(&mut self, data: &[u8], offset: usize) -> Result<(), AssemblerError> {
121         #[cfg(not(feature = "alloc"))]
122         if self.buffer.len() < offset + data.len() {
123             return Err(AssemblerError);
124         }
125 
126         #[cfg(feature = "alloc")]
127         if self.buffer.len() < offset + data.len() {
128             self.buffer.resize(offset + data.len(), 0);
129         }
130 
131         let len = data.len();
132         self.buffer[offset..][..len].copy_from_slice(data);
133 
134         net_debug!(
135             "frag assembler: receiving {} octets at offset {}",
136             len,
137             offset
138         );
139 
140         self.assembler.add(offset, data.len());
141         Ok(())
142     }
143 
144     /// Get an immutable slice of the underlying packet data, if reassembly complete.
145     /// This will mark the assembler as empty, so that it can be reused.
assemble(&mut self) -> Option<&'_ [u8]>146     pub(crate) fn assemble(&mut self) -> Option<&'_ [u8]> {
147         if !self.is_complete() {
148             return None;
149         }
150 
151         // NOTE: we can unwrap because `is_complete` already checks this.
152         let total_size = self.total_size.unwrap();
153         self.reset();
154         Some(&self.buffer[..total_size])
155     }
156 
157     /// Returns `true` when all fragments have been received, otherwise `false`.
is_complete(&self) -> bool158     pub(crate) fn is_complete(&self) -> bool {
159         self.total_size == Some(self.assembler.peek_front())
160     }
161 
162     /// Returns `true` when the packet assembler is free to use.
is_free(&self) -> bool163     fn is_free(&self) -> bool {
164         self.key.is_none()
165     }
166 }
167 
168 /// Set holding multiple [`PacketAssembler`].
169 #[derive(Debug)]
170 pub struct PacketAssemblerSet<K: Eq + Copy> {
171     assemblers: [PacketAssembler<K>; REASSEMBLY_BUFFER_COUNT],
172 }
173 
174 impl<K: Eq + Copy> PacketAssemblerSet<K> {
175     const NEW_PA: PacketAssembler<K> = PacketAssembler::new();
176 
177     /// Create a new set of packet assemblers.
new() -> Self178     pub fn new() -> Self {
179         Self {
180             assemblers: [Self::NEW_PA; REASSEMBLY_BUFFER_COUNT],
181         }
182     }
183 
184     /// Get a [`PacketAssembler`] for a specific key.
185     ///
186     /// If it doesn't exist, it is created, with the `expires_at` timestamp.
187     ///
188     /// If the assembler set is full, in which case an error is returned.
get( &mut self, key: &K, expires_at: Instant, ) -> Result<&mut PacketAssembler<K>, AssemblerFullError>189     pub(crate) fn get(
190         &mut self,
191         key: &K,
192         expires_at: Instant,
193     ) -> Result<&mut PacketAssembler<K>, AssemblerFullError> {
194         let mut empty_slot = None;
195         for slot in &mut self.assemblers {
196             if slot.key.as_ref() == Some(key) {
197                 return Ok(slot);
198             }
199             if slot.is_free() {
200                 empty_slot = Some(slot)
201             }
202         }
203 
204         let slot = empty_slot.ok_or(AssemblerFullError)?;
205         slot.key = Some(*key);
206         slot.expires_at = expires_at;
207         Ok(slot)
208     }
209 
210     /// Remove all [`PacketAssembler`]s that are expired.
remove_expired(&mut self, timestamp: Instant)211     pub fn remove_expired(&mut self, timestamp: Instant) {
212         for frag in &mut self.assemblers {
213             if !frag.is_free() && frag.expires_at < timestamp {
214                 frag.reset();
215             }
216         }
217     }
218 }
219 
220 #[cfg(test)]
221 mod tests {
222     use super::*;
223 
224     #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
225     struct Key {
226         id: usize,
227     }
228 
229     #[test]
packet_assembler_overlap()230     fn packet_assembler_overlap() {
231         let mut p_assembler = PacketAssembler::<Key>::new();
232 
233         p_assembler.set_total_size(5).unwrap();
234 
235         let data = b"Rust";
236         p_assembler.add(&data[..], 0);
237         p_assembler.add(&data[..], 1);
238 
239         assert_eq!(p_assembler.assemble(), Some(&b"RRust"[..]))
240     }
241 
242     #[test]
packet_assembler_assemble()243     fn packet_assembler_assemble() {
244         let mut p_assembler = PacketAssembler::<Key>::new();
245 
246         let data = b"Hello World!";
247 
248         p_assembler.set_total_size(data.len()).unwrap();
249 
250         p_assembler.add(b"Hello ", 0).unwrap();
251         assert_eq!(p_assembler.assemble(), None);
252 
253         p_assembler.add(b"World!", b"Hello ".len()).unwrap();
254 
255         assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..]));
256     }
257 
258     #[test]
packet_assembler_out_of_order_assemble()259     fn packet_assembler_out_of_order_assemble() {
260         let mut p_assembler = PacketAssembler::<Key>::new();
261 
262         let data = b"Hello World!";
263 
264         p_assembler.set_total_size(data.len()).unwrap();
265 
266         p_assembler.add(b"World!", b"Hello ".len()).unwrap();
267         assert_eq!(p_assembler.assemble(), None);
268 
269         p_assembler.add(b"Hello ", 0).unwrap();
270 
271         assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..]));
272     }
273 
274     #[test]
packet_assembler_set()275     fn packet_assembler_set() {
276         let key = Key { id: 1 };
277 
278         let mut set = PacketAssemblerSet::new();
279 
280         assert!(set.get(&key, Instant::ZERO).is_ok());
281     }
282 
283     #[test]
packet_assembler_set_full()284     fn packet_assembler_set_full() {
285         let mut set = PacketAssemblerSet::new();
286         for i in 0..REASSEMBLY_BUFFER_COUNT {
287             set.get(&Key { id: i }, Instant::ZERO).unwrap();
288         }
289         assert!(set.get(&Key { id: 4 }, Instant::ZERO).is_err());
290     }
291 
292     #[test]
packet_assembler_set_assembling_many()293     fn packet_assembler_set_assembling_many() {
294         let mut set = PacketAssemblerSet::new();
295 
296         let key = Key { id: 0 };
297         let assr = set.get(&key, Instant::ZERO).unwrap();
298         assert_eq!(assr.assemble(), None);
299         assr.set_total_size(0).unwrap();
300         assr.assemble().unwrap();
301 
302         // Test that `.assemble()` effectively deletes it.
303         let assr = set.get(&key, Instant::ZERO).unwrap();
304         assert_eq!(assr.assemble(), None);
305         assr.set_total_size(0).unwrap();
306         assr.assemble().unwrap();
307 
308         let key = Key { id: 1 };
309         let assr = set.get(&key, Instant::ZERO).unwrap();
310         assr.set_total_size(0).unwrap();
311         assr.assemble().unwrap();
312 
313         let key = Key { id: 2 };
314         let assr = set.get(&key, Instant::ZERO).unwrap();
315         assr.set_total_size(0).unwrap();
316         assr.assemble().unwrap();
317 
318         let key = Key { id: 2 };
319         let assr = set.get(&key, Instant::ZERO).unwrap();
320         assr.set_total_size(2).unwrap();
321         assr.add(&[0x00], 0).unwrap();
322         assert_eq!(assr.assemble(), None);
323         let assr = set.get(&key, Instant::ZERO).unwrap();
324         assr.add(&[0x01], 1).unwrap();
325         assert_eq!(assr.assemble(), Some(&[0x00, 0x01][..]));
326     }
327 }
328