1 // Copyright 2024, Linaro Limited 2 // Author(s): Manos Pitsidianakis <manos.pitsidianakis@linaro.org> 3 // SPDX-License-Identifier: GPL-2.0-or-later 4 5 use proc_macro::TokenStream; 6 use quote::quote; 7 use syn::{ 8 parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, Data, 9 DeriveInput, Field, Fields, Ident, Meta, Path, Token, Type, Variant, Visibility, 10 }; 11 12 mod utils; 13 use utils::MacroError; 14 15 fn get_fields<'a>( 16 input: &'a DeriveInput, 17 msg: &str, 18 ) -> Result<&'a Punctuated<Field, Comma>, MacroError> { 19 if let Data::Struct(s) = &input.data { 20 if let Fields::Named(fs) = &s.fields { 21 Ok(&fs.named) 22 } else { 23 Err(MacroError::Message( 24 format!("Named fields required for {}", msg), 25 input.ident.span(), 26 )) 27 } 28 } else { 29 Err(MacroError::Message( 30 format!("Struct required for {}", msg), 31 input.ident.span(), 32 )) 33 } 34 } 35 36 fn is_c_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError> { 37 let expected = parse_quote! { #[repr(C)] }; 38 39 if input.attrs.iter().any(|attr| attr == &expected) { 40 Ok(()) 41 } else { 42 Err(MacroError::Message( 43 format!("#[repr(C)] required for {}", msg), 44 input.ident.span(), 45 )) 46 } 47 } 48 49 fn derive_object_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> { 50 is_c_repr(&input, "#[derive(Object)]")?; 51 52 let name = &input.ident; 53 let parent = &get_fields(&input, "#[derive(Object)]")?[0].ident; 54 55 Ok(quote! { 56 ::qemu_api::assert_field_type!(#name, #parent, 57 ::qemu_api::qom::ParentField<<#name as ::qemu_api::qom::ObjectImpl>::ParentType>); 58 59 ::qemu_api::module_init! { 60 MODULE_INIT_QOM => unsafe { 61 ::qemu_api::bindings::type_register_static(&<#name as ::qemu_api::qom::ObjectImpl>::TYPE_INFO); 62 } 63 } 64 }) 65 } 66 67 #[proc_macro_derive(Object)] 68 pub fn derive_object(input: TokenStream) -> TokenStream { 69 let input = parse_macro_input!(input as DeriveInput); 70 let expanded = derive_object_or_error(input).unwrap_or_else(Into::into); 71 72 TokenStream::from(expanded) 73 } 74 75 #[rustfmt::skip::macros(quote)] 76 fn derive_offsets_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> { 77 is_c_repr(&input, "#[derive(offsets)]")?; 78 79 let name = &input.ident; 80 let fields = get_fields(&input, "#[derive(offsets)]")?; 81 let field_names: Vec<&Ident> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect(); 82 let field_types: Vec<&Type> = fields.iter().map(|f| &f.ty).collect(); 83 let field_vis: Vec<&Visibility> = fields.iter().map(|f| &f.vis).collect(); 84 85 Ok(quote! { 86 ::qemu_api::with_offsets! { 87 struct #name { 88 #(#field_vis #field_names: #field_types,)* 89 } 90 } 91 }) 92 } 93 94 #[proc_macro_derive(offsets)] 95 pub fn derive_offsets(input: TokenStream) -> TokenStream { 96 let input = parse_macro_input!(input as DeriveInput); 97 let expanded = derive_offsets_or_error(input).unwrap_or_else(Into::into); 98 99 TokenStream::from(expanded) 100 } 101 102 #[allow(non_snake_case)] 103 fn get_repr_uN(input: &DeriveInput, msg: &str) -> Result<Path, MacroError> { 104 let repr = input.attrs.iter().find(|attr| attr.path().is_ident("repr")); 105 if let Some(repr) = repr { 106 let nested = repr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?; 107 for meta in nested { 108 match meta { 109 Meta::Path(path) if path.is_ident("u8") => return Ok(path), 110 Meta::Path(path) if path.is_ident("u16") => return Ok(path), 111 Meta::Path(path) if path.is_ident("u32") => return Ok(path), 112 Meta::Path(path) if path.is_ident("u64") => return Ok(path), 113 _ => {} 114 } 115 } 116 } 117 118 Err(MacroError::Message( 119 format!("#[repr(u8/u16/u32/u64) required for {}", msg), 120 input.ident.span(), 121 )) 122 } 123 124 fn get_variants(input: &DeriveInput) -> Result<&Punctuated<Variant, Comma>, MacroError> { 125 if let Data::Enum(e) = &input.data { 126 if let Some(v) = e.variants.iter().find(|v| v.fields != Fields::Unit) { 127 return Err(MacroError::Message( 128 "Cannot derive TryInto for enum with non-unit variants.".to_string(), 129 v.fields.span(), 130 )); 131 } 132 Ok(&e.variants) 133 } else { 134 Err(MacroError::Message( 135 "Cannot derive TryInto for union or struct.".to_string(), 136 input.ident.span(), 137 )) 138 } 139 } 140 141 #[rustfmt::skip::macros(quote)] 142 fn derive_tryinto_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> { 143 let repr = get_repr_uN(&input, "#[derive(TryInto)]")?; 144 145 let name = &input.ident; 146 let variants = get_variants(&input)?; 147 let discriminants: Vec<&Ident> = variants.iter().map(|f| &f.ident).collect(); 148 149 Ok(quote! { 150 impl core::convert::TryFrom<#repr> for #name { 151 type Error = #repr; 152 153 fn try_from(value: #repr) -> Result<Self, Self::Error> { 154 #(const #discriminants: #repr = #name::#discriminants as #repr;)*; 155 match value { 156 #(#discriminants => Ok(Self::#discriminants),)* 157 _ => Err(value), 158 } 159 } 160 } 161 }) 162 } 163 164 #[proc_macro_derive(TryInto)] 165 pub fn derive_tryinto(input: TokenStream) -> TokenStream { 166 let input = parse_macro_input!(input as DeriveInput); 167 let expanded = derive_tryinto_or_error(input).unwrap_or_else(Into::into); 168 169 TokenStream::from(expanded) 170 } 171