xref: /DragonOS/user/apps/test-for-robustfutex/src/main.rs (revision 06560afa2aa4db352526f4be8b6262719b8b3eac)
1*06560afaShmt extern crate libc;
2*06560afaShmt extern crate syscalls;
3*06560afaShmt 
4*06560afaShmt use std::{
5*06560afaShmt     ffi::c_void,
6*06560afaShmt     mem::{self, size_of},
7*06560afaShmt     process,
8*06560afaShmt     ptr::{self, NonNull},
9*06560afaShmt     sync::atomic::{AtomicI32, Ordering},
10*06560afaShmt     thread,
11*06560afaShmt     time::Duration,
12*06560afaShmt };
13*06560afaShmt 
14*06560afaShmt use syscalls::{
15*06560afaShmt     syscall0, syscall2, syscall3, syscall6,
16*06560afaShmt     Sysno::{futex, get_robust_list, gettid, set_robust_list},
17*06560afaShmt };
18*06560afaShmt 
19*06560afaShmt use libc::{
20*06560afaShmt     c_int, mmap, perror, EXIT_FAILURE, MAP_ANONYMOUS, MAP_FAILED, MAP_SHARED, PROT_READ, PROT_WRITE,
21*06560afaShmt };
22*06560afaShmt 
23*06560afaShmt const FUTEX_WAIT: usize = 0;
24*06560afaShmt const FUTEX_WAKE: usize = 1;
25*06560afaShmt 
26*06560afaShmt // 封装futex
27*06560afaShmt #[derive(Clone, Copy, Debug)]
28*06560afaShmt struct Futex {
29*06560afaShmt     addr: *mut u32,
30*06560afaShmt }
31*06560afaShmt 
32*06560afaShmt impl Futex {
new(addr: *mut u32) -> Self33*06560afaShmt     pub fn new(addr: *mut u32) -> Self {
34*06560afaShmt         return Futex { addr };
35*06560afaShmt     }
36*06560afaShmt 
get_addr(&self, offset: isize) -> *mut u3237*06560afaShmt     pub fn get_addr(&self, offset: isize) -> *mut u32 {
38*06560afaShmt         return unsafe { self.addr.offset(offset) };
39*06560afaShmt     }
40*06560afaShmt 
get_val(&self, offset: isize) -> u3241*06560afaShmt     pub fn get_val(&self, offset: isize) -> u32 {
42*06560afaShmt         return unsafe { self.addr.offset(offset).read() };
43*06560afaShmt     }
44*06560afaShmt 
set_val(&self, val: u32, offset: isize)45*06560afaShmt     pub fn set_val(&self, val: u32, offset: isize) {
46*06560afaShmt         unsafe {
47*06560afaShmt             self.addr.offset(offset).write(val);
48*06560afaShmt         }
49*06560afaShmt     }
50*06560afaShmt }
51*06560afaShmt 
52*06560afaShmt unsafe impl Send for Futex {}
53*06560afaShmt unsafe impl Sync for Futex {}
54*06560afaShmt 
55*06560afaShmt #[derive(Clone, Copy, Debug)]
56*06560afaShmt struct Lock {
57*06560afaShmt     addr: *mut i32,
58*06560afaShmt }
59*06560afaShmt 
60*06560afaShmt impl Lock {
new(addr: *mut i32) -> Self61*06560afaShmt     pub fn new(addr: *mut i32) -> Self {
62*06560afaShmt         return Lock { addr };
63*06560afaShmt     }
64*06560afaShmt 
get_val(&self, offset: isize) -> i3265*06560afaShmt     pub fn get_val(&self, offset: isize) -> i32 {
66*06560afaShmt         return unsafe { self.addr.offset(offset).read() };
67*06560afaShmt     }
68*06560afaShmt 
set_val(&self, val: i32, offset: isize)69*06560afaShmt     pub fn set_val(&self, val: i32, offset: isize) {
70*06560afaShmt         unsafe {
71*06560afaShmt             self.addr.offset(offset).write(val);
72*06560afaShmt         }
73*06560afaShmt     }
74*06560afaShmt }
75*06560afaShmt 
76*06560afaShmt unsafe impl Send for Lock {}
77*06560afaShmt unsafe impl Sync for Lock {}
78*06560afaShmt 
79*06560afaShmt #[derive(Debug, Clone, Copy)]
80*06560afaShmt struct RobustList {
81*06560afaShmt     next: *const RobustList,
82*06560afaShmt }
83*06560afaShmt 
84*06560afaShmt #[derive(Debug, Clone, Copy)]
85*06560afaShmt struct RobustListHead {
86*06560afaShmt     list: RobustList,
87*06560afaShmt     /// 向kernel提供了要检查的futex字段的相对位置,保持用户空间的灵活性,可以自由
88*06560afaShmt     /// 地调整其数据结构,而无需向内核硬编码任何特定的偏移量
89*06560afaShmt     /// futexes中前面的地址是用来存入robust list中(list.next),后面是存放具体的futex val
90*06560afaShmt     /// 这个字段的作用就是从前面的地址偏移到后面的地址中从而获取futex val
91*06560afaShmt     #[allow(dead_code)]
92*06560afaShmt     futex_offset: isize,
93*06560afaShmt     /// 潜在的竞争条件:由于添加和删除列表是在获取锁之后进行的,这給线程留下了一个小窗口,在此期间可能会导致异常退出,
94*06560afaShmt     /// 从而使锁被悬挂,为了防止这种可能性。用户空间还维护了一个简单的list_op_pending字段,允许线程在获取锁后但还未添加到
95*06560afaShmt     /// 列表时就异常退出时进行清理。并且在完成列表添加或删除操作后将其清除
96*06560afaShmt     /// 这里没有测试这个,在内核中实现实际上就是把list_op_pending地址进行一次唤醒(如果有等待者)
97*06560afaShmt     #[allow(dead_code)]
98*06560afaShmt     list_op_pending: *const RobustList,
99*06560afaShmt }
100*06560afaShmt 
error_handle(msg: &str) -> !101*06560afaShmt fn error_handle(msg: &str) -> ! {
102*06560afaShmt     unsafe { perror(msg.as_ptr() as *const i8) };
103*06560afaShmt     process::exit(EXIT_FAILURE)
104*06560afaShmt }
105*06560afaShmt 
futex_wait(futexes: Futex, thread: &str, offset_futex: isize, lock: Lock, offset_count: isize)106*06560afaShmt fn futex_wait(futexes: Futex, thread: &str, offset_futex: isize, lock: Lock, offset_count: isize) {
107*06560afaShmt     loop {
108*06560afaShmt         let atomic_count = AtomicI32::new(lock.get_val(offset_count));
109*06560afaShmt         if atomic_count
110*06560afaShmt             .compare_exchange(1, 0, Ordering::SeqCst, Ordering::SeqCst)
111*06560afaShmt             .is_ok()
112*06560afaShmt         {
113*06560afaShmt             lock.set_val(0, offset_count);
114*06560afaShmt 
115*06560afaShmt             // 设置futex锁当前被哪个线程占用
116*06560afaShmt             let tid = unsafe { syscall0(gettid).unwrap() as u32 };
117*06560afaShmt             futexes.set_val(futexes.get_val(offset_futex) | tid, offset_futex);
118*06560afaShmt 
119*06560afaShmt             break;
120*06560afaShmt         }
121*06560afaShmt 
122*06560afaShmt         println!("{} wating...", thread);
123*06560afaShmt         let futex_val = futexes.get_val(offset_futex);
124*06560afaShmt         futexes.set_val(futex_val | 0x8000_0000, offset_futex);
125*06560afaShmt         let ret = unsafe {
126*06560afaShmt             syscall6(
127*06560afaShmt                 futex,
128*06560afaShmt                 futexes.get_addr(offset_futex) as usize,
129*06560afaShmt                 FUTEX_WAIT,
130*06560afaShmt                 futexes.get_val(offset_futex) as usize,
131*06560afaShmt                 0,
132*06560afaShmt                 0,
133*06560afaShmt                 0,
134*06560afaShmt             )
135*06560afaShmt         };
136*06560afaShmt         if ret.is_err() {
137*06560afaShmt             error_handle("futex_wait failed");
138*06560afaShmt         }
139*06560afaShmt 
140*06560afaShmt         // 被唤醒后释放锁
141*06560afaShmt         let atomic_count = AtomicI32::new(lock.get_val(offset_count));
142*06560afaShmt         if atomic_count
143*06560afaShmt             .compare_exchange(0, 1, Ordering::SeqCst, Ordering::SeqCst)
144*06560afaShmt             .is_ok()
145*06560afaShmt         {
146*06560afaShmt             lock.set_val(1, offset_count);
147*06560afaShmt 
148*06560afaShmt             // 释放futex锁,不被任何线程占用
149*06560afaShmt             futexes.set_val(futexes.get_val(offset_futex) & 0xc000_0000, offset_futex);
150*06560afaShmt 
151*06560afaShmt             break;
152*06560afaShmt         }
153*06560afaShmt     }
154*06560afaShmt }
155*06560afaShmt 
futex_wake(futexes: Futex, thread: &str, offset_futex: isize, lock: Lock, offset_count: isize)156*06560afaShmt fn futex_wake(futexes: Futex, thread: &str, offset_futex: isize, lock: Lock, offset_count: isize) {
157*06560afaShmt     let atomic_count = AtomicI32::new(lock.get_val(offset_count));
158*06560afaShmt     if atomic_count
159*06560afaShmt         .compare_exchange(0, 1, Ordering::SeqCst, Ordering::SeqCst)
160*06560afaShmt         .is_ok()
161*06560afaShmt     {
162*06560afaShmt         lock.set_val(1, offset_count);
163*06560afaShmt 
164*06560afaShmt         // 释放futex锁,不被任何线程占用
165*06560afaShmt         futexes.set_val(futexes.get_val(offset_futex) & 0xc000_0000, offset_futex);
166*06560afaShmt 
167*06560afaShmt         // 如果没有线程/进程在等这个futex,则不必唤醒, 释放改锁即可
168*06560afaShmt         let futex_val = futexes.get_val(offset_futex);
169*06560afaShmt         if futex_val & 0x8000_0000 == 0 {
170*06560afaShmt             return;
171*06560afaShmt         }
172*06560afaShmt 
173*06560afaShmt         futexes.set_val(futex_val & !(1 << 31), offset_futex);
174*06560afaShmt         let ret = unsafe {
175*06560afaShmt             syscall6(
176*06560afaShmt                 futex,
177*06560afaShmt                 futexes.get_addr(offset_futex) as usize,
178*06560afaShmt                 FUTEX_WAKE,
179*06560afaShmt                 1,
180*06560afaShmt                 0,
181*06560afaShmt                 0,
182*06560afaShmt                 0,
183*06560afaShmt             )
184*06560afaShmt         };
185*06560afaShmt         if ret.is_err() {
186*06560afaShmt             error_handle("futex wake failed");
187*06560afaShmt         }
188*06560afaShmt         println!("{} waked", thread);
189*06560afaShmt     }
190*06560afaShmt }
191*06560afaShmt 
set_list(futexes: Futex)192*06560afaShmt fn set_list(futexes: Futex) {
193*06560afaShmt     let head = RobustListHead {
194*06560afaShmt         list: RobustList { next: ptr::null() },
195*06560afaShmt         futex_offset: 44,
196*06560afaShmt         list_op_pending: ptr::null(),
197*06560afaShmt     };
198*06560afaShmt     let head = NonNull::from(&head).as_ptr();
199*06560afaShmt     unsafe {
200*06560afaShmt         // 加入第一个futex
201*06560afaShmt         let head_ref_mut = &mut *head;
202*06560afaShmt         head_ref_mut.list.next = futexes.get_addr(0) as *const RobustList;
203*06560afaShmt 
204*06560afaShmt         // 加入第二个futex
205*06560afaShmt         let list_2 = NonNull::from(&*head_ref_mut.list.next).as_ptr();
206*06560afaShmt         let list_2_ref_mut = &mut *list_2;
207*06560afaShmt         list_2_ref_mut.next = futexes.get_addr(1) as *const RobustList;
208*06560afaShmt 
209*06560afaShmt         //println!("robust list next: {:?}", (*head).list.next );
210*06560afaShmt         //println!("robust list next next: {:?}", (*(*head).list.next).next );
211*06560afaShmt 
212*06560afaShmt         // 向内核注册robust list
213*06560afaShmt         let len = mem::size_of::<*mut RobustListHead>();
214*06560afaShmt         let ret = syscall2(set_robust_list, head as usize, len);
215*06560afaShmt         if ret.is_err() {
216*06560afaShmt             println!("failed to set_robust_list, ret = {:?}", ret);
217*06560afaShmt         }
218*06560afaShmt     }
219*06560afaShmt }
220*06560afaShmt 
main()221*06560afaShmt fn main() {
222*06560afaShmt     test01();
223*06560afaShmt 
224*06560afaShmt     println!("-------------");
225*06560afaShmt 
226*06560afaShmt     test02();
227*06560afaShmt 
228*06560afaShmt     println!("-------------");
229*06560afaShmt }
230*06560afaShmt 
231*06560afaShmt //测试set_robust_list和get_robust_list两个系统调用是否能正常使用
test01()232*06560afaShmt fn test01() {
233*06560afaShmt     // 创建robust list 头指针
234*06560afaShmt     let head = RobustListHead {
235*06560afaShmt         list: RobustList { next: ptr::null() },
236*06560afaShmt         futex_offset: 8,
237*06560afaShmt         list_op_pending: ptr::null(),
238*06560afaShmt     };
239*06560afaShmt     let head = NonNull::from(&head).as_ptr();
240*06560afaShmt 
241*06560afaShmt     let futexes = unsafe {
242*06560afaShmt         mmap(
243*06560afaShmt             ptr::null_mut::<c_void>(),
244*06560afaShmt             (size_of::<c_int>() * 2) as libc::size_t,
245*06560afaShmt             PROT_READ | PROT_WRITE,
246*06560afaShmt             MAP_ANONYMOUS | MAP_SHARED,
247*06560afaShmt             -1,
248*06560afaShmt             0,
249*06560afaShmt         ) as *mut u32
250*06560afaShmt     };
251*06560afaShmt     if futexes == MAP_FAILED as *mut u32 {
252*06560afaShmt         error_handle("futexes_addr mmap failed");
253*06560afaShmt     }
254*06560afaShmt 
255*06560afaShmt     unsafe {
256*06560afaShmt         futexes.offset(11).write(0x0000_0000);
257*06560afaShmt         futexes.offset(12).write(0x8000_0000);
258*06560afaShmt         println!("futex1 next addr: {:?}", futexes.offset(0));
259*06560afaShmt         println!("futex2 next addr: {:?}", futexes.offset(1));
260*06560afaShmt         println!("futex1 val addr: {:?}", futexes.offset(11));
261*06560afaShmt         println!("futex2 val addr: {:?}", futexes.offset(12));
262*06560afaShmt         println!("futex1 val: {:#x?}", futexes.offset(11).read());
263*06560afaShmt         println!("futex2 val: {:#x?}", futexes.offset(12).read());
264*06560afaShmt     }
265*06560afaShmt 
266*06560afaShmt     // 打印注册之前的robust list
267*06560afaShmt     println!("robust list next(get behind): {:?}", &unsafe { *head });
268*06560afaShmt 
269*06560afaShmt     unsafe {
270*06560afaShmt         let head_ref_mut = &mut *head;
271*06560afaShmt         head_ref_mut.list.next = futexes.offset(0) as *const RobustList;
272*06560afaShmt         let list_2 = NonNull::from(&*head_ref_mut.list.next).as_ptr();
273*06560afaShmt         let list_2_ref_mut = &mut *list_2;
274*06560afaShmt         list_2_ref_mut.next = futexes.offset(1) as *const RobustList;
275*06560afaShmt         println!("robust list next addr: {:?}", (*head).list.next);
276*06560afaShmt         println!(
277*06560afaShmt             "robust list next next addr: {:?}",
278*06560afaShmt             (*(*head).list.next).next
279*06560afaShmt         );
280*06560afaShmt     }
281*06560afaShmt 
282*06560afaShmt     unsafe {
283*06560afaShmt         let len = mem::size_of::<*mut RobustListHead>();
284*06560afaShmt         let ret = syscall2(set_robust_list, head as usize, len);
285*06560afaShmt         if ret.is_err() {
286*06560afaShmt             println!("failed to set_robust_list, ret = {:?}", ret);
287*06560afaShmt         }
288*06560afaShmt     }
289*06560afaShmt 
290*06560afaShmt     println!("get before, set after: {:?}", head);
291*06560afaShmt     println!("get before, set after: {:?}", &unsafe { *head });
292*06560afaShmt     unsafe {
293*06560afaShmt         let len: usize = 0;
294*06560afaShmt         println!("len = {}", len);
295*06560afaShmt         let len_ptr = NonNull::from(&len).as_ptr();
296*06560afaShmt         let ret = syscall3(get_robust_list, 0, head as usize, len_ptr as usize);
297*06560afaShmt         println!("get len = {}", len);
298*06560afaShmt         if ret.is_err() {
299*06560afaShmt             println!("failed to get_robust_list, ret = {:?}", ret);
300*06560afaShmt         }
301*06560afaShmt 
302*06560afaShmt         println!("futex1 val: {:#x}", futexes.offset(11).read());
303*06560afaShmt         println!("futex2 val: {:#x}", futexes.offset(12).read());
304*06560afaShmt         println!("robust list next: {:?}", futexes.offset(0));
305*06560afaShmt         println!("robust list next next: {:#x?}", futexes.offset(0).read());
306*06560afaShmt     }
307*06560afaShmt     println!("robust list head(get after): {:?}", head);
308*06560afaShmt     println!("robust list next(get after): {:?}", &unsafe { *head });
309*06560afaShmt }
310*06560afaShmt 
311*06560afaShmt //测试一个线程异常退出时futex的robustness(多线程测试,目前futex还不支持多进程)
test02()312*06560afaShmt fn test02() {
313*06560afaShmt     let futexes = unsafe {
314*06560afaShmt         mmap(
315*06560afaShmt             ptr::null_mut::<c_void>(),
316*06560afaShmt             (size_of::<c_int>() * 2) as libc::size_t,
317*06560afaShmt             PROT_READ | PROT_WRITE,
318*06560afaShmt             MAP_ANONYMOUS | MAP_SHARED,
319*06560afaShmt             -1,
320*06560afaShmt             0,
321*06560afaShmt         ) as *mut u32
322*06560afaShmt     };
323*06560afaShmt     if futexes == MAP_FAILED as *mut u32 {
324*06560afaShmt         error_handle("mmap failed");
325*06560afaShmt     }
326*06560afaShmt     let count = unsafe {
327*06560afaShmt         mmap(
328*06560afaShmt             ptr::null_mut::<c_void>(),
329*06560afaShmt             (size_of::<c_int>() * 2) as libc::size_t,
330*06560afaShmt             PROT_READ | PROT_WRITE,
331*06560afaShmt             MAP_ANONYMOUS | MAP_SHARED,
332*06560afaShmt             -1,
333*06560afaShmt             0,
334*06560afaShmt         ) as *mut i32
335*06560afaShmt     };
336*06560afaShmt     if count == MAP_FAILED as *mut i32 {
337*06560afaShmt         error_handle("mmap failed");
338*06560afaShmt     }
339*06560afaShmt 
340*06560afaShmt     unsafe {
341*06560afaShmt         // 在这个示例中,第一段和第二段地址放入robust list,第11段地址和第12段地址存放futex val
342*06560afaShmt         futexes.offset(11).write(0x0000_0000);
343*06560afaShmt         futexes.offset(12).write(0x0000_0000);
344*06560afaShmt         println!("futex1 next addr: {:?}", futexes.offset(0));
345*06560afaShmt         println!("futex2 next addr: {:?}", futexes.offset(1));
346*06560afaShmt         println!("futex1 val addr: {:?}", futexes.offset(11));
347*06560afaShmt         println!("futex2 val addr: {:?}", futexes.offset(12));
348*06560afaShmt         println!("futex1 val: {:#x?}", futexes.offset(11).read());
349*06560afaShmt         println!("futex2 val: {:#x?}", futexes.offset(12).read());
350*06560afaShmt 
351*06560afaShmt         count.offset(0).write(1);
352*06560afaShmt         count.offset(1).write(0);
353*06560afaShmt         println!("count1 val: {:?}", count.offset(0).read());
354*06560afaShmt         println!("count2 val: {:?}", count.offset(1).read());
355*06560afaShmt     }
356*06560afaShmt 
357*06560afaShmt     let futexes = Futex::new(futexes);
358*06560afaShmt     let locks = Lock::new(count);
359*06560afaShmt 
360*06560afaShmt     // tid1 = 7
361*06560afaShmt     let thread1 = thread::spawn(move || {
362*06560afaShmt         set_list(futexes);
363*06560afaShmt         thread::sleep(Duration::from_secs(2));
364*06560afaShmt         for i in 0..2 {
365*06560afaShmt             futex_wait(futexes, "thread1", 11, locks, 0);
366*06560afaShmt             println!("thread1 times: {}", i);
367*06560afaShmt             thread::sleep(Duration::from_secs(3));
368*06560afaShmt 
369*06560afaShmt             let tid = unsafe { syscall0(gettid).unwrap() as u32 };
370*06560afaShmt             futexes.set_val(futexes.get_val(12) | tid, 12);
371*06560afaShmt 
372*06560afaShmt             if i == 1 {
373*06560afaShmt                 // 让thread1异常退出,从而无法唤醒thread2,检测robustness
374*06560afaShmt                 println!("Thread1 exiting early due to simulated error.");
375*06560afaShmt                 return;
376*06560afaShmt             }
377*06560afaShmt             futex_wake(futexes, "thread2", 12, locks, 1);
378*06560afaShmt         }
379*06560afaShmt     });
380*06560afaShmt 
381*06560afaShmt     // tid2 = 6
382*06560afaShmt     set_list(futexes);
383*06560afaShmt     for i in 0..2 {
384*06560afaShmt         futex_wait(futexes, "thread2", 12, locks, 1);
385*06560afaShmt         println!("thread2 times: {}", i);
386*06560afaShmt 
387*06560afaShmt         let tid = unsafe { syscall0(gettid).unwrap() as u32 };
388*06560afaShmt         futexes.set_val(futexes.get_val(11) | tid, 11);
389*06560afaShmt 
390*06560afaShmt         futex_wake(futexes, "thread1", 11, locks, 0);
391*06560afaShmt     }
392*06560afaShmt 
393*06560afaShmt     thread1.join().unwrap();
394*06560afaShmt }
395