xref: /DragonOS/kernel/crates/unified-init/macros/src/lib.rs (revision 81294aa2e6b257f0de5e3c28c3f3c89798330836)
1 extern crate alloc;
2 
3 extern crate quote;
4 use proc_macro::TokenStream;
5 use quote::quote;
6 use syn::{
7     __private::ToTokens,
8     parse::{self, Parse, ParseStream},
9     spanned::Spanned,
10     ItemFn, Path,
11 };
12 use uuid::Uuid;
13 
14 /// 统一初始化宏,
15 /// 用于将函数注册到统一初始化列表中
16 ///
17 /// ## 用法
18 ///
19 /// ```rust
20 /// use system_error::SystemError;
21 /// use unified_init::define_unified_initializer_slice;
22 /// use unified_init_macros::unified_init;
23 ///
24 /// /// 初始化函数都将会被放到这个列表中
25 /// define_unified_initializer_slice!(INITIALIZER_LIST);
26 ///
27 /// #[unified_init(INITIALIZER_LIST)]
28 /// fn init1() -> Result<(), SystemError> {
29 ///    Ok(())
30 /// }
31 ///
32 /// #[unified_init(INITIALIZER_LIST)]
33 /// fn init2() -> Result<(), SystemError> {
34 ///    Ok(())
35 /// }
36 ///
37 /// fn main() {
38 ///     assert_eq!(INITIALIZER_LIST.len(), 2);
39 /// }
40 ///
41 /// ```
42 #[proc_macro_attribute]
43 pub fn unified_init(args: TokenStream, input: TokenStream) -> TokenStream {
44     do_unified_init(args, input)
45         .unwrap_or_else(|e| e.to_compile_error().into())
46         .into()
47 }
48 
49 fn do_unified_init(args: TokenStream, input: TokenStream) -> syn::Result<proc_macro2::TokenStream> {
50     // 解析属性数
51     let attr_arg = syn::parse::<UnifiedInitArg>(args)?;
52     // 获取当前函数
53     let function = syn::parse::<ItemFn>(input)?;
54     // 检查函数签名
55     check_function_signature(&function)?;
56 
57     // 添加#[::linkme::distributed_slice(attr_args.initializer_instance)]属性
58     let target_slice = attr_arg.initializer_instance.get_ident().unwrap();
59 
60     // 在旁边添加一个UnifiedInitializer
61     let initializer =
62         generate_unified_initializer(&function, &target_slice, function.sig.ident.to_string())?;
63 
64     // 拼接
65     let mut output = proc_macro2::TokenStream::new();
66     output.extend(function.into_token_stream());
67     output.extend(initializer);
68 
69     Ok(output)
70 }
71 
72 /// 检查函数签名是否满足要求
73 /// 函数签名应该为
74 ///
75 /// ```rust
76 /// use system_error::SystemError;
77 /// fn xxx() -> Result<(), SystemError> {
78 ///     Ok(())
79 /// }
80 /// ```
81 fn check_function_signature(function: &ItemFn) -> syn::Result<()> {
82     // 检查函数签名
83     if function.sig.inputs.len() != 0 {
84         return Err(syn::Error::new(
85             function.sig.inputs.span(),
86             "Expected no arguments",
87         ));
88     }
89 
90     if let syn::ReturnType::Type(_, ty) = &function.sig.output {
91         // 确认返回类型为 Result<(), SystemError>
92         // 解析类型
93 
94         let output_type: syn::Type = syn::parse2(ty.clone().into_token_stream())?;
95 
96         // 检查类型是否为 Result<(), SystemError>
97         if let syn::Type::Path(type_path) = output_type {
98             if type_path.path.segments.last().unwrap().ident == "Result" {
99                 // 检查泛型参数,看看是否满足 Result<(), SystemError>
100                 if let syn::PathArguments::AngleBracketed(generic_args) =
101                     type_path.path.segments.last().unwrap().arguments.clone()
102                 {
103                     if generic_args.args.len() != 2 {
104                         return Err(syn::Error::new(
105                             generic_args.span(),
106                             "Expected two generic arguments",
107                         ));
108                     }
109 
110                     // 检查第一个泛型参数是否为()
111                     if let syn::GenericArgument::Type(type_arg) = generic_args.args.first().unwrap()
112                     {
113                         if let syn::Type::Tuple(tuple) = type_arg {
114                             if tuple.elems.len() != 0 {
115                                 return Err(syn::Error::new(tuple.span(), "Expected empty tuple"));
116                             }
117                         } else {
118                             return Err(syn::Error::new(type_arg.span(), "Expected empty tuple"));
119                         }
120                     } else {
121                         return Err(syn::Error::new(
122                             generic_args.span(),
123                             "Expected first generic argument to be a type",
124                         ));
125                     }
126 
127                     // 检查第二个泛型参数是否为SystemError
128                     if let syn::GenericArgument::Type(type_arg) = generic_args.args.last().unwrap()
129                     {
130                         if let syn::Type::Path(type_path) = type_arg {
131                             if type_path.path.segments.last().unwrap().ident == "SystemError" {
132                                 // 类型匹配,返回 Ok
133                                 return Ok(());
134                             }
135                         }
136                     } else {
137                         return Err(syn::Error::new(
138                             generic_args.span(),
139                             "Expected second generic argument to be a type",
140                         ));
141                     }
142 
143                     return Err(syn::Error::new(
144                         generic_args.span(),
145                         "Expected second generic argument to be SystemError",
146                     ));
147                 }
148 
149                 return Ok(());
150             }
151         }
152     }
153 
154     Err(syn::Error::new(
155         function.sig.output.span(),
156         "Expected -> Result<(), SystemError>",
157     ))
158 }
159 
160 /// 生成UnifiedInitializer全局变量
161 fn generate_unified_initializer(
162     function: &ItemFn,
163     target_slice: &syn::Ident,
164     raw_initializer_name: String,
165 ) -> syn::Result<proc_macro2::TokenStream> {
166     let initializer_name = format!(
167         "unified_initializer_{}_{}",
168         raw_initializer_name,
169         Uuid::new_v4().to_simple().to_string().to_ascii_uppercase()[..8].to_string()
170     )
171     .to_ascii_uppercase();
172 
173     // 获取函数的全名
174     let initializer_name_ident = syn::Ident::new(&initializer_name, function.sig.ident.span());
175 
176     let function_ident = &function.sig.ident;
177 
178     // 生成UnifiedInitializer
179     let initializer = quote! {
180         #[::linkme::distributed_slice(#target_slice)]
181         static #initializer_name_ident: unified_init::UnifiedInitializer = ::unified_init::UnifiedInitializer::new(#raw_initializer_name, &(#function_ident as ::unified_init::UnifiedInitFunction));
182     };
183 
184     Ok(initializer)
185 }
186 
187 struct UnifiedInitArg {
188     initializer_instance: Path,
189 }
190 
191 impl Parse for UnifiedInitArg {
192     fn parse(input: ParseStream) -> parse::Result<Self> {
193         let mut initializer_instance = None;
194 
195         while !input.is_empty() {
196             if initializer_instance.is_some() {
197                 return Err(parse::Error::new(
198                     input.span(),
199                     "Expected exactly one initializer instance",
200                 ));
201             }
202             // 解析Ident
203             let ident = input.parse::<syn::Ident>()?;
204 
205             // 将Ident转换为Path
206             let initializer = syn::Path::from(ident);
207 
208             initializer_instance = Some(initializer);
209         }
210 
211         if initializer_instance.is_none() {
212             return Err(parse::Error::new(
213                 input.span(),
214                 "Expected exactly one initializer instance",
215             ));
216         }
217 
218         // 判断是否为标识符
219         if initializer_instance.as_ref().unwrap().get_ident().is_none() {
220             return Err(parse::Error::new(
221                 initializer_instance.span(),
222                 "Expected identifier",
223             ));
224         }
225 
226         Ok(UnifiedInitArg {
227             initializer_instance: initializer_instance.unwrap(),
228         })
229     }
230 }
231