xref: /linux/rust/macros/module.rs (revision 26ff969926a08eee069767ddbbbc301adbcd9676)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 use std::ffi::CString;
4 
5 use proc_macro2::{
6     Literal,
7     TokenStream, //
8 };
9 use quote::{
10     format_ident,
11     quote, //
12 };
13 use syn::{
14     braced,
15     bracketed,
16     ext::IdentExt,
17     parse::{
18         Parse,
19         ParseStream, //
20     },
21     parse_quote,
22     punctuated::Punctuated,
23     Error,
24     Expr,
25     Ident,
26     LitStr,
27     Path,
28     Result,
29     Token,
30     Type, //
31 };
32 
33 use crate::helpers::*;
34 
35 struct ModInfoBuilder<'a> {
36     module: &'a str,
37     counter: usize,
38     ts: TokenStream,
39     param_ts: TokenStream,
40 }
41 
42 impl<'a> ModInfoBuilder<'a> {
new(module: &'a str) -> Self43     fn new(module: &'a str) -> Self {
44         ModInfoBuilder {
45             module,
46             counter: 0,
47             ts: TokenStream::new(),
48             param_ts: TokenStream::new(),
49         }
50     }
51 
emit_base(&mut self, field: &str, content: &str, builtin: bool, param: bool)52     fn emit_base(&mut self, field: &str, content: &str, builtin: bool, param: bool) {
53         let string = if builtin {
54             // Built-in modules prefix their modinfo strings by `module.`.
55             format!("{module}.{field}={content}\0", module = self.module)
56         } else {
57             // Loadable modules' modinfo strings go as-is.
58             format!("{field}={content}\0")
59         };
60         let length = string.len();
61         let string = Literal::byte_string(string.as_bytes());
62         let cfg = if builtin {
63             quote!(#[cfg(not(MODULE))])
64         } else {
65             quote!(#[cfg(MODULE)])
66         };
67 
68         let counter = format_ident!(
69             "__{module}_{counter}",
70             module = self.module.to_uppercase(),
71             counter = self.counter
72         );
73         let item = quote! {
74             #cfg
75             #[cfg_attr(not(target_os = "macos"), link_section = ".modinfo")]
76             #[used(compiler)]
77             pub static #counter: [u8; #length] = *#string;
78         };
79 
80         if param {
81             self.param_ts.extend(item);
82         } else {
83             self.ts.extend(item);
84         }
85 
86         self.counter += 1;
87     }
88 
emit_only_builtin(&mut self, field: &str, content: &str, param: bool)89     fn emit_only_builtin(&mut self, field: &str, content: &str, param: bool) {
90         self.emit_base(field, content, true, param)
91     }
92 
emit_only_loadable(&mut self, field: &str, content: &str, param: bool)93     fn emit_only_loadable(&mut self, field: &str, content: &str, param: bool) {
94         self.emit_base(field, content, false, param)
95     }
96 
emit(&mut self, field: &str, content: &str)97     fn emit(&mut self, field: &str, content: &str) {
98         self.emit_internal(field, content, false);
99     }
100 
emit_internal(&mut self, field: &str, content: &str, param: bool)101     fn emit_internal(&mut self, field: &str, content: &str, param: bool) {
102         self.emit_only_builtin(field, content, param);
103         self.emit_only_loadable(field, content, param);
104     }
105 
emit_param(&mut self, field: &str, param: &str, content: &str)106     fn emit_param(&mut self, field: &str, param: &str, content: &str) {
107         let content = format!("{param}:{content}");
108         self.emit_internal(field, &content, true);
109     }
110 
emit_params(&mut self, info: &ModuleInfo)111     fn emit_params(&mut self, info: &ModuleInfo) {
112         let Some(params) = &info.params else {
113             return;
114         };
115 
116         for param in params {
117             let param_name_str = param.name.to_string();
118             let param_type_str = param.ptype.to_string();
119 
120             let ops = param_ops_path(&param_type_str);
121 
122             // Note: The spelling of these fields is dictated by the user space
123             // tool `modinfo`.
124             self.emit_param("parmtype", &param_name_str, &param_type_str);
125             self.emit_param("parm", &param_name_str, &param.description.value());
126 
127             let static_name = format_ident!("__{}_{}_struct", self.module, param.name);
128             let param_name_cstr =
129                 CString::new(param_name_str).expect("name contains NUL-terminator");
130             let param_name_cstr_with_module =
131                 CString::new(format!("{}.{}", self.module, param.name))
132                     .expect("name contains NUL-terminator");
133 
134             let param_name = &param.name;
135             let param_type = &param.ptype;
136             let param_default = &param.default;
137 
138             self.param_ts.extend(quote! {
139                 #[allow(non_upper_case_globals)]
140                 pub(crate) static #param_name:
141                     ::kernel::module_param::ModuleParamAccess<#param_type> =
142                         ::kernel::module_param::ModuleParamAccess::new(#param_default);
143 
144                 const _: () = {
145                     #[allow(non_upper_case_globals)]
146                     #[link_section = "__param"]
147                     #[used(compiler)]
148                     static #static_name:
149                         ::kernel::module_param::KernelParam =
150                         ::kernel::module_param::KernelParam::new(
151                             ::kernel::bindings::kernel_param {
152                                 name: kernel::str::as_char_ptr_in_const_context(
153                                     if ::core::cfg!(MODULE) {
154                                         #param_name_cstr
155                                     } else {
156                                         #param_name_cstr_with_module
157                                     }
158                                 ),
159                                 // SAFETY: `__this_module` is constructed by the kernel at load
160                                 // time and will not be freed until the module is unloaded.
161                                 #[cfg(MODULE)]
162                                 mod_: unsafe {
163                                     core::ptr::from_ref(&::kernel::bindings::__this_module)
164                                         .cast_mut()
165                                 },
166                                 #[cfg(not(MODULE))]
167                                 mod_: ::core::ptr::null_mut(),
168                                 ops: core::ptr::from_ref(&#ops),
169                                 perm: 0, // Will not appear in sysfs
170                                 level: -1,
171                                 flags: 0,
172                                 __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 {
173                                     arg: #param_name.as_void_ptr()
174                                 },
175                             }
176                         );
177                 };
178             });
179         }
180     }
181 }
182 
param_ops_path(param_type: &str) -> Path183 fn param_ops_path(param_type: &str) -> Path {
184     match param_type {
185         "i8" => parse_quote!(::kernel::module_param::PARAM_OPS_I8),
186         "u8" => parse_quote!(::kernel::module_param::PARAM_OPS_U8),
187         "i16" => parse_quote!(::kernel::module_param::PARAM_OPS_I16),
188         "u16" => parse_quote!(::kernel::module_param::PARAM_OPS_U16),
189         "i32" => parse_quote!(::kernel::module_param::PARAM_OPS_I32),
190         "u32" => parse_quote!(::kernel::module_param::PARAM_OPS_U32),
191         "i64" => parse_quote!(::kernel::module_param::PARAM_OPS_I64),
192         "u64" => parse_quote!(::kernel::module_param::PARAM_OPS_U64),
193         "isize" => parse_quote!(::kernel::module_param::PARAM_OPS_ISIZE),
194         "usize" => parse_quote!(::kernel::module_param::PARAM_OPS_USIZE),
195         t => panic!("Unsupported parameter type {}", t),
196     }
197 }
198 
199 /// Parse fields that are required to use a specific order.
200 ///
201 /// As fields must follow a specific order, we *could* just parse fields one by one by peeking.
202 /// However the error message generated when implementing that way is not very friendly.
203 ///
204 /// So instead we parse fields in an arbitrary order, but only enforce the ordering after parsing,
205 /// and if the wrong order is used, the proper order is communicated to the user with error message.
206 ///
207 /// Usage looks like this:
208 /// ```ignore
209 /// parse_ordered_fields! {
210 ///     from input;
211 ///
212 ///     // This will extract "foo: <field>" into a variable named "foo".
213 ///     // The variable will have type `Option<_>`.
214 ///     foo => <expression that parses the field>,
215 ///
216 ///     // If you need the variable name to be different than the key name.
217 ///     // This extracts "baz: <field>" into a variable named "bar".
218 ///     // You might want this if "baz" is a keyword.
219 ///     baz as bar => <expression that parse the field>,
220 ///
221 ///     // You can mark a key as required, and the variable will no longer be `Option`.
222 ///     // foobar will be of type `Expr` instead of `Option<Expr>`.
223 ///     foobar [required] => input.parse::<Expr>()?,
224 /// }
225 /// ```
226 macro_rules! parse_ordered_fields {
227     (@gen
228         [$input:expr]
229         [$([$name:ident; $key:ident; $parser:expr])*]
230         [$([$req_name:ident; $req_key:ident])*]
231     ) => {
232         $(let mut $name = None;)*
233 
234         const EXPECTED_KEYS: &[&str] = &[$(stringify!($key),)*];
235         const REQUIRED_KEYS: &[&str] = &[$(stringify!($req_key),)*];
236 
237         let span = $input.span();
238         let mut seen_keys = Vec::new();
239 
240         while !$input.is_empty() {
241             let key = $input.call(Ident::parse_any)?;
242 
243             if seen_keys.contains(&key) {
244                 Err(Error::new_spanned(
245                     &key,
246                     format!(r#"duplicated key "{key}". Keys can only be specified once."#),
247                 ))?
248             }
249 
250             $input.parse::<Token![:]>()?;
251 
252             match &*key.to_string() {
253                 $(
254                     stringify!($key) => $name = Some($parser),
255                 )*
256                 _ => {
257                     Err(Error::new_spanned(
258                         &key,
259                         format!(r#"unknown key "{key}". Valid keys are: {EXPECTED_KEYS:?}."#),
260                     ))?
261                 }
262             }
263 
264             $input.parse::<Token![,]>()?;
265             seen_keys.push(key);
266         }
267 
268         for key in REQUIRED_KEYS {
269             if !seen_keys.iter().any(|e| e == key) {
270                 Err(Error::new(span, format!(r#"missing required key "{key}""#)))?
271             }
272         }
273 
274         let mut ordered_keys: Vec<&str> = Vec::new();
275         for key in EXPECTED_KEYS {
276             if seen_keys.iter().any(|e| e == key) {
277                 ordered_keys.push(key);
278             }
279         }
280 
281         if seen_keys != ordered_keys {
282             Err(Error::new(
283                 span,
284                 format!(r#"keys are not ordered as expected. Order them like: {ordered_keys:?}."#),
285             ))?
286         }
287 
288         $(let $req_name = $req_name.expect("required field");)*
289     };
290 
291     // Handle required fields.
292     (@gen
293         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
294         $key:ident as $name:ident [required] => $parser:expr,
295         $($rest:tt)*
296     ) => {
297         parse_ordered_fields!(
298             @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)* [$name; $key]] $($rest)*
299         )
300     };
301     (@gen
302         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
303         $name:ident [required] => $parser:expr,
304         $($rest:tt)*
305     ) => {
306         parse_ordered_fields!(
307             @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)* [$name; $name]] $($rest)*
308         )
309     };
310 
311     // Handle optional fields.
312     (@gen
313         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
314         $key:ident as $name:ident => $parser:expr,
315         $($rest:tt)*
316     ) => {
317         parse_ordered_fields!(
318             @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)*] $($rest)*
319         )
320     };
321     (@gen
322         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
323         $name:ident => $parser:expr,
324         $($rest:tt)*
325     ) => {
326         parse_ordered_fields!(
327             @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)*] $($rest)*
328         )
329     };
330 
331     (from $input:expr; $($tok:tt)*) => {
332         parse_ordered_fields!(@gen [$input] [] [] $($tok)*)
333     }
334 }
335 
336 struct Parameter {
337     name: Ident,
338     ptype: Ident,
339     default: Expr,
340     description: LitStr,
341 }
342 
343 impl Parse for Parameter {
parse(input: ParseStream<'_>) -> Result<Self>344     fn parse(input: ParseStream<'_>) -> Result<Self> {
345         let name = input.parse()?;
346         input.parse::<Token![:]>()?;
347         let ptype = input.parse()?;
348 
349         let fields;
350         braced!(fields in input);
351 
352         parse_ordered_fields! {
353             from fields;
354             default [required] => fields.parse()?,
355             description [required] => fields.parse()?,
356         }
357 
358         Ok(Self {
359             name,
360             ptype,
361             default,
362             description,
363         })
364     }
365 }
366 
367 pub(crate) struct ModuleInfo {
368     type_: Type,
369     license: AsciiLitStr,
370     name: AsciiLitStr,
371     authors: Option<Punctuated<AsciiLitStr, Token![,]>>,
372     description: Option<LitStr>,
373     alias: Option<Punctuated<AsciiLitStr, Token![,]>>,
374     firmware: Option<Punctuated<AsciiLitStr, Token![,]>>,
375     imports_ns: Option<Punctuated<AsciiLitStr, Token![,]>>,
376     params: Option<Punctuated<Parameter, Token![,]>>,
377 }
378 
379 impl Parse for ModuleInfo {
parse(input: ParseStream<'_>) -> Result<Self>380     fn parse(input: ParseStream<'_>) -> Result<Self> {
381         parse_ordered_fields!(
382             from input;
383             type as type_ [required] => input.parse()?,
384             name [required] => input.parse()?,
385             authors => {
386                 let list;
387                 bracketed!(list in input);
388                 Punctuated::parse_terminated(&list)?
389             },
390             description => input.parse()?,
391             license [required] => input.parse()?,
392             alias => {
393                 let list;
394                 bracketed!(list in input);
395                 Punctuated::parse_terminated(&list)?
396             },
397             firmware => {
398                 let list;
399                 bracketed!(list in input);
400                 Punctuated::parse_terminated(&list)?
401             },
402             imports_ns => {
403                 let list;
404                 bracketed!(list in input);
405                 Punctuated::parse_terminated(&list)?
406             },
407             params => {
408                 let list;
409                 braced!(list in input);
410                 Punctuated::parse_terminated(&list)?
411             },
412         );
413 
414         Ok(ModuleInfo {
415             type_,
416             license,
417             name,
418             authors,
419             description,
420             alias,
421             firmware,
422             imports_ns,
423             params,
424         })
425     }
426 }
427 
module(info: ModuleInfo) -> Result<TokenStream>428 pub(crate) fn module(info: ModuleInfo) -> Result<TokenStream> {
429     let ModuleInfo {
430         type_,
431         license,
432         name,
433         authors,
434         description,
435         alias,
436         firmware,
437         imports_ns,
438         params: _,
439     } = &info;
440 
441     // Rust does not allow hyphens in identifiers, use underscore instead.
442     let ident = name.value().replace('-', "_");
443     let mut modinfo = ModInfoBuilder::new(ident.as_ref());
444     if let Some(authors) = authors {
445         for author in authors {
446             modinfo.emit("author", &author.value());
447         }
448     }
449     if let Some(description) = description {
450         modinfo.emit("description", &description.value());
451     }
452     modinfo.emit("license", &license.value());
453     if let Some(aliases) = alias {
454         for alias in aliases {
455             modinfo.emit("alias", &alias.value());
456         }
457     }
458     if let Some(firmware) = firmware {
459         for fw in firmware {
460             modinfo.emit("firmware", &fw.value());
461         }
462     }
463     if let Some(imports) = imports_ns {
464         for ns in imports {
465             modinfo.emit("import_ns", &ns.value());
466         }
467     }
468 
469     // Built-in modules also export the `file` modinfo string.
470     let file =
471         std::env::var("RUST_MODFILE").expect("Unable to fetch RUST_MODFILE environmental variable");
472     modinfo.emit_only_builtin("file", &file, false);
473 
474     modinfo.emit_params(&info);
475 
476     let modinfo_ts = modinfo.ts;
477     let params_ts = modinfo.param_ts;
478 
479     let ident_init = format_ident!("__{ident}_init");
480     let ident_exit = format_ident!("__{ident}_exit");
481     let ident_initcall = format_ident!("__{ident}_initcall");
482     let initcall_section = ".initcall6.init";
483 
484     let global_asm = format!(
485         r#".section "{initcall_section}", "a"
486         __{ident}_initcall:
487             .long   __{ident}_init - .
488             .previous
489         "#
490     );
491 
492     let name_cstr = CString::new(name.value()).expect("name contains NUL-terminator");
493 
494     Ok(quote! {
495         /// The module name.
496         ///
497         /// Used by the printing macros, e.g. [`info!`].
498         const __LOG_PREFIX: &[u8] = #name_cstr.to_bytes_with_nul();
499 
500         // SAFETY: `__this_module` is constructed by the kernel at load time and will not be
501         // freed until the module is unloaded.
502         #[cfg(MODULE)]
503         static THIS_MODULE: ::kernel::ThisModule = unsafe {
504             extern "C" {
505                 static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>;
506             };
507 
508             ::kernel::ThisModule::from_ptr(__this_module.get())
509         };
510 
511         #[cfg(not(MODULE))]
512         static THIS_MODULE: ::kernel::ThisModule = unsafe {
513             ::kernel::ThisModule::from_ptr(::core::ptr::null_mut())
514         };
515 
516         /// The `LocalModule` type is the type of the module created by `module!`,
517         /// `module_pci_driver!`, `module_platform_driver!`, etc.
518         type LocalModule = #type_;
519 
520         impl ::kernel::ModuleMetadata for #type_ {
521             const NAME: &'static ::kernel::str::CStr = #name_cstr;
522         }
523 
524         // Double nested modules, since then nobody can access the public items inside.
525         #[doc(hidden)]
526         mod __module_init {
527             mod __module_init {
528                 use pin_init::PinInit;
529 
530                 /// The "Rust loadable module" mark.
531                 //
532                 // This may be best done another way later on, e.g. as a new modinfo
533                 // key or a new section. For the moment, keep it simple.
534                 #[cfg(MODULE)]
535                 #[used(compiler)]
536                 static __IS_RUST_MODULE: () = ();
537 
538                 static mut __MOD: ::core::mem::MaybeUninit<super::super::LocalModule> =
539                     ::core::mem::MaybeUninit::uninit();
540 
541                 // Loadable modules need to export the `{init,cleanup}_module` identifiers.
542                 /// # Safety
543                 ///
544                 /// This function must not be called after module initialization, because it may be
545                 /// freed after that completes.
546                 #[cfg(MODULE)]
547                 #[no_mangle]
548                 #[link_section = ".init.text"]
549                 pub unsafe extern "C" fn init_module() -> ::kernel::ffi::c_int {
550                     // SAFETY: This function is inaccessible to the outside due to the double
551                     // module wrapping it. It is called exactly once by the C side via its
552                     // unique name.
553                     unsafe { __init() }
554                 }
555 
556                 #[cfg(MODULE)]
557                 #[used(compiler)]
558                 #[link_section = ".init.data"]
559                 static __UNIQUE_ID___addressable_init_module: unsafe extern "C" fn() -> i32 =
560                     init_module;
561 
562                 #[cfg(MODULE)]
563                 #[no_mangle]
564                 #[link_section = ".exit.text"]
565                 pub extern "C" fn cleanup_module() {
566                     // SAFETY:
567                     // - This function is inaccessible to the outside due to the double
568                     //   module wrapping it. It is called exactly once by the C side via its
569                     //   unique name,
570                     // - furthermore it is only called after `init_module` has returned `0`
571                     //   (which delegates to `__init`).
572                     unsafe { __exit() }
573                 }
574 
575                 #[cfg(MODULE)]
576                 #[used(compiler)]
577                 #[link_section = ".exit.data"]
578                 static __UNIQUE_ID___addressable_cleanup_module: extern "C" fn() = cleanup_module;
579 
580                 // Built-in modules are initialized through an initcall pointer
581                 // and the identifiers need to be unique.
582                 #[cfg(not(MODULE))]
583                 #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))]
584                 #[link_section = #initcall_section]
585                 #[used(compiler)]
586                 pub static #ident_initcall: extern "C" fn() ->
587                     ::kernel::ffi::c_int = #ident_init;
588 
589                 #[cfg(not(MODULE))]
590                 #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)]
591                 ::core::arch::global_asm!(#global_asm);
592 
593                 #[cfg(not(MODULE))]
594                 #[no_mangle]
595                 pub extern "C" fn #ident_init() -> ::kernel::ffi::c_int {
596                     // SAFETY: This function is inaccessible to the outside due to the double
597                     // module wrapping it. It is called exactly once by the C side via its
598                     // placement above in the initcall section.
599                     unsafe { __init() }
600                 }
601 
602                 #[cfg(not(MODULE))]
603                 #[no_mangle]
604                 pub extern "C" fn #ident_exit() {
605                     // SAFETY:
606                     // - This function is inaccessible to the outside due to the double
607                     //   module wrapping it. It is called exactly once by the C side via its
608                     //   unique name,
609                     // - furthermore it is only called after `#ident_init` has
610                     //   returned `0` (which delegates to `__init`).
611                     unsafe { __exit() }
612                 }
613 
614                 /// # Safety
615                 ///
616                 /// This function must only be called once.
617                 unsafe fn __init() -> ::kernel::ffi::c_int {
618                     let initer = <super::super::LocalModule as ::kernel::InPlaceModule>::init(
619                         &super::super::THIS_MODULE
620                     );
621                     // SAFETY: No data race, since `__MOD` can only be accessed by this module
622                     // and there only `__init` and `__exit` access it. These functions are only
623                     // called once and `__exit` cannot be called before or during `__init`.
624                     match unsafe { initer.__pinned_init(__MOD.as_mut_ptr()) } {
625                         Ok(m) => 0,
626                         Err(e) => e.to_errno(),
627                     }
628                 }
629 
630                 /// # Safety
631                 ///
632                 /// This function must
633                 /// - only be called once,
634                 /// - be called after `__init` has been called and returned `0`.
635                 unsafe fn __exit() {
636                     // SAFETY: No data race, since `__MOD` can only be accessed by this module
637                     // and there only `__init` and `__exit` access it. These functions are only
638                     // called once and `__init` was already called.
639                     unsafe {
640                         // Invokes `drop()` on `__MOD`, which should be used for cleanup.
641                         __MOD.assume_init_drop();
642                     }
643                 }
644 
645                 #modinfo_ts
646             }
647         }
648 
649         mod module_parameters {
650             #params_ts
651         }
652     })
653 }
654