xref: /DADK/dadk-user/src/scheduler/mod.rs (revision 70352fd6b1ba6ff2dca344d5c30e0e9b76b5e6b4)
1 use std::{
2     collections::{BTreeMap, HashMap},
3     fmt::Debug,
4     path::PathBuf,
5     process::exit,
6     sync::{
7         atomic::{AtomicI32, Ordering},
8         Arc, Mutex, RwLock,
9     },
10     thread::ThreadId,
11 };
12 
13 use log::{error, info};
14 
15 use crate::{
16     console::Action,
17     context::DadkUserExecuteContext,
18     executor::{target::Target, Executor},
19     parser::task::DADKTask,
20 };
21 
22 use self::task_deque::TASK_DEQUE;
23 
24 pub mod task_deque;
25 #[cfg(test)]
26 mod tests;
27 
28 lazy_static! {
29     // 线程id与任务实体id映射表
30     pub static ref TID_EID: Mutex<HashMap<ThreadId,i32>> = Mutex::new(HashMap::new());
31 }
32 
33 /// # 调度实体内部结构
34 #[derive(Debug, Clone)]
35 pub struct InnerEntity {
36     /// 任务ID
37     id: i32,
38     file_path: PathBuf,
39     /// 任务
40     task: DADKTask,
41     /// 入度
42     indegree: usize,
43     /// 子节点
44     children: Vec<Arc<SchedEntity>>,
45     /// target管理
46     target: Option<Target>,
47 }
48 
49 /// # 调度实体
50 #[derive(Debug)]
51 pub struct SchedEntity {
52     inner: Mutex<InnerEntity>,
53 }
54 
55 impl PartialEq for SchedEntity {
56     fn eq(&self, other: &Self) -> bool {
57         self.inner.lock().unwrap().id == other.inner.lock().unwrap().id
58     }
59 }
60 
61 impl SchedEntity {
62     #[allow(dead_code)]
63     pub fn id(&self) -> i32 {
64         self.inner.lock().unwrap().id
65     }
66 
67     #[allow(dead_code)]
68     pub fn file_path(&self) -> PathBuf {
69         self.inner.lock().unwrap().file_path.clone()
70     }
71 
72     #[allow(dead_code)]
73     pub fn task(&self) -> DADKTask {
74         self.inner.lock().unwrap().task.clone()
75     }
76 
77     /// 入度加1
78     pub fn add_indegree(&self) {
79         self.inner.lock().unwrap().indegree += 1;
80     }
81 
82     /// 入度减1
83     pub fn sub_indegree(&self) -> usize {
84         self.inner.lock().unwrap().indegree -= 1;
85         return self.inner.lock().unwrap().indegree;
86     }
87 
88     /// 增加子节点
89     pub fn add_child(&self, entity: Arc<SchedEntity>) {
90         self.inner.lock().unwrap().children.push(entity);
91     }
92 
93     /// 获取入度
94     pub fn indegree(&self) -> usize {
95         self.inner.lock().unwrap().indegree
96     }
97 
98     /// 获取target
99     pub fn target(&self) -> Option<Target> {
100         self.inner.lock().unwrap().target.clone()
101     }
102 
103     /// 当前任务完成后,所有子节点入度减1
104     ///
105     /// ## 参数
106     ///
107     /// 无
108     ///
109     /// ## 返回值
110     ///
111     /// 所有入度为0的子节点集合
112     pub fn sub_children_indegree(&self) -> Vec<Arc<SchedEntity>> {
113         let mut zero_child = Vec::new();
114         let children = &self.inner.lock().unwrap().children;
115         for child in children.iter() {
116             if child.sub_indegree() == 0 {
117                 zero_child.push(child.clone());
118             }
119         }
120         return zero_child;
121     }
122 }
123 
124 /// # 调度实体列表
125 ///
126 /// 用于存储所有的调度实体
127 #[derive(Debug)]
128 pub struct SchedEntities {
129     /// 任务ID到调度实体的映射
130     id2entity: RwLock<BTreeMap<i32, Arc<SchedEntity>>>,
131 }
132 
133 impl SchedEntities {
134     pub fn new() -> Self {
135         Self {
136             id2entity: RwLock::new(BTreeMap::new()),
137         }
138     }
139 
140     pub fn add(&mut self, entity: Arc<SchedEntity>) {
141         self.id2entity
142             .write()
143             .unwrap()
144             .insert(entity.id(), entity.clone());
145     }
146 
147     #[allow(dead_code)]
148     pub fn get(&self, id: i32) -> Option<Arc<SchedEntity>> {
149         self.id2entity.read().unwrap().get(&id).cloned()
150     }
151 
152     pub fn get_by_name_version(&self, name: &str, version: &str) -> Option<Arc<SchedEntity>> {
153         for e in self.id2entity.read().unwrap().iter() {
154             if e.1.task().name_version_env() == DADKTask::name_version_uppercase(name, version) {
155                 return Some(e.1.clone());
156             }
157         }
158         return None;
159     }
160 
161     pub fn entities(&self) -> Vec<Arc<SchedEntity>> {
162         let mut v = Vec::new();
163         for e in self.id2entity.read().unwrap().iter() {
164             v.push(e.1.clone());
165         }
166         return v;
167     }
168 
169     pub fn id2entity(&self) -> BTreeMap<i32, Arc<SchedEntity>> {
170         self.id2entity.read().unwrap().clone()
171     }
172 
173     #[allow(dead_code)]
174     pub fn len(&self) -> usize {
175         self.id2entity.read().unwrap().len()
176     }
177 
178     #[allow(dead_code)]
179     pub fn is_empty(&self) -> bool {
180         self.id2entity.read().unwrap().is_empty()
181     }
182 
183     #[allow(dead_code)]
184     pub fn clear(&mut self) {
185         self.id2entity.write().unwrap().clear();
186     }
187 
188     pub fn topo_sort(&self) -> Vec<Arc<SchedEntity>> {
189         let mut result = Vec::new();
190         let mut visited = BTreeMap::new();
191         let btree = self.id2entity.write().unwrap().clone();
192         for entity in btree.iter() {
193             if !visited.contains_key(entity.0) {
194                 let r = self.dfs(entity.1, &mut visited, &mut result);
195                 if r.is_err() {
196                     let err = r.unwrap_err();
197                     error!("{}", err.display());
198                     println!("Please fix the errors above and try again.");
199                     std::process::exit(1);
200                 }
201             }
202         }
203         return result;
204     }
205 
206     fn dfs(
207         &self,
208         entity: &Arc<SchedEntity>,
209         visited: &mut BTreeMap<i32, bool>,
210         result: &mut Vec<Arc<SchedEntity>>,
211     ) -> Result<(), DependencyCycleError> {
212         visited.insert(entity.id(), false);
213         for dep in entity.task().depends.iter() {
214             if let Some(dep_entity) = self.get_by_name_version(&dep.name, &dep.version) {
215                 let guard = self.id2entity.write().unwrap();
216                 let e = guard.get(&entity.id()).unwrap();
217                 let d = guard.get(&dep_entity.id()).unwrap();
218                 e.add_indegree();
219                 d.add_child(e.clone());
220                 if let Some(&false) = visited.get(&dep_entity.id()) {
221                     // 输出完整环形依赖
222                     let mut err = DependencyCycleError::new(dep_entity.clone());
223 
224                     err.add(entity.clone(), dep_entity);
225                     return Err(err);
226                 }
227                 if !visited.contains_key(&dep_entity.id()) {
228                     drop(guard);
229                     let r = self.dfs(&dep_entity, visited, result);
230                     if r.is_err() {
231                         let mut err: DependencyCycleError = r.unwrap_err();
232                         // 如果错误已经停止传播,则直接返回
233                         if err.stop_propagation {
234                             return Err(err);
235                         }
236                         // 如果当前实体是错误的起始实体,则停止传播
237                         if entity == &err.head_entity {
238                             err.stop_propagation();
239                         }
240                         err.add(entity.clone(), dep_entity);
241                         return Err(err);
242                     }
243                 }
244             } else {
245                 error!(
246                     "Dependency not found: {} -> {}",
247                     entity.task().name_version(),
248                     dep.name_version()
249                 );
250                 std::process::exit(1);
251             }
252         }
253         visited.insert(entity.id(), true);
254         result.push(entity.clone());
255         return Ok(());
256     }
257 }
258 
259 /// # 任务调度器
260 #[derive(Debug)]
261 pub struct Scheduler {
262     /// DragonOS sysroot在主机上的路径
263     sysroot_dir: PathBuf,
264     /// 要执行的操作
265     action: Action,
266     /// 调度实体列表
267     target: SchedEntities,
268     /// dadk执行的上下文
269     context: Arc<DadkUserExecuteContext>,
270 }
271 
272 pub enum SchedulerError {
273     TaskError(String),
274     /// 不是当前正在编译的目标架构
275     InvalidTargetArch(String),
276     DependencyNotFound(Arc<SchedEntity>, String),
277     RunError(String),
278 }
279 
280 impl Debug for SchedulerError {
281     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282         match self {
283             Self::TaskError(arg0) => {
284                 write!(f, "TaskError: {}", arg0)
285             }
286             SchedulerError::DependencyNotFound(current, msg) => {
287                 write!(
288                     f,
289                     "For task {}, dependency not found: {}. Please check file: {}",
290                     current.task().name_version(),
291                     msg,
292                     current.file_path().display()
293                 )
294             }
295             SchedulerError::RunError(msg) => {
296                 write!(f, "RunError: {}", msg)
297             }
298             SchedulerError::InvalidTargetArch(msg) => {
299                 write!(f, "InvalidTargetArch: {}", msg)
300             }
301         }
302     }
303 }
304 
305 impl Scheduler {
306     pub fn new(
307         context: Arc<DadkUserExecuteContext>,
308         dragonos_dir: PathBuf,
309         action: Action,
310         tasks: Vec<(PathBuf, DADKTask)>,
311     ) -> Result<Self, SchedulerError> {
312         let entities = SchedEntities::new();
313 
314         let mut scheduler = Scheduler {
315             sysroot_dir: dragonos_dir,
316             action,
317             target: entities,
318             context,
319         };
320 
321         let r = scheduler.add_tasks(tasks);
322         if r.is_err() {
323             error!("Error while adding tasks: {:?}", r);
324             return Err(r.err().unwrap());
325         }
326 
327         return Ok(scheduler);
328     }
329 
330     /// # 添加多个任务
331     ///
332     /// 添加任务到调度器中,如果任务已经存在,则返回错误
333     pub fn add_tasks(&mut self, tasks: Vec<(PathBuf, DADKTask)>) -> Result<(), SchedulerError> {
334         for task in tasks {
335             let e = self.add_task(task.0, task.1);
336             if e.is_err() {
337                 if let Err(SchedulerError::InvalidTargetArch(_)) = &e {
338                     continue;
339                 }
340                 e?;
341             }
342         }
343 
344         return Ok(());
345     }
346 
347     /// # 任务是否匹配当前目标架构
348     pub fn task_arch_matched(&self, task: &DADKTask) -> bool {
349         task.target_arch.contains(self.context.target_arch())
350     }
351 
352     /// # 添加一个任务
353     ///
354     /// 添加任务到调度器中,如果任务已经存在,则返回错误
355     pub fn add_task(
356         &mut self,
357         path: PathBuf,
358         task: DADKTask,
359     ) -> Result<Arc<SchedEntity>, SchedulerError> {
360         if !self.task_arch_matched(&task) {
361             return Err(SchedulerError::InvalidTargetArch(format!(
362                 "Task {} is not for target arch: {:?}",
363                 task.name_version(),
364                 self.context.target_arch()
365             )));
366         }
367 
368         let id: i32 = self.generate_task_id();
369         let indegree: usize = 0;
370         let children = Vec::new();
371         let target = self.generate_task_target(&path, &task.rust_target)?;
372         let entity = Arc::new(SchedEntity {
373             inner: Mutex::new(InnerEntity {
374                 id,
375                 task,
376                 file_path: path.clone(),
377                 indegree,
378                 children,
379                 target,
380             }),
381         });
382         let name_version = (entity.task().name.clone(), entity.task().version.clone());
383 
384         if self
385             .target
386             .get_by_name_version(&name_version.0, &name_version.1)
387             .is_some()
388         {
389             return Err(SchedulerError::TaskError(format!(
390                 "Task with name [{}] and version [{}] already exists. Config file: {}",
391                 name_version.0,
392                 name_version.1,
393                 path.display()
394             )));
395         }
396 
397         self.target.add(entity.clone());
398 
399         info!("Task added: {}", entity.task().name_version());
400         return Ok(entity);
401     }
402 
403     fn generate_task_id(&self) -> i32 {
404         static TASK_ID: AtomicI32 = AtomicI32::new(0);
405         return TASK_ID.fetch_add(1, Ordering::SeqCst);
406     }
407 
408     fn generate_task_target(
409         &self,
410         path: &PathBuf,
411         rust_target: &Option<String>,
412     ) -> Result<Option<Target>, SchedulerError> {
413         if let Some(rust_target) = rust_target {
414             // 如果rust_target字段不为none,说明需要target管理
415             // 获取dadk任务路径,用于生成临时dadk文件名
416             let file_str = path.as_path().to_str().unwrap();
417             let tmp_dadk_path = Target::tmp_dadk(file_str);
418             let tmp_dadk_str = tmp_dadk_path.as_path().to_str().unwrap();
419 
420             if Target::is_user_target(rust_target) {
421                 // 如果target文件是用户自己的
422                 if let Ok(target_path) = Target::user_target_path(rust_target) {
423                     let target_path_str = target_path.as_path().to_str().unwrap();
424                     let index = target_path_str.rfind('/').unwrap();
425                     let target_name = target_path_str[index + 1..].to_string();
426                     let tmp_target = PathBuf::from(format!("{}{}", tmp_dadk_str, target_name));
427                     return Ok(Some(Target::new(tmp_target)));
428                 } else {
429                     return Err(SchedulerError::TaskError(
430                         "The path of target file is invalid.".to_string(),
431                     ));
432                 }
433             } else {
434                 // 如果target文件是内置的
435                 let tmp_target = PathBuf::from(format!("{}{}.json", tmp_dadk_str, rust_target));
436                 return Ok(Some(Target::new(tmp_target)));
437             }
438         }
439         return Ok(None);
440     }
441 
442     /// # 执行调度器中的所有任务
443     pub fn run(&self) -> Result<(), SchedulerError> {
444         // 准备全局环境变量
445         crate::executor::prepare_env(&self.target, &self.context)
446             .map_err(|e| SchedulerError::RunError(format!("{:?}", e)))?;
447 
448         match self.action {
449             Action::Build | Action::Install => {
450                 self.run_with_topo_sort()?;
451             }
452             Action::Clean(_) => self.run_without_topo_sort()?,
453             _ => unimplemented!(),
454         }
455 
456         return Ok(());
457     }
458 
459     /// Action需要按照拓扑序执行
460     ///
461     /// Action::Build | Action::Install
462     fn run_with_topo_sort(&self) -> Result<(), SchedulerError> {
463         // 检查是否有不存在的依赖
464         let r = self.check_not_exists_dependency();
465         if r.is_err() {
466             error!("Error while checking tasks: {:?}", r);
467             return r;
468         }
469 
470         // 对调度实体进行拓扑排序
471         let r: Vec<Arc<SchedEntity>> = self.target.topo_sort();
472 
473         let action = self.action.clone();
474         let dragonos_dir = self.sysroot_dir.clone();
475         let id2entity = self.target.id2entity();
476         let count = r.len();
477 
478         // 启动守护线程
479         let handler = std::thread::spawn(move || {
480             Self::build_install_daemon(action, dragonos_dir, id2entity, count, &r)
481         });
482 
483         handler.join().expect("Could not join deamon");
484 
485         return Ok(());
486     }
487 
488     /// Action不需要按照拓扑序执行
489     fn run_without_topo_sort(&self) -> Result<(), SchedulerError> {
490         // 启动守护线程
491         let action = self.action.clone();
492         let dragonos_dir = self.sysroot_dir.clone();
493         let mut r = self.target.entities();
494         let handler = std::thread::spawn(move || {
495             Self::clean_daemon(action, dragonos_dir, &mut r);
496         });
497 
498         handler.join().expect("Could not join deamon");
499         return Ok(());
500     }
501 
502     pub fn execute(action: Action, dragonos_dir: PathBuf, entity: Arc<SchedEntity>) {
503         let mut executor = Executor::new(entity.clone(), action.clone(), dragonos_dir.clone())
504             .map_err(|e| {
505                 error!(
506                     "Error while creating executor for task {} : {:?}",
507                     entity.task().name_version(),
508                     e
509                 );
510                 exit(-1);
511             })
512             .unwrap();
513 
514         executor
515             .execute()
516             .map_err(|e| {
517                 error!(
518                     "Error while executing task {} : {:?}",
519                     entity.task().name_version(),
520                     e
521                 );
522                 exit(-1);
523             })
524             .unwrap();
525     }
526 
527     /// 构建和安装DADK任务的守护线程
528     ///
529     /// ## 参数
530     ///
531     /// - `action` : 要执行的操作
532     /// - `dragonos_dir` : DragonOS sysroot在主机上的路径
533     /// - `id2entity` : DADK任务id与实体映射表
534     /// - `count` : 当前剩余任务数
535     /// - `r` : 总任务实体表
536     ///
537     /// ## 返回值
538     ///
539     /// 无
540     pub fn build_install_daemon(
541         action: Action,
542         dragonos_dir: PathBuf,
543         id2entity: BTreeMap<i32, Arc<SchedEntity>>,
544         mut count: usize,
545         r: &Vec<Arc<SchedEntity>>,
546     ) {
547         let mut guard = TASK_DEQUE.lock().unwrap();
548         // 初始化0入度的任务实体
549         let mut zero_entity: Vec<Arc<SchedEntity>> = Vec::new();
550         for e in r.iter() {
551             if e.indegree() == 0 {
552                 zero_entity.push(e.clone());
553             }
554         }
555 
556         while count > 0 {
557             // 将入度为0的任务实体加入任务队列中,直至没有入度为0的任务实体 或 任务队列满了
558             while !zero_entity.is_empty()
559                 && guard.build_install_task(
560                     action.clone(),
561                     dragonos_dir.clone(),
562                     zero_entity.last().unwrap().clone(),
563                 )
564             {
565                 zero_entity.pop();
566             }
567 
568             let queue = guard.queue_mut();
569             // 如果任务线程已完成,将其从任务队列中删除,并把它的子节点入度减1,如果有0入度子节点,则加入zero_entity,后续可以加入任务队列中
570             queue.retain(|x| {
571                 if x.is_finished() {
572                     count -= 1;
573                     let tid = x.thread().id();
574                     let eid = *TID_EID.lock().unwrap().get(&tid).unwrap();
575                     let entity = id2entity.get(&eid).unwrap();
576                     let zero = entity.sub_children_indegree();
577                     for e in zero.iter() {
578                         zero_entity.push(e.clone());
579                     }
580                     return false;
581                 }
582                 return true;
583             })
584         }
585     }
586 
587     /// 清理DADK任务的守护线程
588     ///
589     /// ## 参数
590     ///
591     /// - `action` : 要执行的操作
592     /// - `dragonos_dir` : DragonOS sysroot在主机上的路径
593     /// - `r` : 总任务实体表
594     ///
595     /// ## 返回值
596     ///
597     /// 无
598     pub fn clean_daemon(action: Action, dragonos_dir: PathBuf, r: &mut Vec<Arc<SchedEntity>>) {
599         let mut guard = TASK_DEQUE.lock().unwrap();
600         while !guard.queue().is_empty() && !r.is_empty() {
601             guard.clean_task(action, dragonos_dir.clone(), r.pop().unwrap().clone());
602         }
603     }
604 
605     /// # 检查是否有不存在的依赖
606     ///
607     /// 如果某个任务的dependency中的任务不存在,则返回错误
608     fn check_not_exists_dependency(&self) -> Result<(), SchedulerError> {
609         for entity in self.target.entities().iter() {
610             for dependency in entity.task().depends.iter() {
611                 let name_version = (dependency.name.clone(), dependency.version.clone());
612                 if !self
613                     .target
614                     .get_by_name_version(&name_version.0, &name_version.1)
615                     .is_some()
616                 {
617                     return Err(SchedulerError::DependencyNotFound(
618                         entity.clone(),
619                         format!("name:{}, version:{}", name_version.0, name_version.1,),
620                     ));
621                 }
622             }
623         }
624 
625         return Ok(());
626     }
627 }
628 
629 /// # 环形依赖错误路径
630 ///
631 /// 本结构体用于在回溯过程中记录环形依赖的路径。
632 ///
633 /// 例如,假设有如下依赖关系:
634 ///
635 /// ```text
636 /// A -> B -> C -> D -> A
637 /// ```
638 ///
639 /// 则在DFS回溯过程中,会依次记录如下路径:
640 ///
641 /// ```text
642 /// D -> A
643 /// C -> D
644 /// B -> C
645 /// A -> B
646 pub struct DependencyCycleError {
647     /// # 起始实体
648     /// 本错误的起始实体,即环形依赖的起点
649     head_entity: Arc<SchedEntity>,
650     /// 是否停止传播
651     stop_propagation: bool,
652     /// 依赖关系
653     dependencies: Vec<(Arc<SchedEntity>, Arc<SchedEntity>)>,
654 }
655 
656 impl DependencyCycleError {
657     pub fn new(head_entity: Arc<SchedEntity>) -> Self {
658         Self {
659             head_entity,
660             stop_propagation: false,
661             dependencies: Vec::new(),
662         }
663     }
664 
665     pub fn add(&mut self, current: Arc<SchedEntity>, dependency: Arc<SchedEntity>) {
666         self.dependencies.push((current, dependency));
667     }
668 
669     pub fn stop_propagation(&mut self) {
670         self.stop_propagation = true;
671     }
672 
673     #[allow(dead_code)]
674     pub fn dependencies(&self) -> &Vec<(Arc<SchedEntity>, Arc<SchedEntity>)> {
675         &self.dependencies
676     }
677 
678     pub fn display(&self) -> String {
679         let mut tmp = self.dependencies.clone();
680         tmp.reverse();
681 
682         let mut ret = format!("Dependency cycle detected: \nStart ->\n");
683         for (current, dep) in tmp.iter() {
684             ret.push_str(&format!(
685                 "->\t{} ({})\t--depends-->\t{} ({})\n",
686                 current.task().name_version(),
687                 current.file_path().display(),
688                 dep.task().name_version(),
689                 dep.file_path().display()
690             ));
691         }
692         ret.push_str("-> End");
693         return ret;
694     }
695 }
696