1 // Copyright 2024, Linaro Limited 2 // Author(s): Manos Pitsidianakis <manos.pitsidianakis@linaro.org> 3 // SPDX-License-Identifier: GPL-2.0-or-later 4 5 use proc_macro::TokenStream; 6 use quote::quote; 7 use syn::{ 8 parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, Data, 9 DeriveInput, Field, Fields, FieldsUnnamed, Ident, Meta, Path, Token, Type, Variant, Visibility, 10 }; 11 12 mod utils; 13 use utils::MacroError; 14 15 fn get_fields<'a>( 16 input: &'a DeriveInput, 17 msg: &str, 18 ) -> Result<&'a Punctuated<Field, Comma>, MacroError> { 19 if let Data::Struct(s) = &input.data { 20 if let Fields::Named(fs) = &s.fields { 21 Ok(&fs.named) 22 } else { 23 Err(MacroError::Message( 24 format!("Named fields required for {}", msg), 25 input.ident.span(), 26 )) 27 } 28 } else { 29 Err(MacroError::Message( 30 format!("Struct required for {}", msg), 31 input.ident.span(), 32 )) 33 } 34 } 35 36 fn get_unnamed_field<'a>(input: &'a DeriveInput, msg: &str) -> Result<&'a Field, MacroError> { 37 if let Data::Struct(s) = &input.data { 38 let unnamed = match &s.fields { 39 Fields::Unnamed(FieldsUnnamed { 40 unnamed: ref fields, 41 .. 42 }) => fields, 43 _ => { 44 return Err(MacroError::Message( 45 format!("Tuple struct required for {}", msg), 46 s.fields.span(), 47 )) 48 } 49 }; 50 if unnamed.len() != 1 { 51 return Err(MacroError::Message( 52 format!("A single field is required for {}", msg), 53 s.fields.span(), 54 )); 55 } 56 Ok(&unnamed[0]) 57 } else { 58 Err(MacroError::Message( 59 format!("Struct required for {}", msg), 60 input.ident.span(), 61 )) 62 } 63 } 64 65 fn is_c_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError> { 66 let expected = parse_quote! { #[repr(C)] }; 67 68 if input.attrs.iter().any(|attr| attr == &expected) { 69 Ok(()) 70 } else { 71 Err(MacroError::Message( 72 format!("#[repr(C)] required for {}", msg), 73 input.ident.span(), 74 )) 75 } 76 } 77 78 fn is_transparent_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError> { 79 let expected = parse_quote! { #[repr(transparent)] }; 80 81 if input.attrs.iter().any(|attr| attr == &expected) { 82 Ok(()) 83 } else { 84 Err(MacroError::Message( 85 format!("#[repr(transparent)] required for {}", msg), 86 input.ident.span(), 87 )) 88 } 89 } 90 91 fn derive_object_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> { 92 is_c_repr(&input, "#[derive(Object)]")?; 93 94 let name = &input.ident; 95 let parent = &get_fields(&input, "#[derive(Object)]")?[0].ident; 96 97 Ok(quote! { 98 ::qemu_api::assert_field_type!(#name, #parent, 99 ::qemu_api::qom::ParentField<<#name as ::qemu_api::qom::ObjectImpl>::ParentType>); 100 101 ::qemu_api::module_init! { 102 MODULE_INIT_QOM => unsafe { 103 ::qemu_api::bindings::type_register_static(&<#name as ::qemu_api::qom::ObjectImpl>::TYPE_INFO); 104 } 105 } 106 }) 107 } 108 109 #[proc_macro_derive(Object)] 110 pub fn derive_object(input: TokenStream) -> TokenStream { 111 let input = parse_macro_input!(input as DeriveInput); 112 let expanded = derive_object_or_error(input).unwrap_or_else(Into::into); 113 114 TokenStream::from(expanded) 115 } 116 117 fn derive_opaque_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> { 118 is_transparent_repr(&input, "#[derive(Wrapper)]")?; 119 120 let name = &input.ident; 121 let field = &get_unnamed_field(&input, "#[derive(Wrapper)]")?; 122 let typ = &field.ty; 123 124 // TODO: how to add "::qemu_api"? For now, this is only used in the 125 // qemu_api crate so it's not a problem. 126 Ok(quote! { 127 unsafe impl crate::cell::Wrapper for #name { 128 type Wrapped = <#typ as crate::cell::Wrapper>::Wrapped; 129 } 130 impl #name { 131 pub unsafe fn from_raw<'a>(ptr: *mut <Self as crate::cell::Wrapper>::Wrapped) -> &'a Self { 132 let ptr = ::std::ptr::NonNull::new(ptr).unwrap().cast::<Self>(); 133 unsafe { ptr.as_ref() } 134 } 135 136 pub const fn as_mut_ptr(&self) -> *mut <Self as crate::cell::Wrapper>::Wrapped { 137 self.0.as_mut_ptr() 138 } 139 140 pub const fn as_ptr(&self) -> *const <Self as crate::cell::Wrapper>::Wrapped { 141 self.0.as_ptr() 142 } 143 144 pub const fn as_void_ptr(&self) -> *mut ::core::ffi::c_void { 145 self.0.as_void_ptr() 146 } 147 148 pub const fn raw_get(slot: *mut Self) -> *mut <Self as crate::cell::Wrapper>::Wrapped { 149 slot.cast() 150 } 151 } 152 }) 153 } 154 155 #[proc_macro_derive(Wrapper)] 156 pub fn derive_opaque(input: TokenStream) -> TokenStream { 157 let input = parse_macro_input!(input as DeriveInput); 158 let expanded = derive_opaque_or_error(input).unwrap_or_else(Into::into); 159 160 TokenStream::from(expanded) 161 } 162 163 #[rustfmt::skip::macros(quote)] 164 fn derive_offsets_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> { 165 is_c_repr(&input, "#[derive(offsets)]")?; 166 167 let name = &input.ident; 168 let fields = get_fields(&input, "#[derive(offsets)]")?; 169 let field_names: Vec<&Ident> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect(); 170 let field_types: Vec<&Type> = fields.iter().map(|f| &f.ty).collect(); 171 let field_vis: Vec<&Visibility> = fields.iter().map(|f| &f.vis).collect(); 172 173 Ok(quote! { 174 ::qemu_api::with_offsets! { 175 struct #name { 176 #(#field_vis #field_names: #field_types,)* 177 } 178 } 179 }) 180 } 181 182 #[proc_macro_derive(offsets)] 183 pub fn derive_offsets(input: TokenStream) -> TokenStream { 184 let input = parse_macro_input!(input as DeriveInput); 185 let expanded = derive_offsets_or_error(input).unwrap_or_else(Into::into); 186 187 TokenStream::from(expanded) 188 } 189 190 #[allow(non_snake_case)] 191 fn get_repr_uN(input: &DeriveInput, msg: &str) -> Result<Path, MacroError> { 192 let repr = input.attrs.iter().find(|attr| attr.path().is_ident("repr")); 193 if let Some(repr) = repr { 194 let nested = repr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?; 195 for meta in nested { 196 match meta { 197 Meta::Path(path) if path.is_ident("u8") => return Ok(path), 198 Meta::Path(path) if path.is_ident("u16") => return Ok(path), 199 Meta::Path(path) if path.is_ident("u32") => return Ok(path), 200 Meta::Path(path) if path.is_ident("u64") => return Ok(path), 201 _ => {} 202 } 203 } 204 } 205 206 Err(MacroError::Message( 207 format!("#[repr(u8/u16/u32/u64) required for {}", msg), 208 input.ident.span(), 209 )) 210 } 211 212 fn get_variants(input: &DeriveInput) -> Result<&Punctuated<Variant, Comma>, MacroError> { 213 if let Data::Enum(e) = &input.data { 214 if let Some(v) = e.variants.iter().find(|v| v.fields != Fields::Unit) { 215 return Err(MacroError::Message( 216 "Cannot derive TryInto for enum with non-unit variants.".to_string(), 217 v.fields.span(), 218 )); 219 } 220 Ok(&e.variants) 221 } else { 222 Err(MacroError::Message( 223 "Cannot derive TryInto for union or struct.".to_string(), 224 input.ident.span(), 225 )) 226 } 227 } 228 229 #[rustfmt::skip::macros(quote)] 230 fn derive_tryinto_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> { 231 let repr = get_repr_uN(&input, "#[derive(TryInto)]")?; 232 233 let name = &input.ident; 234 let variants = get_variants(&input)?; 235 let discriminants: Vec<&Ident> = variants.iter().map(|f| &f.ident).collect(); 236 237 Ok(quote! { 238 impl core::convert::TryFrom<#repr> for #name { 239 type Error = #repr; 240 241 fn try_from(value: #repr) -> Result<Self, Self::Error> { 242 #(const #discriminants: #repr = #name::#discriminants as #repr;)*; 243 match value { 244 #(#discriminants => Ok(Self::#discriminants),)* 245 _ => Err(value), 246 } 247 } 248 } 249 }) 250 } 251 252 #[proc_macro_derive(TryInto)] 253 pub fn derive_tryinto(input: TokenStream) -> TokenStream { 254 let input = parse_macro_input!(input as DeriveInput); 255 let expanded = derive_tryinto_or_error(input).unwrap_or_else(Into::into); 256 257 TokenStream::from(expanded) 258 } 259