xref: /DragonOS/kernel/src/driver/net/virtio_net.rs (revision fba5623183378da6b120caafca120615328efa2e)
1 use core::{
2     cell::UnsafeCell,
3     fmt::Debug,
4     ops::{Deref, DerefMut},
5 };
6 
7 use alloc::{string::String, sync::Arc};
8 use smoltcp::{phy, wire};
9 use virtio_drivers::{device::net::VirtIONet, transport::Transport};
10 
11 use crate::{
12     driver::{
13         base::device::{driver::DriverError, Device, DevicePrivateData, DeviceResource, IdTable},
14         virtio::virtio_impl::HalImpl,
15         Driver,
16     },
17     kerror, kinfo,
18     libs::spinlock::SpinLock,
19     net::{generate_iface_id, NET_DRIVERS},
20     syscall::SystemError,
21     time::Instant,
22 };
23 
24 use super::NetDriver;
25 
26 /// @brief Virtio网络设备驱动(加锁)
27 pub struct VirtioNICDriver<T: Transport> {
28     pub inner: Arc<SpinLock<VirtIONet<HalImpl, T, 2>>>,
29 }
30 
31 impl<T: Transport> Clone for VirtioNICDriver<T> {
32     fn clone(&self) -> Self {
33         return VirtioNICDriver {
34             inner: self.inner.clone(),
35         };
36     }
37 }
38 
39 /// @brief 网卡驱动的包裹器,这是为了获取网卡驱动的可变引用而设计的。
40 /// 由于smoltcp的设计,导致需要在poll的时候获取网卡驱动的可变引用,
41 /// 同时需要在token的consume里面获取可变引用。为了避免双重加锁,所以需要这个包裹器。
42 struct VirtioNICDriverWrapper<T: Transport>(UnsafeCell<VirtioNICDriver<T>>);
43 unsafe impl<T: Transport> Send for VirtioNICDriverWrapper<T> {}
44 unsafe impl<T: Transport> Sync for VirtioNICDriverWrapper<T> {}
45 
46 impl<T: Transport> Deref for VirtioNICDriverWrapper<T> {
47     type Target = VirtioNICDriver<T>;
48     fn deref(&self) -> &Self::Target {
49         unsafe { &*self.0.get() }
50     }
51 }
52 impl<T: Transport> DerefMut for VirtioNICDriverWrapper<T> {
53     fn deref_mut(&mut self) -> &mut Self::Target {
54         unsafe { &mut *self.0.get() }
55     }
56 }
57 
58 impl<T: Transport> VirtioNICDriverWrapper<T> {
59     fn force_get_mut(&self) -> &mut VirtioNICDriver<T> {
60         unsafe { &mut *self.0.get() }
61     }
62 }
63 
64 impl<T: Transport> Debug for VirtioNICDriver<T> {
65     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
66         f.debug_struct("VirtioNICDriver").finish()
67     }
68 }
69 
70 pub struct VirtioInterface<T: Transport> {
71     driver: VirtioNICDriverWrapper<T>,
72     iface_id: usize,
73     iface: SpinLock<smoltcp::iface::Interface>,
74     name: String,
75 }
76 
77 impl<T: Transport> Debug for VirtioInterface<T> {
78     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
79         f.debug_struct("VirtioInterface")
80             .field("driver", self.driver.deref())
81             .field("iface_id", &self.iface_id)
82             .field("iface", &"smoltcp::iface::Interface")
83             .field("name", &self.name)
84             .finish()
85     }
86 }
87 
88 impl<T: Transport> VirtioInterface<T> {
89     pub fn new(mut driver: VirtioNICDriver<T>) -> Arc<Self> {
90         let iface_id = generate_iface_id();
91         let mut iface_config = smoltcp::iface::Config::new();
92 
93         // todo: 随机设定这个值。
94         // 参见 https://docs.rs/smoltcp/latest/smoltcp/iface/struct.Config.html#structfield.random_seed
95         iface_config.random_seed = 12345;
96 
97         iface_config.hardware_addr = Some(wire::HardwareAddress::Ethernet(
98             smoltcp::wire::EthernetAddress(driver.inner.lock().mac_address()),
99         ));
100         let iface = smoltcp::iface::Interface::new(iface_config, &mut driver);
101 
102         let driver: VirtioNICDriverWrapper<T> = VirtioNICDriverWrapper(UnsafeCell::new(driver));
103         let result = Arc::new(VirtioInterface {
104             driver,
105             iface_id,
106             iface: SpinLock::new(iface),
107             name: format!("eth{}", iface_id),
108         });
109 
110         return result;
111     }
112 }
113 
114 impl<T: 'static + Transport> VirtioNICDriver<T> {
115     pub fn new(driver_net: VirtIONet<HalImpl, T, 2>) -> Self {
116         let mut iface_config = smoltcp::iface::Config::new();
117 
118         // todo: 随机设定这个值。
119         // 参见 https://docs.rs/smoltcp/latest/smoltcp/iface/struct.Config.html#structfield.random_seed
120         iface_config.random_seed = 12345;
121 
122         iface_config.hardware_addr = Some(wire::HardwareAddress::Ethernet(
123             smoltcp::wire::EthernetAddress(driver_net.mac_address()),
124         ));
125 
126         let inner: Arc<SpinLock<VirtIONet<HalImpl, T, 2>>> = Arc::new(SpinLock::new(driver_net));
127         let result = VirtioNICDriver { inner };
128         return result;
129     }
130 }
131 
132 pub struct VirtioNetToken<T: Transport> {
133     driver: VirtioNICDriver<T>,
134     rx_buffer: Option<virtio_drivers::device::net::RxBuffer>,
135 }
136 
137 impl<'a, T: Transport> VirtioNetToken<T> {
138     pub fn new(
139         driver: VirtioNICDriver<T>,
140         rx_buffer: Option<virtio_drivers::device::net::RxBuffer>,
141     ) -> Self {
142         return Self { driver, rx_buffer };
143     }
144 }
145 
146 impl<T: Transport> phy::Device for VirtioNICDriver<T> {
147     type RxToken<'a> = VirtioNetToken<T> where Self: 'a;
148     type TxToken<'a> = VirtioNetToken<T> where Self: 'a;
149 
150     fn receive(
151         &mut self,
152         _timestamp: smoltcp::time::Instant,
153     ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
154         match self.inner.lock().receive() {
155             Ok(buf) => Some((
156                 VirtioNetToken::new(self.clone(), Some(buf)),
157                 VirtioNetToken::new(self.clone(), None),
158             )),
159             Err(virtio_drivers::Error::NotReady) => None,
160             Err(err) => panic!("VirtIO receive failed: {}", err),
161         }
162     }
163 
164     fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> {
165         // kdebug!("VirtioNet: transmit");
166         if self.inner.lock().can_send() {
167             // kdebug!("VirtioNet: can send");
168             return Some(VirtioNetToken::new(self.clone(), None));
169         } else {
170             // kdebug!("VirtioNet: can not send");
171             return None;
172         }
173     }
174 
175     fn capabilities(&self) -> phy::DeviceCapabilities {
176         let mut caps = phy::DeviceCapabilities::default();
177         // 网卡的最大传输单元. 请与IP层的MTU进行区分。这个值应当是网卡的最大传输单元,而不是IP层的MTU。
178         caps.max_transmission_unit = 2000;
179         /*
180            Maximum burst size, in terms of MTU.
181            The network device is unable to send or receive bursts large than the value returned by this function.
182            If None, there is no fixed limit on burst size, e.g. if network buffers are dynamically allocated.
183         */
184         caps.max_burst_size = Some(1);
185         return caps;
186     }
187 }
188 
189 impl<T: Transport> phy::TxToken for VirtioNetToken<T> {
190     fn consume<R, F>(self, len: usize, f: F) -> R
191     where
192         F: FnOnce(&mut [u8]) -> R,
193     {
194         // // 为了线程安全,这里需要对VirtioNet进行加【写锁】,以保证对设备的互斥访问。
195 
196         let mut driver_net = self.driver.inner.lock();
197         let mut tx_buf = driver_net.new_tx_buffer(len);
198         let result = f(tx_buf.packet_mut());
199         driver_net.send(tx_buf).expect("virtio_net send failed");
200         return result;
201     }
202 }
203 
204 impl<T: Transport> phy::RxToken for VirtioNetToken<T> {
205     fn consume<R, F>(self, f: F) -> R
206     where
207         F: FnOnce(&mut [u8]) -> R,
208     {
209         // 为了线程安全,这里需要对VirtioNet进行加【写锁】,以保证对设备的互斥访问。
210         let mut rx_buf = self.rx_buffer.unwrap();
211         let result = f(rx_buf.packet_mut());
212         self.driver
213             .inner
214             .lock()
215             .recycle_rx_buffer(rx_buf)
216             .expect("virtio_net recv failed");
217         result
218     }
219 }
220 
221 /// @brief virtio-net 驱动的初始化与测试
222 pub fn virtio_net<T: Transport + 'static>(transport: T) {
223     let driver_net: VirtIONet<HalImpl, T, 2> =
224         match VirtIONet::<HalImpl, T, 2>::new(transport, 4096) {
225             Ok(net) => net,
226             Err(_) => {
227                 kerror!("VirtIONet init failed");
228                 return;
229             }
230         };
231     let mac = smoltcp::wire::EthernetAddress::from_bytes(&driver_net.mac_address());
232     let driver: VirtioNICDriver<T> = VirtioNICDriver::new(driver_net);
233     let iface = VirtioInterface::new(driver);
234     // 将网卡的接口信息注册到全局的网卡接口信息表中
235     NET_DRIVERS.write().insert(iface.nic_id(), iface.clone());
236     kinfo!(
237         "Virtio-net driver init successfully!\tNetDevID: [{}], MAC: [{}]",
238         iface.name(),
239         mac
240     );
241 }
242 
243 impl<T: Transport> Driver for VirtioInterface<T> {
244     fn as_any_ref(&'static self) -> &'static dyn core::any::Any {
245         self
246     }
247 
248     fn probe(&self, _data: &DevicePrivateData) -> Result<(), DriverError> {
249         todo!()
250     }
251 
252     fn load(
253         &self,
254         _data: DevicePrivateData,
255         _resource: Option<DeviceResource>,
256     ) -> Result<Arc<dyn Device>, DriverError> {
257         todo!()
258     }
259 
260     fn id_table(&self) -> IdTable {
261         todo!()
262     }
263 
264     fn set_sys_info(&self, _sys_info: Option<Arc<dyn crate::filesystem::vfs::IndexNode>>) {
265         todo!()
266     }
267 
268     fn sys_info(&self) -> Option<Arc<dyn crate::filesystem::vfs::IndexNode>> {
269         todo!()
270     }
271 }
272 
273 impl<T: Transport> NetDriver for VirtioInterface<T> {
274     fn mac(&self) -> smoltcp::wire::EthernetAddress {
275         let mac: [u8; 6] = self.driver.inner.lock().mac_address();
276         return smoltcp::wire::EthernetAddress::from_bytes(&mac);
277     }
278 
279     #[inline]
280     fn nic_id(&self) -> usize {
281         return self.iface_id;
282     }
283 
284     #[inline]
285     fn name(&self) -> String {
286         return self.name.clone();
287     }
288 
289     fn update_ip_addrs(&self, ip_addrs: &[wire::IpCidr]) -> Result<(), SystemError> {
290         if ip_addrs.len() != 1 {
291             return Err(SystemError::EINVAL);
292         }
293 
294         self.iface.lock().update_ip_addrs(|addrs| {
295             let dest = addrs.iter_mut().next();
296             if let None = dest {
297                 addrs.push(ip_addrs[0]).expect("Push ipCidr failed: full");
298             } else {
299                 let dest = dest.unwrap();
300                 *dest = ip_addrs[0];
301             }
302         });
303         return Ok(());
304     }
305 
306     fn poll(
307         &self,
308         sockets: &mut smoltcp::iface::SocketSet,
309     ) -> Result<(), crate::syscall::SystemError> {
310         let timestamp: smoltcp::time::Instant = Instant::now().into();
311         let mut guard = self.iface.lock();
312         let poll_res = guard.poll(timestamp, self.driver.force_get_mut(), sockets);
313         // todo: notify!!!
314         // kdebug!("Virtio Interface poll:{poll_res}");
315         if poll_res {
316             return Ok(());
317         }
318         return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
319     }
320 
321     #[inline(always)]
322     fn inner_iface(&self) -> &SpinLock<smoltcp::iface::Interface> {
323         return &self.iface;
324     }
325     // fn as_any_ref(&'static self) -> &'static dyn core::any::Any {
326     //     return self;
327     // }
328 }
329 
330 // 向编译器保证,VirtioNICDriver在线程之间是安全的.
331 // 由于smoltcp只会在token内真正操作网卡设备,并且在VirtioNetToken的consume
332 // 方法内,会对VirtioNet进行加【写锁】,因此,能够保证对设备操作的的互斥访问,
333 // 因此VirtioNICDriver在线程之间是安全的。
334 // unsafe impl<T: Transport> Sync for VirtioNICDriver<T> {}
335 // unsafe impl<T: Transport> Send for VirtioNICDriver<T> {}
336