xref: /qemu/rust/qemu-api-macros/src/lib.rs (revision 897c68fb795cf03b89b6688a6f945d68a765c3e4)
1 // Copyright 2024, Linaro Limited
2 // Author(s): Manos Pitsidianakis <manos.pitsidianakis@linaro.org>
3 // SPDX-License-Identifier: GPL-2.0-or-later
4 
5 use proc_macro::TokenStream;
6 use quote::quote;
7 use syn::{
8     parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, Data,
9     DeriveInput, Field, Fields, FieldsUnnamed, Ident, Meta, Path, Token, Type, Variant, Visibility,
10 };
11 
12 mod utils;
13 use utils::MacroError;
14 
15 fn get_fields<'a>(
16     input: &'a DeriveInput,
17     msg: &str,
18 ) -> Result<&'a Punctuated<Field, Comma>, MacroError> {
19     if let Data::Struct(s) = &input.data {
20         if let Fields::Named(fs) = &s.fields {
21             Ok(&fs.named)
22         } else {
23             Err(MacroError::Message(
24                 format!("Named fields required for {}", msg),
25                 input.ident.span(),
26             ))
27         }
28     } else {
29         Err(MacroError::Message(
30             format!("Struct required for {}", msg),
31             input.ident.span(),
32         ))
33     }
34 }
35 
36 fn get_unnamed_field<'a>(input: &'a DeriveInput, msg: &str) -> Result<&'a Field, MacroError> {
37     if let Data::Struct(s) = &input.data {
38         let unnamed = match &s.fields {
39             Fields::Unnamed(FieldsUnnamed {
40                 unnamed: ref fields,
41                 ..
42             }) => fields,
43             _ => {
44                 return Err(MacroError::Message(
45                     format!("Tuple struct required for {}", msg),
46                     s.fields.span(),
47                 ))
48             }
49         };
50         if unnamed.len() != 1 {
51             return Err(MacroError::Message(
52                 format!("A single field is required for {}", msg),
53                 s.fields.span(),
54             ));
55         }
56         Ok(&unnamed[0])
57     } else {
58         Err(MacroError::Message(
59             format!("Struct required for {}", msg),
60             input.ident.span(),
61         ))
62     }
63 }
64 
65 fn is_c_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError> {
66     let expected = parse_quote! { #[repr(C)] };
67 
68     if input.attrs.iter().any(|attr| attr == &expected) {
69         Ok(())
70     } else {
71         Err(MacroError::Message(
72             format!("#[repr(C)] required for {}", msg),
73             input.ident.span(),
74         ))
75     }
76 }
77 
78 fn is_transparent_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError> {
79     let expected = parse_quote! { #[repr(transparent)] };
80 
81     if input.attrs.iter().any(|attr| attr == &expected) {
82         Ok(())
83     } else {
84         Err(MacroError::Message(
85             format!("#[repr(transparent)] required for {}", msg),
86             input.ident.span(),
87         ))
88     }
89 }
90 
91 fn derive_object_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> {
92     is_c_repr(&input, "#[derive(Object)]")?;
93 
94     let name = &input.ident;
95     let parent = &get_fields(&input, "#[derive(Object)]")?[0].ident;
96 
97     Ok(quote! {
98         ::qemu_api::assert_field_type!(#name, #parent,
99             ::qemu_api::qom::ParentField<<#name as ::qemu_api::qom::ObjectImpl>::ParentType>);
100 
101         ::qemu_api::module_init! {
102             MODULE_INIT_QOM => unsafe {
103                 ::qemu_api::bindings::type_register_static(&<#name as ::qemu_api::qom::ObjectImpl>::TYPE_INFO);
104             }
105         }
106     })
107 }
108 
109 #[proc_macro_derive(Object)]
110 pub fn derive_object(input: TokenStream) -> TokenStream {
111     let input = parse_macro_input!(input as DeriveInput);
112     let expanded = derive_object_or_error(input).unwrap_or_else(Into::into);
113 
114     TokenStream::from(expanded)
115 }
116 
117 fn derive_opaque_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> {
118     is_transparent_repr(&input, "#[derive(Wrapper)]")?;
119 
120     let name = &input.ident;
121     let field = &get_unnamed_field(&input, "#[derive(Wrapper)]")?;
122     let typ = &field.ty;
123 
124     // TODO: how to add "::qemu_api"?  For now, this is only used in the
125     // qemu_api crate so it's not a problem.
126     Ok(quote! {
127         unsafe impl crate::cell::Wrapper for #name {
128             type Wrapped = <#typ as crate::cell::Wrapper>::Wrapped;
129         }
130         impl #name {
131             pub unsafe fn from_raw<'a>(ptr: *mut <Self as crate::cell::Wrapper>::Wrapped) -> &'a Self {
132                 let ptr = ::std::ptr::NonNull::new(ptr).unwrap().cast::<Self>();
133                 unsafe { ptr.as_ref() }
134             }
135 
136             pub const fn as_mut_ptr(&self) -> *mut <Self as crate::cell::Wrapper>::Wrapped {
137                 self.0.as_mut_ptr()
138             }
139 
140             pub const fn as_ptr(&self) -> *const <Self as crate::cell::Wrapper>::Wrapped {
141                 self.0.as_ptr()
142             }
143 
144             pub const fn as_void_ptr(&self) -> *mut ::core::ffi::c_void {
145                 self.0.as_void_ptr()
146             }
147 
148             pub const fn raw_get(slot: *mut Self) -> *mut <Self as crate::cell::Wrapper>::Wrapped {
149                 slot.cast()
150             }
151         }
152     })
153 }
154 
155 #[proc_macro_derive(Wrapper)]
156 pub fn derive_opaque(input: TokenStream) -> TokenStream {
157     let input = parse_macro_input!(input as DeriveInput);
158     let expanded = derive_opaque_or_error(input).unwrap_or_else(Into::into);
159 
160     TokenStream::from(expanded)
161 }
162 
163 #[rustfmt::skip::macros(quote)]
164 fn derive_offsets_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> {
165     is_c_repr(&input, "#[derive(offsets)]")?;
166 
167     let name = &input.ident;
168     let fields = get_fields(&input, "#[derive(offsets)]")?;
169     let field_names: Vec<&Ident> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
170     let field_types: Vec<&Type> = fields.iter().map(|f| &f.ty).collect();
171     let field_vis: Vec<&Visibility> = fields.iter().map(|f| &f.vis).collect();
172 
173     Ok(quote! {
174 	::qemu_api::with_offsets! {
175 	    struct #name {
176 		#(#field_vis #field_names: #field_types,)*
177 	    }
178 	}
179     })
180 }
181 
182 #[proc_macro_derive(offsets)]
183 pub fn derive_offsets(input: TokenStream) -> TokenStream {
184     let input = parse_macro_input!(input as DeriveInput);
185     let expanded = derive_offsets_or_error(input).unwrap_or_else(Into::into);
186 
187     TokenStream::from(expanded)
188 }
189 
190 #[allow(non_snake_case)]
191 fn get_repr_uN(input: &DeriveInput, msg: &str) -> Result<Path, MacroError> {
192     let repr = input.attrs.iter().find(|attr| attr.path().is_ident("repr"));
193     if let Some(repr) = repr {
194         let nested = repr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
195         for meta in nested {
196             match meta {
197                 Meta::Path(path) if path.is_ident("u8") => return Ok(path),
198                 Meta::Path(path) if path.is_ident("u16") => return Ok(path),
199                 Meta::Path(path) if path.is_ident("u32") => return Ok(path),
200                 Meta::Path(path) if path.is_ident("u64") => return Ok(path),
201                 _ => {}
202             }
203         }
204     }
205 
206     Err(MacroError::Message(
207         format!("#[repr(u8/u16/u32/u64) required for {}", msg),
208         input.ident.span(),
209     ))
210 }
211 
212 fn get_variants(input: &DeriveInput) -> Result<&Punctuated<Variant, Comma>, MacroError> {
213     if let Data::Enum(e) = &input.data {
214         if let Some(v) = e.variants.iter().find(|v| v.fields != Fields::Unit) {
215             return Err(MacroError::Message(
216                 "Cannot derive TryInto for enum with non-unit variants.".to_string(),
217                 v.fields.span(),
218             ));
219         }
220         Ok(&e.variants)
221     } else {
222         Err(MacroError::Message(
223             "Cannot derive TryInto for union or struct.".to_string(),
224             input.ident.span(),
225         ))
226     }
227 }
228 
229 #[rustfmt::skip::macros(quote)]
230 fn derive_tryinto_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> {
231     let repr = get_repr_uN(&input, "#[derive(TryInto)]")?;
232 
233     let name = &input.ident;
234     let variants = get_variants(&input)?;
235     let discriminants: Vec<&Ident> = variants.iter().map(|f| &f.ident).collect();
236 
237     Ok(quote! {
238         impl core::convert::TryFrom<#repr> for #name {
239             type Error = #repr;
240 
241             fn try_from(value: #repr) -> Result<Self, Self::Error> {
242                 #(const #discriminants: #repr = #name::#discriminants as #repr;)*;
243                 match value {
244                     #(#discriminants => Ok(Self::#discriminants),)*
245                     _ => Err(value),
246                 }
247             }
248         }
249     })
250 }
251 
252 #[proc_macro_derive(TryInto)]
253 pub fn derive_tryinto(input: TokenStream) -> TokenStream {
254     let input = parse_macro_input!(input as DeriveInput);
255     let expanded = derive_tryinto_or_error(input).unwrap_or_else(Into::into);
256 
257     TokenStream::from(expanded)
258 }
259