xref: /linux/rust/pin-init/internal/src/init.rs (revision 26ff969926a08eee069767ddbbbc301adbcd9676)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::{format_ident, quote, quote_spanned};
5 use syn::{
6     braced,
7     parse::{End, Parse},
8     parse_quote,
9     punctuated::Punctuated,
10     spanned::Spanned,
11     token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12 };
13 
14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15 
16 pub(crate) struct Initializer {
17     attrs: Vec<InitializerAttribute>,
18     this: Option<This>,
19     path: Path,
20     brace_token: token::Brace,
21     fields: Punctuated<InitializerField, Token![,]>,
22     rest: Option<(Token![..], Expr)>,
23     error: Option<(Token![?], Type)>,
24 }
25 
26 struct This {
27     _and_token: Token![&],
28     ident: Ident,
29     _in_token: Token![in],
30 }
31 
32 struct InitializerField {
33     attrs: Vec<Attribute>,
34     kind: InitializerKind,
35 }
36 
37 enum InitializerKind {
38     Value {
39         ident: Ident,
40         value: Option<(Token![:], Expr)>,
41     },
42     Init {
43         ident: Ident,
44         _left_arrow_token: Token![<-],
45         value: Expr,
46     },
47     Code {
48         _underscore_token: Token![_],
49         _colon_token: Token![:],
50         block: Block,
51     },
52 }
53 
54 impl InitializerKind {
55     fn ident(&self) -> Option<&Ident> {
56         match self {
57             Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58             Self::Code { .. } => None,
59         }
60     }
61 }
62 
63 enum InitializerAttribute {
64     DefaultError(DefaultErrorAttribute),
65 }
66 
67 struct DefaultErrorAttribute {
68     ty: Box<Type>,
69 }
70 
71 pub(crate) fn expand(
72     Initializer {
73         attrs,
74         this,
75         path,
76         brace_token,
77         fields,
78         rest,
79         error,
80     }: Initializer,
81     default_error: Option<&'static str>,
82     pinned: bool,
83     dcx: &mut DiagCtxt,
84 ) -> Result<TokenStream, ErrorGuaranteed> {
85     let error = error.map_or_else(
86         || {
87             if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
88                 #[expect(irrefutable_let_patterns)]
89                 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90                     Some(ty.clone())
91                 } else {
92                     acc
93                 }
94             }) {
95                 default_error
96             } else if let Some(default_error) = default_error {
97                 syn::parse_str(default_error).unwrap()
98             } else {
99                 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100                 parse_quote!(::core::convert::Infallible)
101             }
102         },
103         |(_, err)| Box::new(err),
104     );
105     let slot = format_ident!("slot");
106     let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
107         (
108             format_ident!("HasPinData"),
109             format_ident!("PinData"),
110             format_ident!("__pin_data"),
111             format_ident!("pin_init_from_closure"),
112         )
113     } else {
114         (
115             format_ident!("HasInitData"),
116             format_ident!("InitData"),
117             format_ident!("__init_data"),
118             format_ident!("init_from_closure"),
119         )
120     };
121     let init_kind = get_init_kind(rest, dcx);
122     let zeroable_check = match init_kind {
123         InitKind::Normal => quote!(),
124         InitKind::Zeroing => quote! {
125             // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
126             // Therefore we check if the struct implements `Zeroable` and then zero the memory.
127             // This allows us to also remove the check that all fields are present (since we
128             // already set the memory to zero and that is a valid bit pattern).
129             fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
130             where T: ::pin_init::Zeroable
131             {}
132             // Ensure that the struct is indeed `Zeroable`.
133             assert_zeroable(#slot);
134             // SAFETY: The type implements `Zeroable` by the check above.
135             unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
136         },
137     };
138     let this = match this {
139         None => quote!(),
140         Some(This { ident, .. }) => quote! {
141             // Create the `this` so it can be referenced by the user inside of the
142             // expressions creating the individual fields.
143             let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
144         },
145     };
146     // `mixed_site` ensures that the data is not accessible to the user-controlled code.
147     let data = Ident::new("__data", Span::mixed_site());
148     let init_fields = init_fields(&fields, pinned, &data, &slot);
149     let field_check = make_field_check(&fields, init_kind, &path);
150     Ok(quote! {{
151         // Get the data about fields from the supplied type.
152         // SAFETY: TODO
153         let #data = unsafe {
154             use ::pin_init::__internal::#has_data_trait;
155             // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
156             // generics (which need to be present with that syntax).
157             #path::#get_data()
158         };
159         // Ensure that `#data` really is of type `#data` and help with type inference:
160         let init = ::pin_init::__internal::#data_trait::make_closure::<_, #error>(
161             #data,
162             move |slot| {
163                 #zeroable_check
164                 #this
165                 #init_fields
166                 #field_check
167                 // SAFETY: we are the `init!` macro that is allowed to call this.
168                 Ok(unsafe { ::pin_init::__internal::InitOk::new() })
169             }
170         );
171         let init = move |slot| -> ::core::result::Result<(), #error> {
172             init(slot).map(|__InitOk| ())
173         };
174         // SAFETY: TODO
175         let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
176         // FIXME: this let binding is required to avoid a compiler error (cycle when computing the
177         // opaque type returned by this function) before Rust 1.81. Remove after MSRV bump.
178         #[allow(
179             clippy::let_and_return,
180             reason = "some clippy versions warn about the let binding"
181         )]
182         init
183     }})
184 }
185 
186 enum InitKind {
187     Normal,
188     Zeroing,
189 }
190 
191 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
192     let Some((dotdot, expr)) = rest else {
193         return InitKind::Normal;
194     };
195     match &expr {
196         Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
197             Expr::Path(ExprPath {
198                 attrs,
199                 qself: None,
200                 path:
201                     Path {
202                         leading_colon: None,
203                         segments,
204                     },
205             }) if attrs.is_empty()
206                 && segments.len() == 2
207                 && segments[0].ident == "Zeroable"
208                 && segments[0].arguments.is_none()
209                 && segments[1].ident == "init_zeroed"
210                 && segments[1].arguments.is_none() =>
211             {
212                 return InitKind::Zeroing;
213             }
214             _ => {}
215         },
216         _ => {}
217     }
218     dcx.error(
219         dotdot.span().join(expr.span()).unwrap_or(expr.span()),
220         "expected nothing or `..Zeroable::init_zeroed()`.",
221     );
222     InitKind::Normal
223 }
224 
225 /// Generate the code that initializes the fields of the struct using the initializers in `field`.
226 fn init_fields(
227     fields: &Punctuated<InitializerField, Token![,]>,
228     pinned: bool,
229     data: &Ident,
230     slot: &Ident,
231 ) -> TokenStream {
232     let mut guards = vec![];
233     let mut guard_attrs = vec![];
234     let mut res = TokenStream::new();
235     for InitializerField { attrs, kind } in fields {
236         let cfgs = {
237             let mut cfgs = attrs.clone();
238             cfgs.retain(|attr| attr.path().is_ident("cfg"));
239             cfgs
240         };
241         let init = match kind {
242             InitializerKind::Value { ident, value } => {
243                 let mut value_ident = ident.clone();
244                 let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
245                     // Setting the span of `value_ident` to `value`'s span improves error messages
246                     // when the type of `value` is wrong.
247                     value_ident.set_span(value.span());
248                     quote!(let #value_ident = #value;)
249                 });
250                 // Again span for better diagnostics
251                 let write = quote_spanned!(ident.span()=> ::core::ptr::write);
252                 // NOTE: the field accessor ensures that the initialized field is properly aligned.
253                 // Unaligned fields will cause the compiler to emit E0793. We do not support
254                 // unaligned fields since `Init::__init` requires an aligned pointer; the call to
255                 // `ptr::write` below has the same requirement.
256                 let accessor = if pinned {
257                     let project_ident = format_ident!("__project_{ident}");
258                     quote! {
259                         // SAFETY: TODO
260                         unsafe { #data.#project_ident(&mut (*#slot).#ident) }
261                     }
262                 } else {
263                     quote! {
264                         // SAFETY: TODO
265                         unsafe { &mut (*#slot).#ident }
266                     }
267                 };
268                 quote! {
269                     #(#attrs)*
270                     {
271                         #value_prep
272                         // SAFETY: TODO
273                         unsafe { #write(&raw mut (*#slot).#ident, #value_ident) };
274                     }
275                     #(#cfgs)*
276                     #[allow(unused_variables)]
277                     let #ident = #accessor;
278                 }
279             }
280             InitializerKind::Init { ident, value, .. } => {
281                 // Again span for better diagnostics
282                 let init = format_ident!("init", span = value.span());
283                 // NOTE: the field accessor ensures that the initialized field is properly aligned.
284                 // Unaligned fields will cause the compiler to emit E0793. We do not support
285                 // unaligned fields since `Init::__init` requires an aligned pointer; the call to
286                 // `ptr::write` below has the same requirement.
287                 let (value_init, accessor) = if pinned {
288                     let project_ident = format_ident!("__project_{ident}");
289                     (
290                         quote! {
291                             // SAFETY:
292                             // - `slot` is valid, because we are inside of an initializer closure, we
293                             //   return when an error/panic occurs.
294                             // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
295                             //   for `#ident`.
296                             unsafe { #data.#ident(&raw mut (*#slot).#ident, #init)? };
297                         },
298                         quote! {
299                             // SAFETY: TODO
300                             unsafe { #data.#project_ident(&mut (*#slot).#ident) }
301                         },
302                     )
303                 } else {
304                     (
305                         quote! {
306                             // SAFETY: `slot` is valid, because we are inside of an initializer
307                             // closure, we return when an error/panic occurs.
308                             unsafe {
309                                 ::pin_init::Init::__init(
310                                     #init,
311                                     &raw mut (*#slot).#ident,
312                                 )?
313                             };
314                         },
315                         quote! {
316                             // SAFETY: TODO
317                             unsafe { &mut (*#slot).#ident }
318                         },
319                     )
320                 };
321                 quote! {
322                     #(#attrs)*
323                     {
324                         let #init = #value;
325                         #value_init
326                     }
327                     #(#cfgs)*
328                     #[allow(unused_variables)]
329                     let #ident = #accessor;
330                 }
331             }
332             InitializerKind::Code { block: value, .. } => quote! {
333                 #(#attrs)*
334                 #[allow(unused_braces)]
335                 #value
336             },
337         };
338         res.extend(init);
339         if let Some(ident) = kind.ident() {
340             // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
341             let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
342             res.extend(quote! {
343                 #(#cfgs)*
344                 // Create the drop guard:
345                 //
346                 // We rely on macro hygiene to make it impossible for users to access this local
347                 // variable.
348                 // SAFETY: We forget the guard later when initialization has succeeded.
349                 let #guard = unsafe {
350                     ::pin_init::__internal::DropGuard::new(
351                         &raw mut (*slot).#ident
352                     )
353                 };
354             });
355             guards.push(guard);
356             guard_attrs.push(cfgs);
357         }
358     }
359     quote! {
360         #res
361         // If execution reaches this point, all fields have been initialized. Therefore we can now
362         // dismiss the guards by forgetting them.
363         #(
364             #(#guard_attrs)*
365             ::core::mem::forget(#guards);
366         )*
367     }
368 }
369 
370 /// Generate the check for ensuring that every field has been initialized.
371 fn make_field_check(
372     fields: &Punctuated<InitializerField, Token![,]>,
373     init_kind: InitKind,
374     path: &Path,
375 ) -> TokenStream {
376     let field_attrs = fields
377         .iter()
378         .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
379     let field_name = fields.iter().filter_map(|f| f.kind.ident());
380     match init_kind {
381         InitKind::Normal => quote! {
382             // We use unreachable code to ensure that all fields have been mentioned exactly once,
383             // this struct initializer will still be type-checked and complain with a very natural
384             // error message if a field is forgotten/mentioned more than once.
385             #[allow(unreachable_code, clippy::diverging_sub_expression)]
386             // SAFETY: this code is never executed.
387             let _ = || unsafe {
388                 ::core::ptr::write(slot, #path {
389                     #(
390                         #(#field_attrs)*
391                         #field_name: ::core::panic!(),
392                     )*
393                 })
394             };
395         },
396         InitKind::Zeroing => quote! {
397             // We use unreachable code to ensure that all fields have been mentioned at most once.
398             // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
399             // be zeroed. This struct initializer will still be type-checked and complain with a
400             // very natural error message if a field is mentioned more than once, or doesn't exist.
401             #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
402             // SAFETY: this code is never executed.
403             let _ = || unsafe {
404                 ::core::ptr::write(slot, #path {
405                     #(
406                         #(#field_attrs)*
407                         #field_name: ::core::panic!(),
408                     )*
409                     ..::core::mem::zeroed()
410                 })
411             };
412         },
413     }
414 }
415 
416 impl Parse for Initializer {
417     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
418         let attrs = input.call(Attribute::parse_outer)?;
419         let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
420         let path = input.parse()?;
421         let content;
422         let brace_token = braced!(content in input);
423         let mut fields = Punctuated::new();
424         loop {
425             let lh = content.lookahead1();
426             if lh.peek(End) || lh.peek(Token![..]) {
427                 break;
428             } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
429                 fields.push_value(content.parse()?);
430                 let lh = content.lookahead1();
431                 if lh.peek(End) {
432                     break;
433                 } else if lh.peek(Token![,]) {
434                     fields.push_punct(content.parse()?);
435                 } else {
436                     return Err(lh.error());
437                 }
438             } else {
439                 return Err(lh.error());
440             }
441         }
442         let rest = content
443             .peek(Token![..])
444             .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
445             .transpose()?;
446         let error = input
447             .peek(Token![?])
448             .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
449             .transpose()?;
450         let attrs = attrs
451             .into_iter()
452             .map(|a| {
453                 if a.path().is_ident("default_error") {
454                     a.parse_args::<DefaultErrorAttribute>()
455                         .map(InitializerAttribute::DefaultError)
456                 } else {
457                     Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
458                 }
459             })
460             .collect::<Result<Vec<_>, _>>()?;
461         Ok(Self {
462             attrs,
463             this,
464             path,
465             brace_token,
466             fields,
467             rest,
468             error,
469         })
470     }
471 }
472 
473 impl Parse for DefaultErrorAttribute {
474     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
475         Ok(Self { ty: input.parse()? })
476     }
477 }
478 
479 impl Parse for This {
480     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
481         Ok(Self {
482             _and_token: input.parse()?,
483             ident: input.parse()?,
484             _in_token: input.parse()?,
485         })
486     }
487 }
488 
489 impl Parse for InitializerField {
490     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
491         let attrs = input.call(Attribute::parse_outer)?;
492         Ok(Self {
493             attrs,
494             kind: input.parse()?,
495         })
496     }
497 }
498 
499 impl Parse for InitializerKind {
500     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
501         let lh = input.lookahead1();
502         if lh.peek(Token![_]) {
503             Ok(Self::Code {
504                 _underscore_token: input.parse()?,
505                 _colon_token: input.parse()?,
506                 block: input.parse()?,
507             })
508         } else if lh.peek(Ident) {
509             let ident = input.parse()?;
510             let lh = input.lookahead1();
511             if lh.peek(Token![<-]) {
512                 Ok(Self::Init {
513                     ident,
514                     _left_arrow_token: input.parse()?,
515                     value: input.parse()?,
516                 })
517             } else if lh.peek(Token![:]) {
518                 Ok(Self::Value {
519                     ident,
520                     value: Some((input.parse()?, input.parse()?)),
521                 })
522             } else if lh.peek(Token![,]) || lh.peek(End) {
523                 Ok(Self::Value { ident, value: None })
524             } else {
525                 Err(lh.error())
526             }
527         } else {
528             Err(lh.error())
529         }
530     }
531 }
532