xref: /linux/rust/pin-init/internal/src/helpers.rs (revision 4f9786035f9e519db41375818e1d0b5f20da2f10)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 #[cfg(not(kernel))]
4 use proc_macro2 as proc_macro;
5 
6 use proc_macro::{TokenStream, TokenTree};
7 
8 /// Parsed generics.
9 ///
10 /// See the field documentation for an explanation what each of the fields represents.
11 ///
12 /// # Examples
13 ///
14 /// ```rust,ignore
15 /// # let input = todo!();
16 /// let (Generics { decl_generics, impl_generics, ty_generics }, rest) = parse_generics(input);
17 /// quote! {
18 ///     struct Foo<$($decl_generics)*> {
19 ///         // ...
20 ///     }
21 ///
22 ///     impl<$impl_generics> Foo<$ty_generics> {
23 ///         fn foo() {
24 ///             // ...
25 ///         }
26 ///     }
27 /// }
28 /// ```
29 pub(crate) struct Generics {
30     /// The generics with bounds and default values (e.g. `T: Clone, const N: usize = 0`).
31     ///
32     /// Use this on type definitions e.g. `struct Foo<$decl_generics> ...` (or `union`/`enum`).
33     pub(crate) decl_generics: Vec<TokenTree>,
34     /// The generics with bounds (e.g. `T: Clone, const N: usize`).
35     ///
36     /// Use this on `impl` blocks e.g. `impl<$impl_generics> Trait for ...`.
37     pub(crate) impl_generics: Vec<TokenTree>,
38     /// The generics without bounds and without default values (e.g. `T, N`).
39     ///
40     /// Use this when you use the type that is declared with these generics e.g.
41     /// `Foo<$ty_generics>`.
42     pub(crate) ty_generics: Vec<TokenTree>,
43 }
44 
45 /// Parses the given `TokenStream` into `Generics` and the rest.
46 ///
47 /// The generics are not present in the rest, but a where clause might remain.
parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>)48 pub(crate) fn parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>) {
49     // The generics with bounds and default values.
50     let mut decl_generics = vec![];
51     // `impl_generics`, the declared generics with their bounds.
52     let mut impl_generics = vec![];
53     // Only the names of the generics, without any bounds.
54     let mut ty_generics = vec![];
55     // Tokens not related to the generics e.g. the `where` token and definition.
56     let mut rest = vec![];
57     // The current level of `<`.
58     let mut nesting = 0;
59     let mut toks = input.into_iter();
60     // If we are at the beginning of a generic parameter.
61     let mut at_start = true;
62     let mut skip_until_comma = false;
63     while let Some(tt) = toks.next() {
64         if nesting == 1 && matches!(&tt, TokenTree::Punct(p) if p.as_char() == '>') {
65             // Found the end of the generics.
66             break;
67         } else if nesting >= 1 {
68             decl_generics.push(tt.clone());
69         }
70         match tt.clone() {
71             TokenTree::Punct(p) if p.as_char() == '<' => {
72                 if nesting >= 1 && !skip_until_comma {
73                     // This is inside of the generics and part of some bound.
74                     impl_generics.push(tt);
75                 }
76                 nesting += 1;
77             }
78             TokenTree::Punct(p) if p.as_char() == '>' => {
79                 // This is a parsing error, so we just end it here.
80                 if nesting == 0 {
81                     break;
82                 } else {
83                     nesting -= 1;
84                     if nesting >= 1 && !skip_until_comma {
85                         // We are still inside of the generics and part of some bound.
86                         impl_generics.push(tt);
87                     }
88                 }
89             }
90             TokenTree::Punct(p) if skip_until_comma && p.as_char() == ',' => {
91                 if nesting == 1 {
92                     impl_generics.push(tt.clone());
93                     impl_generics.push(tt);
94                     skip_until_comma = false;
95                 }
96             }
97             _ if !skip_until_comma => {
98                 match nesting {
99                     // If we haven't entered the generics yet, we still want to keep these tokens.
100                     0 => rest.push(tt),
101                     1 => {
102                         // Here depending on the token, it might be a generic variable name.
103                         match tt.clone() {
104                             TokenTree::Ident(i) if at_start && i.to_string() == "const" => {
105                                 let Some(name) = toks.next() else {
106                                     // Parsing error.
107                                     break;
108                                 };
109                                 impl_generics.push(tt);
110                                 impl_generics.push(name.clone());
111                                 ty_generics.push(name.clone());
112                                 decl_generics.push(name);
113                                 at_start = false;
114                             }
115                             TokenTree::Ident(_) if at_start => {
116                                 impl_generics.push(tt.clone());
117                                 ty_generics.push(tt);
118                                 at_start = false;
119                             }
120                             TokenTree::Punct(p) if p.as_char() == ',' => {
121                                 impl_generics.push(tt.clone());
122                                 ty_generics.push(tt);
123                                 at_start = true;
124                             }
125                             // Lifetimes begin with `'`.
126                             TokenTree::Punct(p) if p.as_char() == '\'' && at_start => {
127                                 impl_generics.push(tt.clone());
128                                 ty_generics.push(tt);
129                             }
130                             // Generics can have default values, we skip these.
131                             TokenTree::Punct(p) if p.as_char() == '=' => {
132                                 skip_until_comma = true;
133                             }
134                             _ => impl_generics.push(tt),
135                         }
136                     }
137                     _ => impl_generics.push(tt),
138                 }
139             }
140             _ => {}
141         }
142     }
143     rest.extend(toks);
144     (
145         Generics {
146             impl_generics,
147             decl_generics,
148             ty_generics,
149         },
150         rest,
151     )
152 }
153