1 // SPDX-License-Identifier: GPL-2.0
2
3 use crate::helpers::{parse_generics, Generics};
4 use proc_macro::{TokenStream, TokenTree};
5
derive(input: TokenStream) -> TokenStream6 pub(crate) fn derive(input: TokenStream) -> TokenStream {
7 let (
8 Generics {
9 impl_generics,
10 ty_generics,
11 },
12 mut rest,
13 ) = parse_generics(input);
14 // This should be the body of the struct `{...}`.
15 let last = rest.pop();
16 // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
17 let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
18 // Are we inside of a generic where we want to add `Zeroable`?
19 let mut in_generic = !impl_generics.is_empty();
20 // Have we already inserted `Zeroable`?
21 let mut inserted = false;
22 // Level of `<>` nestings.
23 let mut nested = 0;
24 for tt in impl_generics {
25 match &tt {
26 // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
27 TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
28 if in_generic && !inserted {
29 new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
30 }
31 in_generic = true;
32 inserted = false;
33 new_impl_generics.push(tt);
34 }
35 // If we find `'`, then we are entering a lifetime.
36 TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
37 in_generic = false;
38 new_impl_generics.push(tt);
39 }
40 TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
41 new_impl_generics.push(tt);
42 if in_generic {
43 new_impl_generics.extend(quote! { ::kernel::init::Zeroable + });
44 inserted = true;
45 }
46 }
47 TokenTree::Punct(p) if p.as_char() == '<' => {
48 nested += 1;
49 new_impl_generics.push(tt);
50 }
51 TokenTree::Punct(p) if p.as_char() == '>' => {
52 assert!(nested > 0);
53 nested -= 1;
54 new_impl_generics.push(tt);
55 }
56 _ => new_impl_generics.push(tt),
57 }
58 }
59 assert_eq!(nested, 0);
60 if in_generic && !inserted {
61 new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
62 }
63 quote! {
64 ::kernel::__derive_zeroable!(
65 parse_input:
66 @sig(#(#rest)*),
67 @impl_generics(#(#new_impl_generics)*),
68 @ty_generics(#(#ty_generics)*),
69 @body(#last),
70 );
71 }
72 }
73