xref: /DragonOS/kernel/src/driver/net/virtio_net.rs (revision 7cc4a02c7ff7bafd798b185beb7b0c2986b9f32f)
1 use core::{
2     cell::UnsafeCell,
3     fmt::Debug,
4     ops::{Deref, DerefMut},
5 };
6 
7 use alloc::{string::String, sync::Arc};
8 use smoltcp::{phy, wire};
9 use virtio_drivers::{device::net::VirtIONet, transport::Transport};
10 
11 use crate::{
12     driver::{virtio::virtio_impl::HalImpl, Driver},
13     kerror, kinfo,
14     libs::spinlock::SpinLock,
15     net::{generate_iface_id, NET_DRIVERS},
16     syscall::SystemError,
17     time::Instant,
18 };
19 
20 use super::NetDriver;
21 
22 /// @brief Virtio网络设备驱动(加锁)
23 pub struct VirtioNICDriver<T: Transport> {
24     pub inner: Arc<SpinLock<VirtIONet<HalImpl, T, 2>>>,
25 }
26 
27 impl<T: Transport> Clone for VirtioNICDriver<T> {
28     fn clone(&self) -> Self {
29         return VirtioNICDriver {
30             inner: self.inner.clone(),
31         };
32     }
33 }
34 
35 /// @brief 网卡驱动的包裹器,这是为了获取网卡驱动的可变引用而设计的。
36 /// 由于smoltcp的设计,导致需要在poll的时候获取网卡驱动的可变引用,
37 /// 同时需要在token的consume里面获取可变引用。为了避免双重加锁,所以需要这个包裹器。
38 struct VirtioNICDriverWrapper<T: Transport>(UnsafeCell<VirtioNICDriver<T>>);
39 unsafe impl<T: Transport> Send for VirtioNICDriverWrapper<T> {}
40 unsafe impl<T: Transport> Sync for VirtioNICDriverWrapper<T> {}
41 
42 impl<T: Transport> Deref for VirtioNICDriverWrapper<T> {
43     type Target = VirtioNICDriver<T>;
44     fn deref(&self) -> &Self::Target {
45         unsafe { &*self.0.get() }
46     }
47 }
48 impl<T: Transport> DerefMut for VirtioNICDriverWrapper<T> {
49     fn deref_mut(&mut self) -> &mut Self::Target {
50         unsafe { &mut *self.0.get() }
51     }
52 }
53 
54 impl<T: Transport> VirtioNICDriverWrapper<T> {
55     fn force_get_mut(&self) -> &mut VirtioNICDriver<T> {
56         unsafe { &mut *self.0.get() }
57     }
58 }
59 
60 impl<T: Transport> Debug for VirtioNICDriver<T> {
61     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
62         f.debug_struct("VirtioNICDriver").finish()
63     }
64 }
65 
66 pub struct VirtioInterface<T: Transport> {
67     driver: VirtioNICDriverWrapper<T>,
68     iface_id: usize,
69     iface: SpinLock<smoltcp::iface::Interface>,
70     name: String,
71 }
72 
73 impl<T: Transport> Debug for VirtioInterface<T> {
74     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
75         f.debug_struct("VirtioInterface")
76             .field("driver", self.driver.deref())
77             .field("iface_id", &self.iface_id)
78             .field("iface", &"smoltcp::iface::Interface")
79             .field("name", &self.name)
80             .finish()
81     }
82 }
83 
84 impl<T: Transport> VirtioInterface<T> {
85     pub fn new(mut driver: VirtioNICDriver<T>) -> Arc<Self> {
86         let iface_id = generate_iface_id();
87         let mut iface_config = smoltcp::iface::Config::new();
88 
89         // todo: 随机设定这个值。
90         // 参见 https://docs.rs/smoltcp/latest/smoltcp/iface/struct.Config.html#structfield.random_seed
91         iface_config.random_seed = 12345;
92 
93         iface_config.hardware_addr = Some(wire::HardwareAddress::Ethernet(
94             smoltcp::wire::EthernetAddress(driver.inner.lock().mac_address()),
95         ));
96         let iface = smoltcp::iface::Interface::new(iface_config, &mut driver);
97 
98         let driver: VirtioNICDriverWrapper<T> = VirtioNICDriverWrapper(UnsafeCell::new(driver));
99         let result = Arc::new(VirtioInterface {
100             driver,
101             iface_id,
102             iface: SpinLock::new(iface),
103             name: format!("eth{}", iface_id),
104         });
105 
106         return result;
107     }
108 }
109 
110 impl<T: 'static + Transport> VirtioNICDriver<T> {
111     pub fn new(driver_net: VirtIONet<HalImpl, T, 2>) -> Self {
112         let mut iface_config = smoltcp::iface::Config::new();
113 
114         // todo: 随机设定这个值。
115         // 参见 https://docs.rs/smoltcp/latest/smoltcp/iface/struct.Config.html#structfield.random_seed
116         iface_config.random_seed = 12345;
117 
118         iface_config.hardware_addr = Some(wire::HardwareAddress::Ethernet(
119             smoltcp::wire::EthernetAddress(driver_net.mac_address()),
120         ));
121 
122         let inner: Arc<SpinLock<VirtIONet<HalImpl, T, 2>>> = Arc::new(SpinLock::new(driver_net));
123         let result = VirtioNICDriver { inner };
124         return result;
125     }
126 }
127 
128 pub struct VirtioNetToken<T: Transport> {
129     driver: VirtioNICDriver<T>,
130     rx_buffer: Option<virtio_drivers::device::net::RxBuffer>,
131 }
132 
133 impl<'a, T: Transport> VirtioNetToken<T> {
134     pub fn new(
135         driver: VirtioNICDriver<T>,
136         rx_buffer: Option<virtio_drivers::device::net::RxBuffer>,
137     ) -> Self {
138         return Self { driver, rx_buffer };
139     }
140 }
141 
142 impl<T: Transport> phy::Device for VirtioNICDriver<T> {
143     type RxToken<'a> = VirtioNetToken<T> where Self: 'a;
144     type TxToken<'a> = VirtioNetToken<T> where Self: 'a;
145 
146     fn receive(
147         &mut self,
148         _timestamp: smoltcp::time::Instant,
149     ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
150         match self.inner.lock().receive() {
151             Ok(buf) => Some((
152                 VirtioNetToken::new(self.clone(), Some(buf)),
153                 VirtioNetToken::new(self.clone(), None),
154             )),
155             Err(virtio_drivers::Error::NotReady) => None,
156             Err(err) => panic!("VirtIO receive failed: {}", err),
157         }
158     }
159 
160     fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> {
161         // kdebug!("VirtioNet: transmit");
162         if self.inner.lock().can_send() {
163             // kdebug!("VirtioNet: can send");
164             return Some(VirtioNetToken::new(self.clone(), None));
165         } else {
166             // kdebug!("VirtioNet: can not send");
167             return None;
168         }
169     }
170 
171     fn capabilities(&self) -> phy::DeviceCapabilities {
172         let mut caps = phy::DeviceCapabilities::default();
173         // 网卡的最大传输单元. 请与IP层的MTU进行区分。这个值应当是网卡的最大传输单元,而不是IP层的MTU。
174         caps.max_transmission_unit = 2000;
175         /*
176            Maximum burst size, in terms of MTU.
177            The network device is unable to send or receive bursts large than the value returned by this function.
178            If None, there is no fixed limit on burst size, e.g. if network buffers are dynamically allocated.
179         */
180         caps.max_burst_size = Some(1);
181         return caps;
182     }
183 }
184 
185 impl<T: Transport> phy::TxToken for VirtioNetToken<T> {
186     fn consume<R, F>(self, len: usize, f: F) -> R
187     where
188         F: FnOnce(&mut [u8]) -> R,
189     {
190         // // 为了线程安全,这里需要对VirtioNet进行加【写锁】,以保证对设备的互斥访问。
191 
192         let mut driver_net = self.driver.inner.lock();
193         let mut tx_buf = driver_net.new_tx_buffer(len);
194         let result = f(tx_buf.packet_mut());
195         driver_net.send(tx_buf).expect("virtio_net send failed");
196         return result;
197     }
198 }
199 
200 impl<T: Transport> phy::RxToken for VirtioNetToken<T> {
201     fn consume<R, F>(self, f: F) -> R
202     where
203         F: FnOnce(&mut [u8]) -> R,
204     {
205         // 为了线程安全,这里需要对VirtioNet进行加【写锁】,以保证对设备的互斥访问。
206         let mut rx_buf = self.rx_buffer.unwrap();
207         let result = f(rx_buf.packet_mut());
208         self.driver
209             .inner
210             .lock()
211             .recycle_rx_buffer(rx_buf)
212             .expect("virtio_net recv failed");
213         result
214     }
215 }
216 
217 /// @brief virtio-net 驱动的初始化与测试
218 pub fn virtio_net<T: Transport + 'static>(transport: T) {
219     let driver_net: VirtIONet<HalImpl, T, 2> =
220         match VirtIONet::<HalImpl, T, 2>::new(transport, 4096) {
221             Ok(net) => net,
222             Err(_) => {
223                 kerror!("VirtIONet init failed");
224                 return;
225             }
226         };
227     let mac = smoltcp::wire::EthernetAddress::from_bytes(&driver_net.mac_address());
228     let driver: VirtioNICDriver<T> = VirtioNICDriver::new(driver_net);
229     let iface = VirtioInterface::new(driver);
230     // 将网卡的接口信息注册到全局的网卡接口信息表中
231     NET_DRIVERS.write().insert(iface.nic_id(), iface.clone());
232     kinfo!(
233         "Virtio-net driver init successfully!\tNetDevID: [{}], MAC: [{}]",
234         iface.name(),
235         mac
236     );
237 }
238 
239 impl<T: Transport> Driver for VirtioInterface<T> {
240     fn as_any_ref(&'static self) -> &'static dyn core::any::Any {
241         self
242     }
243 }
244 
245 impl<T: Transport> NetDriver for VirtioInterface<T> {
246     fn mac(&self) -> smoltcp::wire::EthernetAddress {
247         let mac: [u8; 6] = self.driver.inner.lock().mac_address();
248         return smoltcp::wire::EthernetAddress::from_bytes(&mac);
249     }
250 
251     #[inline]
252     fn nic_id(&self) -> usize {
253         return self.iface_id;
254     }
255 
256     #[inline]
257     fn name(&self) -> String {
258         return self.name.clone();
259     }
260 
261     fn update_ip_addrs(&self, ip_addrs: &[wire::IpCidr]) -> Result<(), SystemError> {
262         if ip_addrs.len() != 1 {
263             return Err(SystemError::EINVAL);
264         }
265 
266         self.iface.lock().update_ip_addrs(|addrs| {
267             let dest = addrs.iter_mut().next();
268             if let None = dest {
269                 addrs.push(ip_addrs[0]).expect("Push ipCidr failed: full");
270             } else {
271                 let dest = dest.unwrap();
272                 *dest = ip_addrs[0];
273             }
274         });
275         return Ok(());
276     }
277 
278     fn poll(
279         &self,
280         sockets: &mut smoltcp::iface::SocketSet,
281     ) -> Result<(), crate::syscall::SystemError> {
282         let timestamp: smoltcp::time::Instant = Instant::now().into();
283         let mut guard = self.iface.lock();
284         let poll_res = guard.poll(timestamp, self.driver.force_get_mut(), sockets);
285         // todo: notify!!!
286         // kdebug!("Virtio Interface poll:{poll_res}");
287         if poll_res {
288             return Ok(());
289         }
290         return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
291     }
292 
293     #[inline(always)]
294     fn inner_iface(&self) -> &SpinLock<smoltcp::iface::Interface> {
295         return &self.iface;
296     }
297     // fn as_any_ref(&'static self) -> &'static dyn core::any::Any {
298     //     return self;
299     // }
300 }
301 
302 // 向编译器保证,VirtioNICDriver在线程之间是安全的.
303 // 由于smoltcp只会在token内真正操作网卡设备,并且在VirtioNetToken的consume
304 // 方法内,会对VirtioNet进行加【写锁】,因此,能够保证对设备操作的的互斥访问,
305 // 因此VirtioNICDriver在线程之间是安全的。
306 // unsafe impl<T: Transport> Sync for VirtioNICDriver<T> {}
307 // unsafe impl<T: Transport> Send for VirtioNICDriver<T> {}
308