1 extern crate alloc;
2
3 extern crate quote;
4 use proc_macro::TokenStream;
5 use quote::quote;
6 use syn::{
7 __private::ToTokens,
8 parse::{self, Parse, ParseStream},
9 spanned::Spanned,
10 ItemFn, Path,
11 };
12 use uuid::Uuid;
13
14 /// 统一初始化宏,
15 /// 用于将函数注册到统一初始化列表中
16 ///
17 /// ## 用法
18 ///
19 /// ```rust
20 /// use system_error::SystemError;
21 /// use unified_init::define_unified_initializer_slice;
22 /// use unified_init_macros::unified_init;
23 ///
24 /// /// 初始化函数都将会被放到这个列表中
25 /// define_unified_initializer_slice!(INITIALIZER_LIST);
26 ///
27 /// #[unified_init(INITIALIZER_LIST)]
28 /// fn init1() -> Result<(), SystemError> {
29 /// Ok(())
30 /// }
31 ///
32 /// #[unified_init(INITIALIZER_LIST)]
33 /// fn init2() -> Result<(), SystemError> {
34 /// Ok(())
35 /// }
36 ///
37 /// fn main() {
38 /// assert_eq!(INITIALIZER_LIST.len(), 2);
39 /// }
40 ///
41 /// ```
42 #[proc_macro_attribute]
unified_init(args: TokenStream, input: TokenStream) -> TokenStream43 pub fn unified_init(args: TokenStream, input: TokenStream) -> TokenStream {
44 do_unified_init(args, input)
45 .unwrap_or_else(|e| e.to_compile_error())
46 .into()
47 }
48
do_unified_init(args: TokenStream, input: TokenStream) -> syn::Result<proc_macro2::TokenStream>49 fn do_unified_init(args: TokenStream, input: TokenStream) -> syn::Result<proc_macro2::TokenStream> {
50 // 解析属性数
51 let attr_arg = syn::parse::<UnifiedInitArg>(args)?;
52 // 获取当前函数
53 let function = syn::parse::<ItemFn>(input)?;
54 // 检查函数签名
55 check_function_signature(&function)?;
56
57 // 添加#[::linkme::distributed_slice(attr_args.initializer_instance)]属性
58 let target_slice = attr_arg.initializer_instance.get_ident().unwrap();
59
60 // 在旁边添加一个UnifiedInitializer
61 let initializer =
62 generate_unified_initializer(&function, target_slice, function.sig.ident.to_string())?;
63
64 // 拼接
65 let mut output = proc_macro2::TokenStream::new();
66 output.extend(function.into_token_stream());
67 output.extend(initializer);
68
69 Ok(output)
70 }
71
72 /// 检查函数签名是否满足要求
73 /// 函数签名应该为
74 ///
75 /// ```rust
76 /// use system_error::SystemError;
77 /// fn xxx() -> Result<(), SystemError> {
78 /// Ok(())
79 /// }
80 /// ```
check_function_signature(function: &ItemFn) -> syn::Result<()>81 fn check_function_signature(function: &ItemFn) -> syn::Result<()> {
82 // 检查函数签名
83 if !function.sig.inputs.is_empty() {
84 return Err(syn::Error::new(
85 function.sig.inputs.span(),
86 "Expected no arguments",
87 ));
88 }
89
90 if let syn::ReturnType::Type(_, ty) = &function.sig.output {
91 // 确认返回类型为 Result<(), SystemError>
92 // 解析类型
93
94 let output_type: syn::Type = syn::parse2(ty.clone().into_token_stream())?;
95
96 // 检查类型是否为 Result<(), SystemError>
97 if let syn::Type::Path(type_path) = output_type {
98 if type_path.path.segments.last().unwrap().ident == "Result" {
99 // 检查泛型参数,看看是否满足 Result<(), SystemError>
100 if let syn::PathArguments::AngleBracketed(generic_args) =
101 type_path.path.segments.last().unwrap().arguments.clone()
102 {
103 if generic_args.args.len() != 2 {
104 return Err(syn::Error::new(
105 generic_args.span(),
106 "Expected two generic arguments",
107 ));
108 }
109
110 // 检查第一个泛型参数是否为()
111 if let syn::GenericArgument::Type(type_arg) = generic_args.args.first().unwrap()
112 {
113 if let syn::Type::Tuple(tuple) = type_arg {
114 if !tuple.elems.is_empty() {
115 return Err(syn::Error::new(tuple.span(), "Expected empty tuple"));
116 }
117 } else {
118 return Err(syn::Error::new(type_arg.span(), "Expected empty tuple"));
119 }
120 } else {
121 return Err(syn::Error::new(
122 generic_args.span(),
123 "Expected first generic argument to be a type",
124 ));
125 }
126
127 // 检查第二个泛型参数是否为SystemError
128 if let syn::GenericArgument::Type(type_arg) = generic_args.args.last().unwrap()
129 {
130 if let syn::Type::Path(type_path) = type_arg {
131 if type_path.path.segments.last().unwrap().ident == "SystemError" {
132 // 类型匹配,返回 Ok
133 return Ok(());
134 }
135 }
136 } else {
137 return Err(syn::Error::new(
138 generic_args.span(),
139 "Expected second generic argument to be a type",
140 ));
141 }
142
143 return Err(syn::Error::new(
144 generic_args.span(),
145 "Expected second generic argument to be SystemError",
146 ));
147 }
148
149 return Ok(());
150 }
151 }
152 }
153
154 Err(syn::Error::new(
155 function.sig.output.span(),
156 "Expected -> Result<(), SystemError>",
157 ))
158 }
159
160 /// 生成UnifiedInitializer全局变量
generate_unified_initializer( function: &ItemFn, target_slice: &syn::Ident, raw_initializer_name: String, ) -> syn::Result<proc_macro2::TokenStream>161 fn generate_unified_initializer(
162 function: &ItemFn,
163 target_slice: &syn::Ident,
164 raw_initializer_name: String,
165 ) -> syn::Result<proc_macro2::TokenStream> {
166 let initializer_name = format!(
167 "unified_initializer_{}_{}",
168 raw_initializer_name,
169 &Uuid::new_v4().to_simple().to_string().to_ascii_uppercase()[..8]
170 )
171 .to_ascii_uppercase();
172
173 // 获取函数的全名
174 let initializer_name_ident = syn::Ident::new(&initializer_name, function.sig.ident.span());
175
176 let function_ident = &function.sig.ident;
177
178 // 生成UnifiedInitializer
179 let initializer = quote! {
180 #[::linkme::distributed_slice(#target_slice)]
181 static #initializer_name_ident: unified_init::UnifiedInitializer = ::unified_init::UnifiedInitializer::new(#raw_initializer_name, &(#function_ident as ::unified_init::UnifiedInitFunction));
182 };
183
184 Ok(initializer)
185 }
186
187 struct UnifiedInitArg {
188 initializer_instance: Path,
189 }
190
191 impl Parse for UnifiedInitArg {
parse(input: ParseStream) -> parse::Result<Self>192 fn parse(input: ParseStream) -> parse::Result<Self> {
193 let mut initializer_instance = None;
194
195 while !input.is_empty() {
196 if initializer_instance.is_some() {
197 return Err(parse::Error::new(
198 input.span(),
199 "Expected exactly one initializer instance",
200 ));
201 }
202 // 解析Ident
203 let ident = input.parse::<syn::Ident>()?;
204
205 // 将Ident转换为Path
206 let initializer = syn::Path::from(ident);
207
208 initializer_instance = Some(initializer);
209 }
210
211 if initializer_instance.is_none() {
212 return Err(parse::Error::new(
213 input.span(),
214 "Expected exactly one initializer instance",
215 ));
216 }
217
218 // 判断是否为标识符
219 if initializer_instance.as_ref().unwrap().get_ident().is_none() {
220 return Err(parse::Error::new(
221 initializer_instance.span(),
222 "Expected identifier",
223 ));
224 }
225
226 Ok(UnifiedInitArg {
227 initializer_instance: initializer_instance.unwrap(),
228 })
229 }
230 }
231