1 // SPDX-License-Identifier: GPL-2.0
2 
3 #[cfg(not(kernel))]
4 use proc_macro2 as proc_macro;
5 
6 use crate::helpers::{parse_generics, Generics};
7 use proc_macro::{TokenStream, TokenTree};
8 
derive(input: TokenStream) -> TokenStream9 pub(crate) fn derive(input: TokenStream) -> TokenStream {
10     let (
11         Generics {
12             impl_generics,
13             decl_generics: _,
14             ty_generics,
15         },
16         mut rest,
17     ) = parse_generics(input);
18     // This should be the body of the struct `{...}`.
19     let last = rest.pop();
20     // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
21     let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
22     // Are we inside of a generic where we want to add `Zeroable`?
23     let mut in_generic = !impl_generics.is_empty();
24     // Have we already inserted `Zeroable`?
25     let mut inserted = false;
26     // Level of `<>` nestings.
27     let mut nested = 0;
28     for tt in impl_generics {
29         match &tt {
30             // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
31             TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
32                 if in_generic && !inserted {
33                     new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
34                 }
35                 in_generic = true;
36                 inserted = false;
37                 new_impl_generics.push(tt);
38             }
39             // If we find `'`, then we are entering a lifetime.
40             TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
41                 in_generic = false;
42                 new_impl_generics.push(tt);
43             }
44             TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
45                 new_impl_generics.push(tt);
46                 if in_generic {
47                     new_impl_generics.extend(quote! { ::pin_init::Zeroable + });
48                     inserted = true;
49                 }
50             }
51             TokenTree::Punct(p) if p.as_char() == '<' => {
52                 nested += 1;
53                 new_impl_generics.push(tt);
54             }
55             TokenTree::Punct(p) if p.as_char() == '>' => {
56                 assert!(nested > 0);
57                 nested -= 1;
58                 new_impl_generics.push(tt);
59             }
60             _ => new_impl_generics.push(tt),
61         }
62     }
63     assert_eq!(nested, 0);
64     if in_generic && !inserted {
65         new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
66     }
67     quote! {
68         ::pin_init::__derive_zeroable!(
69             parse_input:
70                 @sig(#(#rest)*),
71                 @impl_generics(#(#new_impl_generics)*),
72                 @ty_generics(#(#ty_generics)*),
73                 @body(#last),
74         );
75     }
76 }
77