xref: /DragonOS/kernel/src/driver/block/virtio_blk.rs (revision 4dd4856f933be0b4624c7f7ffa9e3d0c8c218873)
1 use core::{any::Any, fmt::Debug};
2 
3 use alloc::{
4     collections::LinkedList,
5     string::{String, ToString},
6     sync::{Arc, Weak},
7     vec::Vec,
8 };
9 use bitmap::traits::BitMapOps;
10 use log::error;
11 use system_error::SystemError;
12 use unified_init::macros::unified_init;
13 use virtio_drivers::device::blk::{VirtIOBlk, SECTOR_SIZE};
14 
15 use crate::{
16     driver::{
17         base::{
18             block::{
19                 block_device::{BlockDevName, BlockDevice, BlockId, GeneralBlockRange, LBA_SIZE},
20                 disk_info::Partition,
21                 manager::{block_dev_manager, BlockDevMeta},
22             },
23             class::Class,
24             device::{
25                 bus::Bus,
26                 driver::{Driver, DriverCommonData},
27                 Device, DeviceCommonData, DeviceId, DeviceType, IdTable,
28             },
29             kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState},
30             kset::KSet,
31         },
32         virtio::{
33             sysfs::{virtio_bus, virtio_device_manager, virtio_driver_manager},
34             transport::VirtIOTransport,
35             virtio_impl::HalImpl,
36             VirtIODevice, VirtIODeviceIndex, VirtIODriver, VirtIODriverCommonData, VirtioDeviceId,
37             VIRTIO_VENDOR_ID,
38         },
39     },
40     exception::{irqdesc::IrqReturn, IrqNumber},
41     filesystem::{kernfs::KernFSInode, mbr::MbrDiskPartionTable},
42     init::initcall::INITCALL_POSTCORE,
43     libs::{
44         rwlock::{RwLockReadGuard, RwLockWriteGuard},
45         spinlock::{SpinLock, SpinLockGuard},
46     },
47 };
48 
49 const VIRTIO_BLK_BASENAME: &str = "virtio_blk";
50 
51 static mut VIRTIO_BLK_DRIVER: Option<Arc<VirtIOBlkDriver>> = None;
52 
53 #[inline(always)]
54 #[allow(dead_code)]
55 fn virtio_blk_driver() -> Arc<VirtIOBlkDriver> {
56     unsafe { VIRTIO_BLK_DRIVER.as_ref().unwrap().clone() }
57 }
58 
59 /// Get the first virtio block device
60 #[allow(dead_code)]
61 pub fn virtio_blk_0() -> Option<Arc<VirtIOBlkDevice>> {
62     virtio_blk_driver()
63         .devices()
64         .first()
65         .cloned()
66         .map(|dev| dev.arc_any().downcast().unwrap())
67 }
68 
69 pub fn virtio_blk(
70     transport: VirtIOTransport,
71     dev_id: Arc<DeviceId>,
72     dev_parent: Option<Arc<dyn Device>>,
73 ) {
74     let device = VirtIOBlkDevice::new(transport, dev_id);
75     if let Some(device) = device {
76         if let Some(dev_parent) = dev_parent {
77             device.set_dev_parent(Some(Arc::downgrade(&dev_parent)));
78         }
79         virtio_device_manager()
80             .device_add(device.clone() as Arc<dyn VirtIODevice>)
81             .expect("Add virtio blk failed");
82     }
83 }
84 
85 static mut VIRTIOBLK_MANAGER: Option<VirtIOBlkManager> = None;
86 
87 #[inline]
88 fn virtioblk_manager() -> &'static VirtIOBlkManager {
89     unsafe { VIRTIOBLK_MANAGER.as_ref().unwrap() }
90 }
91 
92 #[unified_init(INITCALL_POSTCORE)]
93 fn virtioblk_manager_init() -> Result<(), SystemError> {
94     unsafe {
95         VIRTIOBLK_MANAGER = Some(VirtIOBlkManager::new());
96     }
97     Ok(())
98 }
99 
100 pub struct VirtIOBlkManager {
101     inner: SpinLock<InnerVirtIOBlkManager>,
102 }
103 
104 struct InnerVirtIOBlkManager {
105     id_bmp: bitmap::StaticBitmap<{ VirtIOBlkManager::MAX_DEVICES }>,
106     devname: [Option<BlockDevName>; VirtIOBlkManager::MAX_DEVICES],
107 }
108 
109 impl VirtIOBlkManager {
110     pub const MAX_DEVICES: usize = 25;
111 
112     pub fn new() -> Self {
113         Self {
114             inner: SpinLock::new(InnerVirtIOBlkManager {
115                 id_bmp: bitmap::StaticBitmap::new(),
116                 devname: [const { None }; Self::MAX_DEVICES],
117             }),
118         }
119     }
120 
121     fn inner(&self) -> SpinLockGuard<InnerVirtIOBlkManager> {
122         self.inner.lock()
123     }
124 
125     pub fn alloc_id(&self) -> Option<BlockDevName> {
126         let mut inner = self.inner();
127         let idx = inner.id_bmp.first_false_index()?;
128         inner.id_bmp.set(idx, true);
129         let name = Self::format_name(idx);
130         inner.devname[idx] = Some(name.clone());
131         Some(name)
132     }
133 
134     /// Generate a new block device name like 'vda', 'vdb', etc.
135     fn format_name(id: usize) -> BlockDevName {
136         let x = (b'a' + id as u8) as char;
137         BlockDevName::new(format!("vd{}", x), id)
138     }
139 
140     pub fn free_id(&self, id: usize) {
141         if id >= Self::MAX_DEVICES {
142             return;
143         }
144         self.inner().id_bmp.set(id, false);
145         self.inner().devname[id] = None;
146     }
147 }
148 
149 /// virtio block device
150 #[derive(Debug)]
151 #[cast_to([sync] VirtIODevice)]
152 #[cast_to([sync] Device)]
153 pub struct VirtIOBlkDevice {
154     blkdev_meta: BlockDevMeta,
155     dev_id: Arc<DeviceId>,
156     inner: SpinLock<InnerVirtIOBlkDevice>,
157     locked_kobj_state: LockedKObjectState,
158     self_ref: Weak<Self>,
159 }
160 
161 unsafe impl Send for VirtIOBlkDevice {}
162 unsafe impl Sync for VirtIOBlkDevice {}
163 
164 impl VirtIOBlkDevice {
165     pub fn new(transport: VirtIOTransport, dev_id: Arc<DeviceId>) -> Option<Arc<Self>> {
166         let devname = virtioblk_manager().alloc_id()?;
167         let irq = transport.irq().map(|irq| IrqNumber::new(irq.data()));
168         let device_inner = VirtIOBlk::<HalImpl, VirtIOTransport>::new(transport);
169         if let Err(e) = device_inner {
170             error!("VirtIOBlkDevice '{dev_id:?}' create failed: {:?}", e);
171             return None;
172         }
173 
174         let mut device_inner: VirtIOBlk<HalImpl, VirtIOTransport> = device_inner.unwrap();
175         device_inner.enable_interrupts();
176         let dev = Arc::new_cyclic(|self_ref| Self {
177             blkdev_meta: BlockDevMeta::new(devname),
178             self_ref: self_ref.clone(),
179             dev_id,
180             locked_kobj_state: LockedKObjectState::default(),
181             inner: SpinLock::new(InnerVirtIOBlkDevice {
182                 device_inner,
183                 name: None,
184                 virtio_index: None,
185                 device_common: DeviceCommonData::default(),
186                 kobject_common: KObjectCommonData::default(),
187                 irq,
188             }),
189         });
190 
191         Some(dev)
192     }
193 
194     fn inner(&self) -> SpinLockGuard<InnerVirtIOBlkDevice> {
195         self.inner.lock()
196     }
197 }
198 
199 impl BlockDevice for VirtIOBlkDevice {
200     fn dev_name(&self) -> &BlockDevName {
201         &self.blkdev_meta.devname
202     }
203 
204     fn blkdev_meta(&self) -> &BlockDevMeta {
205         &self.blkdev_meta
206     }
207 
208     fn disk_range(&self) -> GeneralBlockRange {
209         let inner = self.inner();
210         let blocks = inner.device_inner.capacity() as usize * SECTOR_SIZE / LBA_SIZE;
211         drop(inner);
212         log::debug!(
213             "VirtIOBlkDevice '{:?}' disk_range: 0..{}",
214             self.dev_name(),
215             blocks
216         );
217         GeneralBlockRange::new(0, blocks).unwrap()
218     }
219 
220     fn read_at_sync(
221         &self,
222         lba_id_start: BlockId,
223         count: usize,
224         buf: &mut [u8],
225     ) -> Result<usize, SystemError> {
226         let mut inner = self.inner();
227 
228         inner
229             .device_inner
230             .read_blocks(lba_id_start, &mut buf[..count * LBA_SIZE])
231             .map_err(|e| {
232                 error!(
233                     "VirtIOBlkDevice '{:?}' read_at_sync failed: {:?}",
234                     self.dev_id, e
235                 );
236                 SystemError::EIO
237             })?;
238 
239         Ok(count)
240     }
241 
242     fn write_at_sync(
243         &self,
244         lba_id_start: BlockId,
245         count: usize,
246         buf: &[u8],
247     ) -> Result<usize, SystemError> {
248         self.inner()
249             .device_inner
250             .write_blocks(lba_id_start, &buf[..count * LBA_SIZE])
251             .map_err(|_| SystemError::EIO)?;
252         Ok(count)
253     }
254 
255     fn sync(&self) -> Result<(), SystemError> {
256         Ok(())
257     }
258 
259     fn blk_size_log2(&self) -> u8 {
260         9
261     }
262 
263     fn as_any_ref(&self) -> &dyn Any {
264         self
265     }
266 
267     fn device(&self) -> Arc<dyn Device> {
268         self.self_ref.upgrade().unwrap()
269     }
270 
271     fn block_size(&self) -> usize {
272         todo!()
273     }
274 
275     fn partitions(&self) -> Vec<Arc<Partition>> {
276         let device = self.self_ref.upgrade().unwrap() as Arc<dyn BlockDevice>;
277         let mbr_table = MbrDiskPartionTable::from_disk(device.clone())
278             .expect("Failed to get MBR partition table");
279         mbr_table.partitions(Arc::downgrade(&device))
280     }
281 }
282 
283 struct InnerVirtIOBlkDevice {
284     device_inner: VirtIOBlk<HalImpl, VirtIOTransport>,
285     name: Option<String>,
286     virtio_index: Option<VirtIODeviceIndex>,
287     device_common: DeviceCommonData,
288     kobject_common: KObjectCommonData,
289     irq: Option<IrqNumber>,
290 }
291 
292 impl Debug for InnerVirtIOBlkDevice {
293     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
294         f.debug_struct("InnerVirtIOBlkDevice").finish()
295     }
296 }
297 
298 impl VirtIODevice for VirtIOBlkDevice {
299     fn irq(&self) -> Option<IrqNumber> {
300         self.inner().irq
301     }
302 
303     fn handle_irq(
304         &self,
305         _irq: crate::exception::IrqNumber,
306     ) -> Result<IrqReturn, system_error::SystemError> {
307         // todo: handle virtio blk irq
308         Ok(crate::exception::irqdesc::IrqReturn::Handled)
309     }
310 
311     fn dev_id(&self) -> &Arc<DeviceId> {
312         &self.dev_id
313     }
314 
315     fn set_device_name(&self, name: String) {
316         self.inner().name = Some(name);
317     }
318 
319     fn device_name(&self) -> String {
320         self.inner()
321             .name
322             .clone()
323             .unwrap_or_else(|| VIRTIO_BLK_BASENAME.to_string())
324     }
325 
326     fn set_virtio_device_index(&self, index: VirtIODeviceIndex) {
327         self.inner().virtio_index = Some(index);
328     }
329 
330     fn virtio_device_index(&self) -> Option<VirtIODeviceIndex> {
331         self.inner().virtio_index
332     }
333 
334     fn device_type_id(&self) -> u32 {
335         virtio_drivers::transport::DeviceType::Block as u32
336     }
337 
338     fn vendor(&self) -> u32 {
339         VIRTIO_VENDOR_ID.into()
340     }
341 }
342 
343 impl Device for VirtIOBlkDevice {
344     fn dev_type(&self) -> DeviceType {
345         DeviceType::Net
346     }
347 
348     fn id_table(&self) -> IdTable {
349         IdTable::new(VIRTIO_BLK_BASENAME.to_string(), None)
350     }
351 
352     fn bus(&self) -> Option<Weak<dyn Bus>> {
353         self.inner().device_common.bus.clone()
354     }
355 
356     fn set_bus(&self, bus: Option<Weak<dyn Bus>>) {
357         self.inner().device_common.bus = bus;
358     }
359 
360     fn class(&self) -> Option<Arc<dyn Class>> {
361         let mut guard = self.inner();
362         let r = guard.device_common.class.clone()?.upgrade();
363         if r.is_none() {
364             guard.device_common.class = None;
365         }
366 
367         return r;
368     }
369 
370     fn set_class(&self, class: Option<Weak<dyn Class>>) {
371         self.inner().device_common.class = class;
372     }
373 
374     fn driver(&self) -> Option<Arc<dyn Driver>> {
375         let r = self.inner().device_common.driver.clone()?.upgrade();
376         if r.is_none() {
377             self.inner().device_common.driver = None;
378         }
379 
380         return r;
381     }
382 
383     fn set_driver(&self, driver: Option<Weak<dyn Driver>>) {
384         self.inner().device_common.driver = driver;
385     }
386 
387     fn is_dead(&self) -> bool {
388         false
389     }
390 
391     fn can_match(&self) -> bool {
392         self.inner().device_common.can_match
393     }
394 
395     fn set_can_match(&self, can_match: bool) {
396         self.inner().device_common.can_match = can_match;
397     }
398 
399     fn state_synced(&self) -> bool {
400         true
401     }
402 
403     fn dev_parent(&self) -> Option<Weak<dyn Device>> {
404         self.inner().device_common.get_parent_weak_or_clear()
405     }
406 
407     fn set_dev_parent(&self, parent: Option<Weak<dyn Device>>) {
408         self.inner().device_common.parent = parent;
409     }
410 }
411 
412 impl KObject for VirtIOBlkDevice {
413     fn as_any_ref(&self) -> &dyn Any {
414         self
415     }
416 
417     fn set_inode(&self, inode: Option<Arc<KernFSInode>>) {
418         self.inner().kobject_common.kern_inode = inode;
419     }
420 
421     fn inode(&self) -> Option<Arc<KernFSInode>> {
422         self.inner().kobject_common.kern_inode.clone()
423     }
424 
425     fn parent(&self) -> Option<Weak<dyn KObject>> {
426         self.inner().kobject_common.parent.clone()
427     }
428 
429     fn set_parent(&self, parent: Option<Weak<dyn KObject>>) {
430         self.inner().kobject_common.parent = parent;
431     }
432 
433     fn kset(&self) -> Option<Arc<KSet>> {
434         self.inner().kobject_common.kset.clone()
435     }
436 
437     fn set_kset(&self, kset: Option<Arc<KSet>>) {
438         self.inner().kobject_common.kset = kset;
439     }
440 
441     fn kobj_type(&self) -> Option<&'static dyn KObjType> {
442         self.inner().kobject_common.kobj_type
443     }
444 
445     fn name(&self) -> String {
446         self.device_name()
447     }
448 
449     fn set_name(&self, _name: String) {
450         // do nothing
451     }
452 
453     fn kobj_state(&self) -> RwLockReadGuard<KObjectState> {
454         self.locked_kobj_state.read()
455     }
456 
457     fn kobj_state_mut(&self) -> RwLockWriteGuard<KObjectState> {
458         self.locked_kobj_state.write()
459     }
460 
461     fn set_kobj_state(&self, state: KObjectState) {
462         *self.locked_kobj_state.write() = state;
463     }
464 
465     fn set_kobj_type(&self, ktype: Option<&'static dyn KObjType>) {
466         self.inner().kobject_common.kobj_type = ktype;
467     }
468 }
469 
470 #[unified_init(INITCALL_POSTCORE)]
471 fn virtio_blk_driver_init() -> Result<(), SystemError> {
472     let driver = VirtIOBlkDriver::new();
473     virtio_driver_manager()
474         .register(driver.clone() as Arc<dyn VirtIODriver>)
475         .expect("Add virtio net driver failed");
476     unsafe {
477         VIRTIO_BLK_DRIVER = Some(driver);
478     }
479 
480     return Ok(());
481 }
482 
483 #[derive(Debug)]
484 #[cast_to([sync] VirtIODriver)]
485 #[cast_to([sync] Driver)]
486 struct VirtIOBlkDriver {
487     inner: SpinLock<InnerVirtIOBlkDriver>,
488     kobj_state: LockedKObjectState,
489 }
490 
491 impl VirtIOBlkDriver {
492     pub fn new() -> Arc<Self> {
493         let inner = InnerVirtIOBlkDriver {
494             virtio_driver_common: VirtIODriverCommonData::default(),
495             driver_common: DriverCommonData::default(),
496             kobj_common: KObjectCommonData::default(),
497         };
498 
499         let id_table = VirtioDeviceId::new(
500             virtio_drivers::transport::DeviceType::Block as u32,
501             VIRTIO_VENDOR_ID.into(),
502         );
503         let result = VirtIOBlkDriver {
504             inner: SpinLock::new(inner),
505             kobj_state: LockedKObjectState::default(),
506         };
507         result.add_virtio_id(id_table);
508 
509         return Arc::new(result);
510     }
511 
512     fn inner(&self) -> SpinLockGuard<InnerVirtIOBlkDriver> {
513         return self.inner.lock();
514     }
515 }
516 
517 #[derive(Debug)]
518 struct InnerVirtIOBlkDriver {
519     virtio_driver_common: VirtIODriverCommonData,
520     driver_common: DriverCommonData,
521     kobj_common: KObjectCommonData,
522 }
523 
524 impl VirtIODriver for VirtIOBlkDriver {
525     fn probe(&self, device: &Arc<dyn VirtIODevice>) -> Result<(), SystemError> {
526         let dev = device
527             .clone()
528             .arc_any()
529             .downcast::<VirtIOBlkDevice>()
530             .map_err(|_| {
531                 error!(
532                 "VirtIOBlkDriver::probe() failed: device is not a VirtIO block device. Device: '{:?}'",
533                 device.name()
534             );
535                 SystemError::EINVAL
536             })?;
537 
538         block_dev_manager().register(dev as Arc<dyn BlockDevice>)?;
539         return Ok(());
540     }
541 
542     fn virtio_id_table(&self) -> LinkedList<crate::driver::virtio::VirtioDeviceId> {
543         self.inner().virtio_driver_common.id_table.clone()
544     }
545 
546     fn add_virtio_id(&self, id: VirtioDeviceId) {
547         self.inner().virtio_driver_common.id_table.push_back(id);
548     }
549 }
550 
551 impl Driver for VirtIOBlkDriver {
552     fn id_table(&self) -> Option<IdTable> {
553         Some(IdTable::new(VIRTIO_BLK_BASENAME.to_string(), None))
554     }
555 
556     fn add_device(&self, device: Arc<dyn Device>) {
557         let iface = device
558             .arc_any()
559             .downcast::<VirtIOBlkDevice>()
560             .expect("VirtIOBlkDriver::add_device() failed: device is not a VirtIOBlkDevice");
561 
562         self.inner()
563             .driver_common
564             .devices
565             .push(iface as Arc<dyn Device>);
566     }
567 
568     fn delete_device(&self, device: &Arc<dyn Device>) {
569         let _iface = device
570             .clone()
571             .arc_any()
572             .downcast::<VirtIOBlkDevice>()
573             .expect("VirtIOBlkDriver::delete_device() failed: device is not a VirtIOBlkDevice");
574 
575         let mut guard = self.inner();
576         let index = guard
577             .driver_common
578             .devices
579             .iter()
580             .position(|dev| Arc::ptr_eq(device, dev))
581             .expect("VirtIOBlkDriver::delete_device() failed: device not found");
582 
583         guard.driver_common.devices.remove(index);
584     }
585 
586     fn devices(&self) -> Vec<Arc<dyn Device>> {
587         self.inner().driver_common.devices.clone()
588     }
589 
590     fn bus(&self) -> Option<Weak<dyn Bus>> {
591         Some(Arc::downgrade(&virtio_bus()) as Weak<dyn Bus>)
592     }
593 
594     fn set_bus(&self, _bus: Option<Weak<dyn Bus>>) {
595         // do nothing
596     }
597 }
598 
599 impl KObject for VirtIOBlkDriver {
600     fn as_any_ref(&self) -> &dyn Any {
601         self
602     }
603 
604     fn set_inode(&self, inode: Option<Arc<KernFSInode>>) {
605         self.inner().kobj_common.kern_inode = inode;
606     }
607 
608     fn inode(&self) -> Option<Arc<KernFSInode>> {
609         self.inner().kobj_common.kern_inode.clone()
610     }
611 
612     fn parent(&self) -> Option<Weak<dyn KObject>> {
613         self.inner().kobj_common.parent.clone()
614     }
615 
616     fn set_parent(&self, parent: Option<Weak<dyn KObject>>) {
617         self.inner().kobj_common.parent = parent;
618     }
619 
620     fn kset(&self) -> Option<Arc<KSet>> {
621         self.inner().kobj_common.kset.clone()
622     }
623 
624     fn set_kset(&self, kset: Option<Arc<KSet>>) {
625         self.inner().kobj_common.kset = kset;
626     }
627 
628     fn kobj_type(&self) -> Option<&'static dyn KObjType> {
629         self.inner().kobj_common.kobj_type
630     }
631 
632     fn set_kobj_type(&self, ktype: Option<&'static dyn KObjType>) {
633         self.inner().kobj_common.kobj_type = ktype;
634     }
635 
636     fn name(&self) -> String {
637         VIRTIO_BLK_BASENAME.to_string()
638     }
639 
640     fn set_name(&self, _name: String) {
641         // do nothing
642     }
643 
644     fn kobj_state(&self) -> RwLockReadGuard<KObjectState> {
645         self.kobj_state.read()
646     }
647 
648     fn kobj_state_mut(&self) -> RwLockWriteGuard<KObjectState> {
649         self.kobj_state.write()
650     }
651 
652     fn set_kobj_state(&self, state: KObjectState) {
653         *self.kobj_state.write() = state;
654     }
655 }
656