xref: /DragonOS/kernel/src/driver/net/virtio_net.rs (revision 13776c114b15c406b1e0aaeeb71812ea6e471d2e)
1*13776c11Slogin use core::{
2*13776c11Slogin     cell::UnsafeCell,
3*13776c11Slogin     fmt::Debug,
4*13776c11Slogin     ops::{Deref, DerefMut},
5*13776c11Slogin };
6*13776c11Slogin 
7*13776c11Slogin use alloc::{string::String, sync::Arc};
8*13776c11Slogin use smoltcp::{phy, wire};
9*13776c11Slogin use virtio_drivers::{device::net::VirtIONet, transport::Transport};
10*13776c11Slogin 
11*13776c11Slogin use crate::{
12*13776c11Slogin     driver::{virtio::virtio_impl::HalImpl, Driver},
13*13776c11Slogin     kdebug, kerror, kinfo,
14*13776c11Slogin     libs::spinlock::SpinLock,
15*13776c11Slogin     net::{generate_iface_id, NET_DRIVERS},
16*13776c11Slogin     syscall::SystemError,
17*13776c11Slogin     time::Instant,
18*13776c11Slogin };
19*13776c11Slogin 
20*13776c11Slogin use super::NetDriver;
21*13776c11Slogin 
22*13776c11Slogin /// @brief Virtio网络设备驱动(加锁)
23*13776c11Slogin pub struct VirtioNICDriver<T: Transport> {
24*13776c11Slogin     pub inner: Arc<SpinLock<VirtIONet<HalImpl, T, 2>>>,
25*13776c11Slogin }
26*13776c11Slogin 
27*13776c11Slogin impl<T: Transport> Clone for VirtioNICDriver<T> {
28*13776c11Slogin     fn clone(&self) -> Self {
29*13776c11Slogin         return VirtioNICDriver {
30*13776c11Slogin             inner: self.inner.clone(),
31*13776c11Slogin         };
32*13776c11Slogin     }
33*13776c11Slogin }
34*13776c11Slogin 
35*13776c11Slogin /// @brief 网卡驱动的包裹器,这是为了获取网卡驱动的可变引用而设计的。
36*13776c11Slogin /// 由于smoltcp的设计,导致需要在poll的时候获取网卡驱动的可变引用,
37*13776c11Slogin /// 同时需要在token的consume里面获取可变引用。为了避免双重加锁,所以需要这个包裹器。
38*13776c11Slogin struct VirtioNICDriverWrapper<T: Transport>(UnsafeCell<VirtioNICDriver<T>>);
39*13776c11Slogin unsafe impl<T: Transport> Send for VirtioNICDriverWrapper<T> {}
40*13776c11Slogin unsafe impl<T: Transport> Sync for VirtioNICDriverWrapper<T> {}
41*13776c11Slogin 
42*13776c11Slogin impl<T: Transport> Deref for VirtioNICDriverWrapper<T> {
43*13776c11Slogin     type Target = VirtioNICDriver<T>;
44*13776c11Slogin     fn deref(&self) -> &Self::Target {
45*13776c11Slogin         unsafe { &*self.0.get() }
46*13776c11Slogin     }
47*13776c11Slogin }
48*13776c11Slogin impl<T: Transport> DerefMut for VirtioNICDriverWrapper<T> {
49*13776c11Slogin     fn deref_mut(&mut self) -> &mut Self::Target {
50*13776c11Slogin         unsafe { &mut *self.0.get() }
51*13776c11Slogin     }
52*13776c11Slogin }
53*13776c11Slogin 
54*13776c11Slogin impl<T: Transport> VirtioNICDriverWrapper<T> {
55*13776c11Slogin     fn force_get_mut(&self) -> &mut VirtioNICDriver<T> {
56*13776c11Slogin         unsafe { &mut *self.0.get() }
57*13776c11Slogin     }
58*13776c11Slogin }
59*13776c11Slogin 
60*13776c11Slogin impl<T: Transport> Debug for VirtioNICDriver<T> {
61*13776c11Slogin     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
62*13776c11Slogin         f.debug_struct("VirtioNICDriver").finish()
63*13776c11Slogin     }
64*13776c11Slogin }
65*13776c11Slogin 
66*13776c11Slogin pub struct VirtioInterface<T: Transport> {
67*13776c11Slogin     driver: VirtioNICDriverWrapper<T>,
68*13776c11Slogin     iface_id: usize,
69*13776c11Slogin     iface: SpinLock<smoltcp::iface::Interface>,
70*13776c11Slogin     name: String,
71*13776c11Slogin }
72*13776c11Slogin 
73*13776c11Slogin impl<T: Transport> Debug for VirtioInterface<T> {
74*13776c11Slogin     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
75*13776c11Slogin         f.debug_struct("VirtioInterface")
76*13776c11Slogin             .field("driver", self.driver.deref())
77*13776c11Slogin             .field("iface_id", &self.iface_id)
78*13776c11Slogin             .field("iface", &"smoltcp::iface::Interface")
79*13776c11Slogin             .field("name", &self.name)
80*13776c11Slogin             .finish()
81*13776c11Slogin     }
82*13776c11Slogin }
83*13776c11Slogin 
84*13776c11Slogin impl<T: Transport> VirtioInterface<T> {
85*13776c11Slogin     pub fn new(mut driver: VirtioNICDriver<T>) -> Arc<Self> {
86*13776c11Slogin         let iface_id = generate_iface_id();
87*13776c11Slogin         let mut iface_config = smoltcp::iface::Config::new();
88*13776c11Slogin 
89*13776c11Slogin         // todo: 随机设定这个值。
90*13776c11Slogin         // 参见 https://docs.rs/smoltcp/latest/smoltcp/iface/struct.Config.html#structfield.random_seed
91*13776c11Slogin         iface_config.random_seed = 12345;
92*13776c11Slogin 
93*13776c11Slogin         iface_config.hardware_addr = Some(wire::HardwareAddress::Ethernet(
94*13776c11Slogin             smoltcp::wire::EthernetAddress(driver.inner.lock().mac_address()),
95*13776c11Slogin         ));
96*13776c11Slogin         let iface = smoltcp::iface::Interface::new(iface_config, &mut driver);
97*13776c11Slogin 
98*13776c11Slogin         let driver: VirtioNICDriverWrapper<T> = VirtioNICDriverWrapper(UnsafeCell::new(driver));
99*13776c11Slogin         let result = Arc::new(VirtioInterface {
100*13776c11Slogin             driver,
101*13776c11Slogin             iface_id,
102*13776c11Slogin             iface: SpinLock::new(iface),
103*13776c11Slogin             name: format!("eth{}", iface_id),
104*13776c11Slogin         });
105*13776c11Slogin 
106*13776c11Slogin         return result;
107*13776c11Slogin     }
108*13776c11Slogin }
109*13776c11Slogin 
110*13776c11Slogin impl<T: 'static + Transport> VirtioNICDriver<T> {
111*13776c11Slogin     pub fn new(driver_net: VirtIONet<HalImpl, T, 2>) -> Self {
112*13776c11Slogin         let mut iface_config = smoltcp::iface::Config::new();
113*13776c11Slogin 
114*13776c11Slogin         // todo: 随机设定这个值。
115*13776c11Slogin         // 参见 https://docs.rs/smoltcp/latest/smoltcp/iface/struct.Config.html#structfield.random_seed
116*13776c11Slogin         iface_config.random_seed = 12345;
117*13776c11Slogin 
118*13776c11Slogin         iface_config.hardware_addr = Some(wire::HardwareAddress::Ethernet(
119*13776c11Slogin             smoltcp::wire::EthernetAddress(driver_net.mac_address()),
120*13776c11Slogin         ));
121*13776c11Slogin 
122*13776c11Slogin         let inner: Arc<SpinLock<VirtIONet<HalImpl, T, 2>>> = Arc::new(SpinLock::new(driver_net));
123*13776c11Slogin         let result = VirtioNICDriver { inner };
124*13776c11Slogin         return result;
125*13776c11Slogin     }
126*13776c11Slogin }
127*13776c11Slogin 
128*13776c11Slogin pub struct VirtioNetToken<T: Transport> {
129*13776c11Slogin     driver: VirtioNICDriver<T>,
130*13776c11Slogin     rx_buffer: Option<virtio_drivers::device::net::RxBuffer>,
131*13776c11Slogin }
132*13776c11Slogin 
133*13776c11Slogin impl<'a, T: Transport> VirtioNetToken<T> {
134*13776c11Slogin     pub fn new(
135*13776c11Slogin         driver: VirtioNICDriver<T>,
136*13776c11Slogin         rx_buffer: Option<virtio_drivers::device::net::RxBuffer>,
137*13776c11Slogin     ) -> Self {
138*13776c11Slogin         return Self { driver, rx_buffer };
139*13776c11Slogin     }
140*13776c11Slogin }
141*13776c11Slogin 
142*13776c11Slogin impl<T: Transport> phy::Device for VirtioNICDriver<T> {
143*13776c11Slogin     type RxToken<'a> = VirtioNetToken<T> where Self: 'a;
144*13776c11Slogin     type TxToken<'a> = VirtioNetToken<T> where Self: 'a;
145*13776c11Slogin 
146*13776c11Slogin     fn receive(
147*13776c11Slogin         &mut self,
148*13776c11Slogin         _timestamp: smoltcp::time::Instant,
149*13776c11Slogin     ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
150*13776c11Slogin         match self.inner.lock().receive() {
151*13776c11Slogin             Ok(buf) => Some((
152*13776c11Slogin                 VirtioNetToken::new(self.clone(), Some(buf)),
153*13776c11Slogin                 VirtioNetToken::new(self.clone(), None),
154*13776c11Slogin             )),
155*13776c11Slogin             Err(virtio_drivers::Error::NotReady) => None,
156*13776c11Slogin             Err(err) => panic!("VirtIO receive failed: {}", err),
157*13776c11Slogin         }
158*13776c11Slogin     }
159*13776c11Slogin 
160*13776c11Slogin     fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> {
161*13776c11Slogin         // kdebug!("VirtioNet: transmit");
162*13776c11Slogin         if self.inner.lock().can_send() {
163*13776c11Slogin             // kdebug!("VirtioNet: can send");
164*13776c11Slogin             return Some(VirtioNetToken::new(self.clone(), None));
165*13776c11Slogin         } else {
166*13776c11Slogin             // kdebug!("VirtioNet: can not send");
167*13776c11Slogin             return None;
168*13776c11Slogin         }
169*13776c11Slogin     }
170*13776c11Slogin 
171*13776c11Slogin     fn capabilities(&self) -> phy::DeviceCapabilities {
172*13776c11Slogin         let mut caps = phy::DeviceCapabilities::default();
173*13776c11Slogin         // 网卡的最大传输单元. 请与IP层的MTU进行区分。这个值应当是网卡的最大传输单元,而不是IP层的MTU。
174*13776c11Slogin         caps.max_transmission_unit = 2000;
175*13776c11Slogin         /*
176*13776c11Slogin            Maximum burst size, in terms of MTU.
177*13776c11Slogin            The network device is unable to send or receive bursts large than the value returned by this function.
178*13776c11Slogin            If None, there is no fixed limit on burst size, e.g. if network buffers are dynamically allocated.
179*13776c11Slogin         */
180*13776c11Slogin         caps.max_burst_size = Some(1);
181*13776c11Slogin         return caps;
182*13776c11Slogin     }
183*13776c11Slogin }
184*13776c11Slogin 
185*13776c11Slogin impl<T: Transport> phy::TxToken for VirtioNetToken<T> {
186*13776c11Slogin     fn consume<R, F>(self, len: usize, f: F) -> R
187*13776c11Slogin     where
188*13776c11Slogin         F: FnOnce(&mut [u8]) -> R,
189*13776c11Slogin     {
190*13776c11Slogin         // // 为了线程安全,这里需要对VirtioNet进行加【写锁】,以保证对设备的互斥访问。
191*13776c11Slogin 
192*13776c11Slogin         let mut driver_net = self.driver.inner.lock();
193*13776c11Slogin         let mut tx_buf = driver_net.new_tx_buffer(len);
194*13776c11Slogin         let result = f(tx_buf.packet_mut());
195*13776c11Slogin         driver_net.send(tx_buf).expect("virtio_net send failed");
196*13776c11Slogin         return result;
197*13776c11Slogin     }
198*13776c11Slogin }
199*13776c11Slogin 
200*13776c11Slogin impl<T: Transport> phy::RxToken for VirtioNetToken<T> {
201*13776c11Slogin     fn consume<R, F>(self, f: F) -> R
202*13776c11Slogin     where
203*13776c11Slogin         F: FnOnce(&mut [u8]) -> R,
204*13776c11Slogin     {
205*13776c11Slogin         // 为了线程安全,这里需要对VirtioNet进行加【写锁】,以保证对设备的互斥访问。
206*13776c11Slogin         let mut rx_buf = self.rx_buffer.unwrap();
207*13776c11Slogin         let result = f(rx_buf.packet_mut());
208*13776c11Slogin         self.driver
209*13776c11Slogin             .inner
210*13776c11Slogin             .lock()
211*13776c11Slogin             .recycle_rx_buffer(rx_buf)
212*13776c11Slogin             .expect("virtio_net recv failed");
213*13776c11Slogin         result
214*13776c11Slogin     }
215*13776c11Slogin }
216*13776c11Slogin 
217*13776c11Slogin /// @brief virtio-net 驱动的初始化与测试
218*13776c11Slogin pub fn virtio_net<T: Transport + 'static>(transport: T) {
219*13776c11Slogin     let driver_net: VirtIONet<HalImpl, T, 2> =
220*13776c11Slogin         match VirtIONet::<HalImpl, T, 2>::new(transport, 4096) {
221*13776c11Slogin             Ok(net) => net,
222*13776c11Slogin             Err(_) => {
223*13776c11Slogin                 kerror!("VirtIONet init failed");
224*13776c11Slogin                 return;
225*13776c11Slogin             }
226*13776c11Slogin         };
227*13776c11Slogin     let mac = smoltcp::wire::EthernetAddress::from_bytes(&driver_net.mac_address());
228*13776c11Slogin     let driver: VirtioNICDriver<T> = VirtioNICDriver::new(driver_net);
229*13776c11Slogin     let iface = VirtioInterface::new(driver);
230*13776c11Slogin     // 将网卡的接口信息注册到全局的网卡接口信息表中
231*13776c11Slogin     NET_DRIVERS.write().insert(iface.nic_id(), iface.clone());
232*13776c11Slogin     kinfo!(
233*13776c11Slogin         "Virtio-net driver init successfully!\tNetDevID: [{}], MAC: [{}]",
234*13776c11Slogin         iface.name(),
235*13776c11Slogin         mac
236*13776c11Slogin     );
237*13776c11Slogin }
238*13776c11Slogin 
239*13776c11Slogin impl<T: Transport> Driver for VirtioInterface<T> {
240*13776c11Slogin     fn as_any_ref(&'static self) -> &'static dyn core::any::Any {
241*13776c11Slogin         self
242*13776c11Slogin     }
243*13776c11Slogin }
244*13776c11Slogin 
245*13776c11Slogin impl<T: Transport> NetDriver for VirtioInterface<T> {
246*13776c11Slogin     fn mac(&self) -> smoltcp::wire::EthernetAddress {
247*13776c11Slogin         let mac: [u8; 6] = self.driver.inner.lock().mac_address();
248*13776c11Slogin         return smoltcp::wire::EthernetAddress::from_bytes(&mac);
249*13776c11Slogin     }
250*13776c11Slogin 
251*13776c11Slogin     #[inline]
252*13776c11Slogin     fn nic_id(&self) -> usize {
253*13776c11Slogin         return self.iface_id;
254*13776c11Slogin     }
255*13776c11Slogin 
256*13776c11Slogin     #[inline]
257*13776c11Slogin     fn name(&self) -> String {
258*13776c11Slogin         return self.name.clone();
259*13776c11Slogin     }
260*13776c11Slogin 
261*13776c11Slogin     fn update_ip_addrs(&self, ip_addrs: &[wire::IpCidr]) -> Result<(), SystemError> {
262*13776c11Slogin         if ip_addrs.len() != 1 {
263*13776c11Slogin             return Err(SystemError::EINVAL);
264*13776c11Slogin         }
265*13776c11Slogin 
266*13776c11Slogin         self.iface.lock().update_ip_addrs(|addrs| {
267*13776c11Slogin             let dest = addrs.iter_mut().next();
268*13776c11Slogin             if let None = dest {
269*13776c11Slogin                 addrs.push(ip_addrs[0]).expect("Push ipCidr failed: full");
270*13776c11Slogin             } else {
271*13776c11Slogin                 let dest = dest.unwrap();
272*13776c11Slogin                 *dest = ip_addrs[0];
273*13776c11Slogin             }
274*13776c11Slogin         });
275*13776c11Slogin         return Ok(());
276*13776c11Slogin     }
277*13776c11Slogin 
278*13776c11Slogin     fn poll(
279*13776c11Slogin         &self,
280*13776c11Slogin         sockets: &mut smoltcp::iface::SocketSet,
281*13776c11Slogin     ) -> Result<(), crate::syscall::SystemError> {
282*13776c11Slogin         let timestamp: smoltcp::time::Instant = Instant::now().into();
283*13776c11Slogin         let mut guard = self.iface.lock();
284*13776c11Slogin         let poll_res = guard.poll(timestamp, self.driver.force_get_mut(), sockets);
285*13776c11Slogin         // todo: notify!!!
286*13776c11Slogin         // kdebug!("Virtio Interface poll:{poll_res}");
287*13776c11Slogin         if poll_res {
288*13776c11Slogin             return Ok(());
289*13776c11Slogin         }
290*13776c11Slogin         return Err(SystemError::EAGAIN);
291*13776c11Slogin     }
292*13776c11Slogin 
293*13776c11Slogin     #[inline(always)]
294*13776c11Slogin     fn inner_iface(&self) -> &SpinLock<smoltcp::iface::Interface> {
295*13776c11Slogin         return &self.iface;
296*13776c11Slogin     }
297*13776c11Slogin     // fn as_any_ref(&'static self) -> &'static dyn core::any::Any {
298*13776c11Slogin     //     return self;
299*13776c11Slogin     // }
300*13776c11Slogin }
301*13776c11Slogin 
302*13776c11Slogin // 向编译器保证,VirtioNICDriver在线程之间是安全的.
303*13776c11Slogin // 由于smoltcp只会在token内真正操作网卡设备,并且在VirtioNetToken的consume
304*13776c11Slogin // 方法内,会对VirtioNet进行加【写锁】,因此,能够保证对设备操作的的互斥访问,
305*13776c11Slogin // 因此VirtioNICDriver在线程之间是安全的。
306*13776c11Slogin // unsafe impl<T: Transport> Sync for VirtioNICDriver<T> {}
307*13776c11Slogin // unsafe impl<T: Transport> Send for VirtioNICDriver<T> {}
308