xref: /qemu/rust/qemu-api-macros/src/lib.rs (revision ec7e5a90fea996f04ea24e81b680a87bc975354a)
1 // Copyright 2024, Linaro Limited
2 // Author(s): Manos Pitsidianakis <manos.pitsidianakis@linaro.org>
3 // SPDX-License-Identifier: GPL-2.0-or-later
4 
5 use proc_macro::TokenStream;
6 use quote::quote;
7 use syn::{
8     parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, Data,
9     DeriveInput, Field, Fields, FieldsUnnamed, Ident, Meta, Path, Token, Variant,
10 };
11 
12 mod utils;
13 use utils::MacroError;
14 
15 fn get_fields<'a>(
16     input: &'a DeriveInput,
17     msg: &str,
18 ) -> Result<&'a Punctuated<Field, Comma>, MacroError> {
19     let Data::Struct(ref s) = &input.data else {
20         return Err(MacroError::Message(
21             format!("Struct required for {msg}"),
22             input.ident.span(),
23         ));
24     };
25     let Fields::Named(ref fs) = &s.fields else {
26         return Err(MacroError::Message(
27             format!("Named fields required for {msg}"),
28             input.ident.span(),
29         ));
30     };
31     Ok(&fs.named)
32 }
33 
34 fn get_unnamed_field<'a>(input: &'a DeriveInput, msg: &str) -> Result<&'a Field, MacroError> {
35     let Data::Struct(ref s) = &input.data else {
36         return Err(MacroError::Message(
37             format!("Struct required for {msg}"),
38             input.ident.span(),
39         ));
40     };
41     let Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }) = &s.fields else {
42         return Err(MacroError::Message(
43             format!("Tuple struct required for {msg}"),
44             s.fields.span(),
45         ));
46     };
47     if unnamed.len() != 1 {
48         return Err(MacroError::Message(
49             format!("A single field is required for {msg}"),
50             s.fields.span(),
51         ));
52     }
53     Ok(&unnamed[0])
54 }
55 
56 fn is_c_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError> {
57     let expected = parse_quote! { #[repr(C)] };
58 
59     if input.attrs.iter().any(|attr| attr == &expected) {
60         Ok(())
61     } else {
62         Err(MacroError::Message(
63             format!("#[repr(C)] required for {msg}"),
64             input.ident.span(),
65         ))
66     }
67 }
68 
69 fn is_transparent_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError> {
70     let expected = parse_quote! { #[repr(transparent)] };
71 
72     if input.attrs.iter().any(|attr| attr == &expected) {
73         Ok(())
74     } else {
75         Err(MacroError::Message(
76             format!("#[repr(transparent)] required for {msg}"),
77             input.ident.span(),
78         ))
79     }
80 }
81 
82 fn derive_object_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> {
83     is_c_repr(&input, "#[derive(Object)]")?;
84 
85     let name = &input.ident;
86     let parent = &get_fields(&input, "#[derive(Object)]")?[0].ident;
87 
88     Ok(quote! {
89         ::qemu_api::assert_field_type!(#name, #parent,
90             ::qemu_api::qom::ParentField<<#name as ::qemu_api::qom::ObjectImpl>::ParentType>);
91 
92         ::qemu_api::module_init! {
93             MODULE_INIT_QOM => unsafe {
94                 ::qemu_api::bindings::type_register_static(&<#name as ::qemu_api::qom::ObjectImpl>::TYPE_INFO);
95             }
96         }
97     })
98 }
99 
100 #[proc_macro_derive(Object)]
101 pub fn derive_object(input: TokenStream) -> TokenStream {
102     let input = parse_macro_input!(input as DeriveInput);
103     let expanded = derive_object_or_error(input).unwrap_or_else(Into::into);
104 
105     TokenStream::from(expanded)
106 }
107 
108 fn derive_opaque_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> {
109     is_transparent_repr(&input, "#[derive(Wrapper)]")?;
110 
111     let name = &input.ident;
112     let field = &get_unnamed_field(&input, "#[derive(Wrapper)]")?;
113     let typ = &field.ty;
114 
115     // TODO: how to add "::qemu_api"?  For now, this is only used in the
116     // qemu_api crate so it's not a problem.
117     Ok(quote! {
118         unsafe impl crate::cell::Wrapper for #name {
119             type Wrapped = <#typ as crate::cell::Wrapper>::Wrapped;
120         }
121         impl #name {
122             pub unsafe fn from_raw<'a>(ptr: *mut <Self as crate::cell::Wrapper>::Wrapped) -> &'a Self {
123                 let ptr = ::std::ptr::NonNull::new(ptr).unwrap().cast::<Self>();
124                 unsafe { ptr.as_ref() }
125             }
126 
127             pub const fn as_mut_ptr(&self) -> *mut <Self as crate::cell::Wrapper>::Wrapped {
128                 self.0.as_mut_ptr()
129             }
130 
131             pub const fn as_ptr(&self) -> *const <Self as crate::cell::Wrapper>::Wrapped {
132                 self.0.as_ptr()
133             }
134 
135             pub const fn as_void_ptr(&self) -> *mut ::core::ffi::c_void {
136                 self.0.as_void_ptr()
137             }
138 
139             pub const fn raw_get(slot: *mut Self) -> *mut <Self as crate::cell::Wrapper>::Wrapped {
140                 slot.cast()
141             }
142         }
143     })
144 }
145 
146 #[proc_macro_derive(Wrapper)]
147 pub fn derive_opaque(input: TokenStream) -> TokenStream {
148     let input = parse_macro_input!(input as DeriveInput);
149     let expanded = derive_opaque_or_error(input).unwrap_or_else(Into::into);
150 
151     TokenStream::from(expanded)
152 }
153 
154 #[allow(non_snake_case)]
155 fn get_repr_uN(input: &DeriveInput, msg: &str) -> Result<Path, MacroError> {
156     let repr = input.attrs.iter().find(|attr| attr.path().is_ident("repr"));
157     if let Some(repr) = repr {
158         let nested = repr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
159         for meta in nested {
160             match meta {
161                 Meta::Path(path) if path.is_ident("u8") => return Ok(path),
162                 Meta::Path(path) if path.is_ident("u16") => return Ok(path),
163                 Meta::Path(path) if path.is_ident("u32") => return Ok(path),
164                 Meta::Path(path) if path.is_ident("u64") => return Ok(path),
165                 _ => {}
166             }
167         }
168     }
169 
170     Err(MacroError::Message(
171         format!("#[repr(u8/u16/u32/u64) required for {msg}"),
172         input.ident.span(),
173     ))
174 }
175 
176 fn get_variants(input: &DeriveInput) -> Result<&Punctuated<Variant, Comma>, MacroError> {
177     let Data::Enum(ref e) = &input.data else {
178         return Err(MacroError::Message(
179             "Cannot derive TryInto for union or struct.".to_string(),
180             input.ident.span(),
181         ));
182     };
183     if let Some(v) = e.variants.iter().find(|v| v.fields != Fields::Unit) {
184         return Err(MacroError::Message(
185             "Cannot derive TryInto for enum with non-unit variants.".to_string(),
186             v.fields.span(),
187         ));
188     }
189     Ok(&e.variants)
190 }
191 
192 #[rustfmt::skip::macros(quote)]
193 fn derive_tryinto_or_error(input: DeriveInput) -> Result<proc_macro2::TokenStream, MacroError> {
194     let repr = get_repr_uN(&input, "#[derive(TryInto)]")?;
195 
196     let name = &input.ident;
197     let variants = get_variants(&input)?;
198     let discriminants: Vec<&Ident> = variants.iter().map(|f| &f.ident).collect();
199 
200     Ok(quote! {
201         impl core::convert::TryFrom<#repr> for #name {
202             type Error = #repr;
203 
204             fn try_from(value: #repr) -> Result<Self, Self::Error> {
205                 #(const #discriminants: #repr = #name::#discriminants as #repr;)*;
206                 match value {
207                     #(#discriminants => Ok(Self::#discriminants),)*
208                     _ => Err(value),
209                 }
210             }
211         }
212     })
213 }
214 
215 #[proc_macro_derive(TryInto)]
216 pub fn derive_tryinto(input: TokenStream) -> TokenStream {
217     let input = parse_macro_input!(input as DeriveInput);
218     let expanded = derive_tryinto_or_error(input).unwrap_or_else(Into::into);
219 
220     TokenStream::from(expanded)
221 }
222