1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use proc_macro2::{Span, TokenStream}; 4 use quote::{format_ident, quote, quote_spanned}; 5 use syn::{ 6 braced, 7 parse::{End, Parse}, 8 parse_quote, 9 punctuated::Punctuated, 10 spanned::Spanned, 11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, 12 }; 13 14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; 15 16 pub(crate) struct Initializer { 17 attrs: Vec<InitializerAttribute>, 18 this: Option<This>, 19 path: Path, 20 brace_token: token::Brace, 21 fields: Punctuated<InitializerField, Token![,]>, 22 rest: Option<(Token![..], Expr)>, 23 error: Option<(Token![?], Type)>, 24 } 25 26 struct This { 27 _and_token: Token![&], 28 ident: Ident, 29 _in_token: Token![in], 30 } 31 32 struct InitializerField { 33 attrs: Vec<Attribute>, 34 kind: InitializerKind, 35 } 36 37 enum InitializerKind { 38 Value { 39 ident: Ident, 40 value: Option<(Token![:], Expr)>, 41 }, 42 Init { 43 ident: Ident, 44 _left_arrow_token: Token![<-], 45 value: Expr, 46 }, 47 Code { 48 _underscore_token: Token![_], 49 _colon_token: Token![:], 50 block: Block, 51 }, 52 } 53 54 impl InitializerKind { 55 fn ident(&self) -> Option<&Ident> { 56 match self { 57 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), 58 Self::Code { .. } => None, 59 } 60 } 61 } 62 63 enum InitializerAttribute { 64 DefaultError(DefaultErrorAttribute), 65 } 66 67 struct DefaultErrorAttribute { 68 ty: Box<Type>, 69 } 70 71 pub(crate) fn expand( 72 Initializer { 73 attrs, 74 this, 75 path, 76 brace_token, 77 fields, 78 rest, 79 error, 80 }: Initializer, 81 default_error: Option<&'static str>, 82 pinned: bool, 83 dcx: &mut DiagCtxt, 84 ) -> Result<TokenStream, ErrorGuaranteed> { 85 let error = error.map_or_else( 86 || { 87 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| { 88 #[expect(irrefutable_let_patterns)] 89 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr { 90 Some(ty.clone()) 91 } else { 92 acc 93 } 94 }) { 95 default_error 96 } else if let Some(default_error) = default_error { 97 syn::parse_str(default_error).unwrap() 98 } else { 99 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); 100 parse_quote!(::core::convert::Infallible) 101 } 102 }, 103 |(_, err)| Box::new(err), 104 ); 105 let slot = format_ident!("slot"); 106 let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { 107 ( 108 format_ident!("HasPinData"), 109 format_ident!("PinData"), 110 format_ident!("__pin_data"), 111 format_ident!("pin_init_from_closure"), 112 ) 113 } else { 114 ( 115 format_ident!("HasInitData"), 116 format_ident!("InitData"), 117 format_ident!("__init_data"), 118 format_ident!("init_from_closure"), 119 ) 120 }; 121 let init_kind = get_init_kind(rest, dcx); 122 let zeroable_check = match init_kind { 123 InitKind::Normal => quote!(), 124 InitKind::Zeroing => quote! { 125 // The user specified `..Zeroable::zeroed()` at the end of the list of fields. 126 // Therefore we check if the struct implements `Zeroable` and then zero the memory. 127 // This allows us to also remove the check that all fields are present (since we 128 // already set the memory to zero and that is a valid bit pattern). 129 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) 130 where T: ::pin_init::Zeroable 131 {} 132 // Ensure that the struct is indeed `Zeroable`. 133 assert_zeroable(#slot); 134 // SAFETY: The type implements `Zeroable` by the check above. 135 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; 136 }, 137 }; 138 let this = match this { 139 None => quote!(), 140 Some(This { ident, .. }) => quote! { 141 // Create the `this` so it can be referenced by the user inside of the 142 // expressions creating the individual fields. 143 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; 144 }, 145 }; 146 // `mixed_site` ensures that the data is not accessible to the user-controlled code. 147 let data = Ident::new("__data", Span::mixed_site()); 148 let init_fields = init_fields(&fields, pinned, &data, &slot); 149 let field_check = make_field_check(&fields, init_kind, &path); 150 Ok(quote! {{ 151 // Get the data about fields from the supplied type. 152 // SAFETY: TODO 153 let #data = unsafe { 154 use ::pin_init::__internal::#has_data_trait; 155 // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit 156 // generics (which need to be present with that syntax). 157 #path::#get_data() 158 }; 159 // Ensure that `#data` really is of type `#data` and help with type inference: 160 let init = ::pin_init::__internal::#data_trait::make_closure::<_, #error>( 161 #data, 162 move |slot| { 163 #zeroable_check 164 #this 165 #init_fields 166 #field_check 167 // SAFETY: we are the `init!` macro that is allowed to call this. 168 Ok(unsafe { ::pin_init::__internal::InitOk::new() }) 169 } 170 ); 171 let init = move |slot| -> ::core::result::Result<(), #error> { 172 init(slot).map(|__InitOk| ()) 173 }; 174 // SAFETY: TODO 175 let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; 176 // FIXME: this let binding is required to avoid a compiler error (cycle when computing the 177 // opaque type returned by this function) before Rust 1.81. Remove after MSRV bump. 178 #[allow( 179 clippy::let_and_return, 180 reason = "some clippy versions warn about the let binding" 181 )] 182 init 183 }}) 184 } 185 186 enum InitKind { 187 Normal, 188 Zeroing, 189 } 190 191 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { 192 let Some((dotdot, expr)) = rest else { 193 return InitKind::Normal; 194 }; 195 match &expr { 196 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { 197 Expr::Path(ExprPath { 198 attrs, 199 qself: None, 200 path: 201 Path { 202 leading_colon: None, 203 segments, 204 }, 205 }) if attrs.is_empty() 206 && segments.len() == 2 207 && segments[0].ident == "Zeroable" 208 && segments[0].arguments.is_none() 209 && segments[1].ident == "init_zeroed" 210 && segments[1].arguments.is_none() => 211 { 212 return InitKind::Zeroing; 213 } 214 _ => {} 215 }, 216 _ => {} 217 } 218 dcx.error( 219 dotdot.span().join(expr.span()).unwrap_or(expr.span()), 220 "expected nothing or `..Zeroable::init_zeroed()`.", 221 ); 222 InitKind::Normal 223 } 224 225 /// Generate the code that initializes the fields of the struct using the initializers in `field`. 226 fn init_fields( 227 fields: &Punctuated<InitializerField, Token![,]>, 228 pinned: bool, 229 data: &Ident, 230 slot: &Ident, 231 ) -> TokenStream { 232 let mut guards = vec![]; 233 let mut guard_attrs = vec![]; 234 let mut res = TokenStream::new(); 235 for InitializerField { attrs, kind } in fields { 236 let cfgs = { 237 let mut cfgs = attrs.clone(); 238 cfgs.retain(|attr| attr.path().is_ident("cfg")); 239 cfgs 240 }; 241 let init = match kind { 242 InitializerKind::Value { ident, value } => { 243 let mut value_ident = ident.clone(); 244 let value_prep = value.as_ref().map(|value| &value.1).map(|value| { 245 // Setting the span of `value_ident` to `value`'s span improves error messages 246 // when the type of `value` is wrong. 247 value_ident.set_span(value.span()); 248 quote!(let #value_ident = #value;) 249 }); 250 // Again span for better diagnostics 251 let write = quote_spanned!(ident.span()=> ::core::ptr::write); 252 // NOTE: the field accessor ensures that the initialized field is properly aligned. 253 // Unaligned fields will cause the compiler to emit E0793. We do not support 254 // unaligned fields since `Init::__init` requires an aligned pointer; the call to 255 // `ptr::write` below has the same requirement. 256 let accessor = if pinned { 257 let project_ident = format_ident!("__project_{ident}"); 258 quote! { 259 // SAFETY: TODO 260 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 261 } 262 } else { 263 quote! { 264 // SAFETY: TODO 265 unsafe { &mut (*#slot).#ident } 266 } 267 }; 268 quote! { 269 #(#attrs)* 270 { 271 #value_prep 272 // SAFETY: TODO 273 unsafe { #write(&raw mut (*#slot).#ident, #value_ident) }; 274 } 275 #(#cfgs)* 276 #[allow(unused_variables)] 277 let #ident = #accessor; 278 } 279 } 280 InitializerKind::Init { ident, value, .. } => { 281 // Again span for better diagnostics 282 let init = format_ident!("init", span = value.span()); 283 // NOTE: the field accessor ensures that the initialized field is properly aligned. 284 // Unaligned fields will cause the compiler to emit E0793. We do not support 285 // unaligned fields since `Init::__init` requires an aligned pointer; the call to 286 // `ptr::write` below has the same requirement. 287 let (value_init, accessor) = if pinned { 288 let project_ident = format_ident!("__project_{ident}"); 289 ( 290 quote! { 291 // SAFETY: 292 // - `slot` is valid, because we are inside of an initializer closure, we 293 // return when an error/panic occurs. 294 // - We also use `#data` to require the correct trait (`Init` or `PinInit`) 295 // for `#ident`. 296 unsafe { #data.#ident(&raw mut (*#slot).#ident, #init)? }; 297 }, 298 quote! { 299 // SAFETY: TODO 300 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 301 }, 302 ) 303 } else { 304 ( 305 quote! { 306 // SAFETY: `slot` is valid, because we are inside of an initializer 307 // closure, we return when an error/panic occurs. 308 unsafe { 309 ::pin_init::Init::__init( 310 #init, 311 &raw mut (*#slot).#ident, 312 )? 313 }; 314 }, 315 quote! { 316 // SAFETY: TODO 317 unsafe { &mut (*#slot).#ident } 318 }, 319 ) 320 }; 321 quote! { 322 #(#attrs)* 323 { 324 let #init = #value; 325 #value_init 326 } 327 #(#cfgs)* 328 #[allow(unused_variables)] 329 let #ident = #accessor; 330 } 331 } 332 InitializerKind::Code { block: value, .. } => quote! { 333 #(#attrs)* 334 #[allow(unused_braces)] 335 #value 336 }, 337 }; 338 res.extend(init); 339 if let Some(ident) = kind.ident() { 340 // `mixed_site` ensures that the guard is not accessible to the user-controlled code. 341 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); 342 res.extend(quote! { 343 #(#cfgs)* 344 // Create the drop guard: 345 // 346 // We rely on macro hygiene to make it impossible for users to access this local 347 // variable. 348 // SAFETY: We forget the guard later when initialization has succeeded. 349 let #guard = unsafe { 350 ::pin_init::__internal::DropGuard::new( 351 &raw mut (*slot).#ident 352 ) 353 }; 354 }); 355 guards.push(guard); 356 guard_attrs.push(cfgs); 357 } 358 } 359 quote! { 360 #res 361 // If execution reaches this point, all fields have been initialized. Therefore we can now 362 // dismiss the guards by forgetting them. 363 #( 364 #(#guard_attrs)* 365 ::core::mem::forget(#guards); 366 )* 367 } 368 } 369 370 /// Generate the check for ensuring that every field has been initialized. 371 fn make_field_check( 372 fields: &Punctuated<InitializerField, Token![,]>, 373 init_kind: InitKind, 374 path: &Path, 375 ) -> TokenStream { 376 let field_attrs = fields 377 .iter() 378 .filter_map(|f| f.kind.ident().map(|_| &f.attrs)); 379 let field_name = fields.iter().filter_map(|f| f.kind.ident()); 380 match init_kind { 381 InitKind::Normal => quote! { 382 // We use unreachable code to ensure that all fields have been mentioned exactly once, 383 // this struct initializer will still be type-checked and complain with a very natural 384 // error message if a field is forgotten/mentioned more than once. 385 #[allow(unreachable_code, clippy::diverging_sub_expression)] 386 // SAFETY: this code is never executed. 387 let _ = || unsafe { 388 ::core::ptr::write(slot, #path { 389 #( 390 #(#field_attrs)* 391 #field_name: ::core::panic!(), 392 )* 393 }) 394 }; 395 }, 396 InitKind::Zeroing => quote! { 397 // We use unreachable code to ensure that all fields have been mentioned at most once. 398 // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will 399 // be zeroed. This struct initializer will still be type-checked and complain with a 400 // very natural error message if a field is mentioned more than once, or doesn't exist. 401 #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)] 402 // SAFETY: this code is never executed. 403 let _ = || unsafe { 404 ::core::ptr::write(slot, #path { 405 #( 406 #(#field_attrs)* 407 #field_name: ::core::panic!(), 408 )* 409 ..::core::mem::zeroed() 410 }) 411 }; 412 }, 413 } 414 } 415 416 impl Parse for Initializer { 417 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 418 let attrs = input.call(Attribute::parse_outer)?; 419 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; 420 let path = input.parse()?; 421 let content; 422 let brace_token = braced!(content in input); 423 let mut fields = Punctuated::new(); 424 loop { 425 let lh = content.lookahead1(); 426 if lh.peek(End) || lh.peek(Token![..]) { 427 break; 428 } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) { 429 fields.push_value(content.parse()?); 430 let lh = content.lookahead1(); 431 if lh.peek(End) { 432 break; 433 } else if lh.peek(Token![,]) { 434 fields.push_punct(content.parse()?); 435 } else { 436 return Err(lh.error()); 437 } 438 } else { 439 return Err(lh.error()); 440 } 441 } 442 let rest = content 443 .peek(Token![..]) 444 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) 445 .transpose()?; 446 let error = input 447 .peek(Token![?]) 448 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) 449 .transpose()?; 450 let attrs = attrs 451 .into_iter() 452 .map(|a| { 453 if a.path().is_ident("default_error") { 454 a.parse_args::<DefaultErrorAttribute>() 455 .map(InitializerAttribute::DefaultError) 456 } else { 457 Err(syn::Error::new_spanned(a, "unknown initializer attribute")) 458 } 459 }) 460 .collect::<Result<Vec<_>, _>>()?; 461 Ok(Self { 462 attrs, 463 this, 464 path, 465 brace_token, 466 fields, 467 rest, 468 error, 469 }) 470 } 471 } 472 473 impl Parse for DefaultErrorAttribute { 474 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 475 Ok(Self { ty: input.parse()? }) 476 } 477 } 478 479 impl Parse for This { 480 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 481 Ok(Self { 482 _and_token: input.parse()?, 483 ident: input.parse()?, 484 _in_token: input.parse()?, 485 }) 486 } 487 } 488 489 impl Parse for InitializerField { 490 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 491 let attrs = input.call(Attribute::parse_outer)?; 492 Ok(Self { 493 attrs, 494 kind: input.parse()?, 495 }) 496 } 497 } 498 499 impl Parse for InitializerKind { 500 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 501 let lh = input.lookahead1(); 502 if lh.peek(Token![_]) { 503 Ok(Self::Code { 504 _underscore_token: input.parse()?, 505 _colon_token: input.parse()?, 506 block: input.parse()?, 507 }) 508 } else if lh.peek(Ident) { 509 let ident = input.parse()?; 510 let lh = input.lookahead1(); 511 if lh.peek(Token![<-]) { 512 Ok(Self::Init { 513 ident, 514 _left_arrow_token: input.parse()?, 515 value: input.parse()?, 516 }) 517 } else if lh.peek(Token![:]) { 518 Ok(Self::Value { 519 ident, 520 value: Some((input.parse()?, input.parse()?)), 521 }) 522 } else if lh.peek(Token![,]) || lh.peek(End) { 523 Ok(Self::Value { ident, value: None }) 524 } else { 525 Err(lh.error()) 526 } 527 } else { 528 Err(lh.error()) 529 } 530 } 531 } 532