xref: /DragonOS/kernel/src/driver/net/virtio_net.rs (revision 4b0170bd6bb374d0e9699a0076cc23b976ad6db7)
1 use core::{
2     cell::UnsafeCell,
3     fmt::Debug,
4     ops::{Deref, DerefMut},
5 };
6 
7 use alloc::{
8     string::String,
9     sync::{Arc, Weak},
10 };
11 use smoltcp::{phy, wire};
12 use virtio_drivers::{device::net::VirtIONet, transport::Transport};
13 
14 use super::NetDriver;
15 use crate::{
16     driver::{
17         base::{
18             device::{bus::Bus, driver::Driver, Device, DeviceId, IdTable},
19             kobject::{KObjType, KObject, KObjectState},
20         },
21         virtio::{irq::virtio_irq_manager, virtio_impl::HalImpl, VirtIODevice},
22     },
23     exception::{irqdesc::IrqReturn, IrqNumber},
24     kerror, kinfo,
25     libs::spinlock::SpinLock,
26     net::{generate_iface_id, net_core::poll_ifaces_try_lock_onetime, NET_DRIVERS},
27     time::Instant,
28 };
29 use system_error::SystemError;
30 
31 /// @brief Virtio网络设备驱动(加锁)
32 pub struct VirtioNICDriver<T: Transport> {
33     pub inner: Arc<SpinLock<VirtIONet<HalImpl, T, 2>>>,
34 }
35 
36 impl<T: Transport> Clone for VirtioNICDriver<T> {
37     fn clone(&self) -> Self {
38         return VirtioNICDriver {
39             inner: self.inner.clone(),
40         };
41     }
42 }
43 
44 /// 网卡驱动的包裹器,这是为了获取网卡驱动的可变引用而设计的。
45 ///
46 /// 由于smoltcp的设计,导致需要在poll的时候获取网卡驱动的可变引用,
47 /// 同时需要在token的consume里面获取可变引用。为了避免双重加锁,所以需要这个包裹器。
48 struct VirtioNICDriverWrapper<T: Transport>(UnsafeCell<VirtioNICDriver<T>>);
49 unsafe impl<T: Transport> Send for VirtioNICDriverWrapper<T> {}
50 unsafe impl<T: Transport> Sync for VirtioNICDriverWrapper<T> {}
51 
52 impl<T: Transport> Deref for VirtioNICDriverWrapper<T> {
53     type Target = VirtioNICDriver<T>;
54     fn deref(&self) -> &Self::Target {
55         unsafe { &*self.0.get() }
56     }
57 }
58 impl<T: Transport> DerefMut for VirtioNICDriverWrapper<T> {
59     fn deref_mut(&mut self) -> &mut Self::Target {
60         unsafe { &mut *self.0.get() }
61     }
62 }
63 
64 #[allow(clippy::mut_from_ref)]
65 impl<T: Transport> VirtioNICDriverWrapper<T> {
66     fn force_get_mut(&self) -> &mut VirtioNICDriver<T> {
67         unsafe { &mut *self.0.get() }
68     }
69 }
70 
71 impl<T: Transport> Debug for VirtioNICDriver<T> {
72     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
73         f.debug_struct("VirtioNICDriver").finish()
74     }
75 }
76 
77 pub struct VirtioInterface<T: Transport> {
78     driver: VirtioNICDriverWrapper<T>,
79     iface_id: usize,
80     iface: SpinLock<smoltcp::iface::Interface>,
81     name: String,
82     dev_id: Arc<DeviceId>,
83 }
84 
85 impl<T: Transport> Debug for VirtioInterface<T> {
86     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
87         f.debug_struct("VirtioInterface")
88             .field("driver", self.driver.deref())
89             .field("iface_id", &self.iface_id)
90             .field("iface", &"smoltcp::iface::Interface")
91             .field("name", &self.name)
92             .finish()
93     }
94 }
95 
96 impl<T: Transport> VirtioInterface<T> {
97     pub fn new(mut driver: VirtioNICDriver<T>, dev_id: Arc<DeviceId>) -> Arc<Self> {
98         let iface_id = generate_iface_id();
99         let mut iface_config = smoltcp::iface::Config::new();
100 
101         // todo: 随机设定这个值。
102         // 参见 https://docs.rs/smoltcp/latest/smoltcp/iface/struct.Config.html#structfield.random_seed
103         iface_config.random_seed = 12345;
104 
105         iface_config.hardware_addr = Some(wire::HardwareAddress::Ethernet(
106             smoltcp::wire::EthernetAddress(driver.inner.lock().mac_address()),
107         ));
108         let iface = smoltcp::iface::Interface::new(iface_config, &mut driver);
109 
110         let driver: VirtioNICDriverWrapper<T> = VirtioNICDriverWrapper(UnsafeCell::new(driver));
111         let result = Arc::new(VirtioInterface {
112             driver,
113             iface_id,
114             iface: SpinLock::new(iface),
115             name: format!("eth{}", iface_id),
116             dev_id,
117         });
118 
119         return result;
120     }
121 }
122 
123 impl<T: Transport + 'static> VirtIODevice for VirtioInterface<T> {
124     fn handle_irq(&self, _irq: IrqNumber) -> Result<IrqReturn, SystemError> {
125         poll_ifaces_try_lock_onetime().ok();
126         return Ok(IrqReturn::Handled);
127     }
128 
129     fn dev_id(&self) -> &Arc<DeviceId> {
130         return &self.dev_id;
131     }
132 }
133 
134 impl<T: Transport> Drop for VirtioInterface<T> {
135     fn drop(&mut self) {
136         // 从全局的网卡接口信息表中删除这个网卡的接口信息
137         NET_DRIVERS.write_irqsave().remove(&self.iface_id);
138     }
139 }
140 
141 impl<T: 'static + Transport> VirtioNICDriver<T> {
142     pub fn new(driver_net: VirtIONet<HalImpl, T, 2>) -> Self {
143         let mut iface_config = smoltcp::iface::Config::new();
144 
145         // todo: 随机设定这个值。
146         // 参见 https://docs.rs/smoltcp/latest/smoltcp/iface/struct.Config.html#structfield.random_seed
147         iface_config.random_seed = 12345;
148 
149         iface_config.hardware_addr = Some(wire::HardwareAddress::Ethernet(
150             smoltcp::wire::EthernetAddress(driver_net.mac_address()),
151         ));
152 
153         let inner: Arc<SpinLock<VirtIONet<HalImpl, T, 2>>> = Arc::new(SpinLock::new(driver_net));
154         let result = VirtioNICDriver { inner };
155         return result;
156     }
157 }
158 
159 pub struct VirtioNetToken<T: Transport> {
160     driver: VirtioNICDriver<T>,
161     rx_buffer: Option<virtio_drivers::device::net::RxBuffer>,
162 }
163 
164 impl<T: Transport> VirtioNetToken<T> {
165     pub fn new(
166         driver: VirtioNICDriver<T>,
167         rx_buffer: Option<virtio_drivers::device::net::RxBuffer>,
168     ) -> Self {
169         return Self { driver, rx_buffer };
170     }
171 }
172 
173 impl<T: Transport> phy::Device for VirtioNICDriver<T> {
174     type RxToken<'a> = VirtioNetToken<T> where Self: 'a;
175     type TxToken<'a> = VirtioNetToken<T> where Self: 'a;
176 
177     fn receive(
178         &mut self,
179         _timestamp: smoltcp::time::Instant,
180     ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
181         match self.inner.lock().receive() {
182             Ok(buf) => Some((
183                 VirtioNetToken::new(self.clone(), Some(buf)),
184                 VirtioNetToken::new(self.clone(), None),
185             )),
186             Err(virtio_drivers::Error::NotReady) => None,
187             Err(err) => panic!("VirtIO receive failed: {}", err),
188         }
189     }
190 
191     fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> {
192         // kdebug!("VirtioNet: transmit");
193         if self.inner.lock_irqsave().can_send() {
194             // kdebug!("VirtioNet: can send");
195             return Some(VirtioNetToken::new(self.clone(), None));
196         } else {
197             // kdebug!("VirtioNet: can not send");
198             return None;
199         }
200     }
201 
202     fn capabilities(&self) -> phy::DeviceCapabilities {
203         let mut caps = phy::DeviceCapabilities::default();
204         // 网卡的最大传输单元. 请与IP层的MTU进行区分。这个值应当是网卡的最大传输单元,而不是IP层的MTU。
205         caps.max_transmission_unit = 2000;
206         /*
207            Maximum burst size, in terms of MTU.
208            The network device is unable to send or receive bursts large than the value returned by this function.
209            If None, there is no fixed limit on burst size, e.g. if network buffers are dynamically allocated.
210         */
211         caps.max_burst_size = Some(1);
212         return caps;
213     }
214 }
215 
216 impl<T: Transport> phy::TxToken for VirtioNetToken<T> {
217     fn consume<R, F>(self, len: usize, f: F) -> R
218     where
219         F: FnOnce(&mut [u8]) -> R,
220     {
221         // // 为了线程安全,这里需要对VirtioNet进行加【写锁】,以保证对设备的互斥访问。
222 
223         let mut driver_net = self.driver.inner.lock();
224         let mut tx_buf = driver_net.new_tx_buffer(len);
225         let result = f(tx_buf.packet_mut());
226         driver_net.send(tx_buf).expect("virtio_net send failed");
227         return result;
228     }
229 }
230 
231 impl<T: Transport> phy::RxToken for VirtioNetToken<T> {
232     fn consume<R, F>(self, f: F) -> R
233     where
234         F: FnOnce(&mut [u8]) -> R,
235     {
236         // 为了线程安全,这里需要对VirtioNet进行加【写锁】,以保证对设备的互斥访问。
237         let mut rx_buf = self.rx_buffer.unwrap();
238         let result = f(rx_buf.packet_mut());
239         self.driver
240             .inner
241             .lock()
242             .recycle_rx_buffer(rx_buf)
243             .expect("virtio_net recv failed");
244         result
245     }
246 }
247 
248 /// @brief virtio-net 驱动的初始化与测试
249 pub fn virtio_net<T: Transport + 'static>(transport: T, dev_id: Arc<DeviceId>) {
250     let driver_net: VirtIONet<HalImpl, T, 2> =
251         match VirtIONet::<HalImpl, T, 2>::new(transport, 4096) {
252             Ok(net) => net,
253             Err(_) => {
254                 kerror!("VirtIONet init failed");
255                 return;
256             }
257         };
258     let mac = smoltcp::wire::EthernetAddress::from_bytes(&driver_net.mac_address());
259     let driver: VirtioNICDriver<T> = VirtioNICDriver::new(driver_net);
260     let iface = VirtioInterface::new(driver, dev_id);
261     let name = iface.name.clone();
262     // 将网卡的接口信息注册到全局的网卡接口信息表中
263     NET_DRIVERS
264         .write_irqsave()
265         .insert(iface.nic_id(), iface.clone());
266 
267     virtio_irq_manager()
268         .register_device(iface.clone())
269         .expect("Register virtio net failed");
270     kinfo!(
271         "Virtio-net driver init successfully!\tNetDevID: [{}], MAC: [{}]",
272         name,
273         mac
274     );
275 }
276 
277 impl<T: Transport + 'static> Driver for VirtioInterface<T> {
278     fn id_table(&self) -> Option<IdTable> {
279         todo!()
280     }
281 
282     fn add_device(&self, _device: Arc<dyn Device>) {
283         todo!()
284     }
285 
286     fn delete_device(&self, _device: &Arc<dyn Device>) {
287         todo!()
288     }
289 
290     fn devices(&self) -> alloc::vec::Vec<Arc<dyn Device>> {
291         todo!()
292     }
293 
294     fn bus(&self) -> Option<Weak<dyn Bus>> {
295         todo!()
296     }
297 
298     fn set_bus(&self, _bus: Option<Weak<dyn Bus>>) {
299         todo!()
300     }
301 }
302 
303 impl<T: Transport + 'static> NetDriver for VirtioInterface<T> {
304     fn mac(&self) -> smoltcp::wire::EthernetAddress {
305         let mac: [u8; 6] = self.driver.inner.lock().mac_address();
306         return smoltcp::wire::EthernetAddress::from_bytes(&mac);
307     }
308 
309     #[inline]
310     fn nic_id(&self) -> usize {
311         return self.iface_id;
312     }
313 
314     #[inline]
315     fn name(&self) -> String {
316         return self.name.clone();
317     }
318 
319     fn update_ip_addrs(&self, ip_addrs: &[wire::IpCidr]) -> Result<(), SystemError> {
320         if ip_addrs.len() != 1 {
321             return Err(SystemError::EINVAL);
322         }
323 
324         self.iface.lock().update_ip_addrs(|addrs| {
325             let dest = addrs.iter_mut().next();
326 
327             if let Some(dest) = dest {
328                 *dest = ip_addrs[0];
329             } else {
330                 addrs.push(ip_addrs[0]).expect("Push ipCidr failed: full");
331             }
332         });
333         return Ok(());
334     }
335 
336     fn poll(&self, sockets: &mut smoltcp::iface::SocketSet) -> Result<(), SystemError> {
337         let timestamp: smoltcp::time::Instant = Instant::now().into();
338         let mut guard = self.iface.lock();
339         let poll_res = guard.poll(timestamp, self.driver.force_get_mut(), sockets);
340         // todo: notify!!!
341         // kdebug!("Virtio Interface poll:{poll_res}");
342         if poll_res {
343             return Ok(());
344         }
345         return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
346     }
347 
348     #[inline(always)]
349     fn inner_iface(&self) -> &SpinLock<smoltcp::iface::Interface> {
350         return &self.iface;
351     }
352     // fn as_any_ref(&'static self) -> &'static dyn core::any::Any {
353     //     return self;
354     // }
355 }
356 
357 impl<T: Transport + 'static> KObject for VirtioInterface<T> {
358     fn as_any_ref(&self) -> &dyn core::any::Any {
359         self
360     }
361 
362     fn set_inode(&self, _inode: Option<Arc<crate::filesystem::kernfs::KernFSInode>>) {
363         todo!()
364     }
365 
366     fn inode(&self) -> Option<Arc<crate::filesystem::kernfs::KernFSInode>> {
367         todo!()
368     }
369 
370     fn parent(&self) -> Option<alloc::sync::Weak<dyn KObject>> {
371         todo!()
372     }
373 
374     fn set_parent(&self, _parent: Option<alloc::sync::Weak<dyn KObject>>) {
375         todo!()
376     }
377 
378     fn kset(&self) -> Option<Arc<crate::driver::base::kset::KSet>> {
379         todo!()
380     }
381 
382     fn set_kset(&self, _kset: Option<Arc<crate::driver::base::kset::KSet>>) {
383         todo!()
384     }
385 
386     fn kobj_type(&self) -> Option<&'static dyn crate::driver::base::kobject::KObjType> {
387         todo!()
388     }
389 
390     fn name(&self) -> String {
391         self.name.clone()
392     }
393 
394     fn set_name(&self, _name: String) {
395         todo!()
396     }
397 
398     fn kobj_state(
399         &self,
400     ) -> crate::libs::rwlock::RwLockReadGuard<crate::driver::base::kobject::KObjectState> {
401         todo!()
402     }
403 
404     fn kobj_state_mut(
405         &self,
406     ) -> crate::libs::rwlock::RwLockWriteGuard<crate::driver::base::kobject::KObjectState> {
407         todo!()
408     }
409 
410     fn set_kobj_state(&self, _state: KObjectState) {
411         todo!()
412     }
413 
414     fn set_kobj_type(&self, _ktype: Option<&'static dyn KObjType>) {
415         todo!()
416     }
417 }
418 
419 // 向编译器保证,VirtioNICDriver在线程之间是安全的.
420 // 由于smoltcp只会在token内真正操作网卡设备,并且在VirtioNetToken的consume
421 // 方法内,会对VirtioNet进行加【写锁】,因此,能够保证对设备操作的的互斥访问,
422 // 因此VirtioNICDriver在线程之间是安全的。
423 // unsafe impl<T: Transport> Sync for VirtioNICDriver<T> {}
424 // unsafe impl<T: Transport> Send for VirtioNICDriver<T> {}
425