xref: /DragonOS/kernel/src/driver/virtio/virtio.rs (revision fae6e9ade46a52976ad5d099643d51cc20876448)
1 use super::mmio::virtio_probe_mmio;
2 use super::transport_pci::PciTransport;
3 use super::virtio_impl::HalImpl;
4 use crate::driver::base::device::bus::Bus;
5 use crate::driver::base::device::{Device, DeviceId};
6 use crate::driver::block::virtio_blk::virtio_blk;
7 use crate::driver::net::virtio_net::virtio_net;
8 use crate::driver::pci::pci::{
9     get_pci_device_structures_mut_by_vendor_id, PciDeviceStructure,
10     PciDeviceStructureGeneralDevice, PCI_DEVICE_LINKEDLIST,
11 };
12 use crate::driver::pci::subsys::pci_bus;
13 use crate::driver::virtio::transport::VirtIOTransport;
14 use crate::libs::rwlock::RwLockWriteGuard;
15 
16 use alloc::string::String;
17 use alloc::sync::Arc;
18 use alloc::vec::Vec;
19 use alloc::{boxed::Box, collections::LinkedList};
20 use log::{debug, error, warn};
21 use virtio_drivers::transport::{DeviceType, Transport};
22 
23 ///@brief 寻找并加载所有virtio设备的驱动(目前只有virtio-net,但其他virtio设备也可添加)
24 pub fn virtio_probe() {
25     #[cfg(not(target_arch = "riscv64"))]
26     virtio_probe_pci();
27     virtio_probe_mmio();
28 }
29 
30 #[allow(dead_code)]
31 fn virtio_probe_pci() {
32     let mut list = PCI_DEVICE_LINKEDLIST.write();
33     let virtio_list = virtio_device_search(&mut list);
34     for virtio_device in virtio_list {
35         let dev_id = virtio_device.common_header.device_id;
36         let dev_id = DeviceId::new(None, Some(format!("{dev_id}"))).unwrap();
37         match PciTransport::new::<HalImpl>(virtio_device, dev_id.clone()) {
38             Ok(mut transport) => {
39                 debug!(
40                     "Detected virtio PCI device with device type {:?}, features {:#018x}",
41                     transport.device_type(),
42                     transport.read_device_features(),
43                 );
44                 let transport = VirtIOTransport::Pci(transport);
45                 // 这里暂时通过设备名称在sysfs中查找设备,但是我感觉用设备ID更好
46                 let bus = pci_bus() as Arc<dyn Bus>;
47                 let name: String = virtio_device.common_header.bus_device_function.into();
48                 let pci_raw_device = bus.find_device_by_name(name.as_str());
49                 virtio_device_init(transport, dev_id, pci_raw_device);
50             }
51             Err(err) => {
52                 error!("Pci transport create failed because of error: {}", err);
53             }
54         }
55     }
56 }
57 
58 ///@brief 为virtio设备寻找对应的驱动进行初始化
59 pub(super) fn virtio_device_init(
60     transport: VirtIOTransport,
61     dev_id: Arc<DeviceId>,
62     dev_parent: Option<Arc<dyn Device>>,
63 ) {
64     match transport.device_type() {
65         DeviceType::Block => virtio_blk(transport, dev_id, dev_parent),
66         DeviceType::GPU => {
67             warn!("Not support virtio_gpu device for now");
68         }
69         DeviceType::Input => {
70             warn!("Not support virtio_input device for now");
71         }
72         DeviceType::Network => virtio_net(transport, dev_id, dev_parent),
73         t => {
74             warn!("Unrecognized virtio device: {:?}", t);
75         }
76     }
77 }
78 
79 /// # virtio_device_search - 在给定的PCI设备列表中搜索符合特定标准的virtio设备
80 ///
81 /// 该函数搜索一个PCI设备列表,找到所有由特定厂商ID(0x1AF4)和设备ID范围(0x1000至0x103F)定义的virtio设备。
82 ///
83 /// ## 参数
84 ///
85 /// - list: &'a mut RwLockWriteGuard<'_, LinkedList<Box<dyn PciDeviceStructure>>> - 一个可写的PCI设备结构列表的互斥锁。
86 ///
87 /// ## 返回值
88 ///
89 /// 返回一个包含所有找到的virtio设备的数组
90 fn virtio_device_search<'a>(
91     list: &'a mut RwLockWriteGuard<'_, LinkedList<Box<dyn PciDeviceStructure>>>,
92 ) -> Vec<&'a mut PciDeviceStructureGeneralDevice> {
93     let mut virtio_list = Vec::new();
94     let result = get_pci_device_structures_mut_by_vendor_id(list, 0x1AF4);
95 
96     for device in result {
97         let standard_device = device.as_standard_device_mut().unwrap();
98         let header = &standard_device.common_header;
99         if header.device_id >= 0x1000 && header.device_id <= 0x103F {
100             virtio_list.push(standard_device);
101         }
102     }
103 
104     return virtio_list;
105 }
106