1 // SPDX-License-Identifier: GPL-2.0
2
3 //! Procedural macro to run KUnit tests using a user-space like syntax.
4 //!
5 //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
6
7 use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
8 use std::collections::HashMap;
9 use std::fmt::Write;
10
kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream11 pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
12 let attr = attr.to_string();
13
14 if attr.is_empty() {
15 panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
16 }
17
18 if attr.len() > 255 {
19 panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes")
20 }
21
22 let mut tokens: Vec<_> = ts.into_iter().collect();
23
24 // Scan for the `mod` keyword.
25 tokens
26 .iter()
27 .find_map(|token| match token {
28 TokenTree::Ident(ident) => match ident.to_string().as_str() {
29 "mod" => Some(true),
30 _ => None,
31 },
32 _ => None,
33 })
34 .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");
35
36 // Retrieve the main body. The main body should be the last token tree.
37 let body = match tokens.pop() {
38 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
39 _ => panic!("Cannot locate main body of module"),
40 };
41
42 // Get the functions set as tests. Search for `[test]` -> `fn`.
43 let mut body_it = body.stream().into_iter();
44 let mut tests = Vec::new();
45 let mut attributes: HashMap<String, TokenStream> = HashMap::new();
46 while let Some(token) = body_it.next() {
47 match token {
48 TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() {
49 Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
50 if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() {
51 // Collect attributes because we need to find which are tests. We also
52 // need to copy `cfg` attributes so tests can be conditionally enabled.
53 attributes
54 .entry(name.to_string())
55 .or_default()
56 .extend([token, TokenTree::Group(g)]);
57 }
58 continue;
59 }
60 _ => (),
61 },
62 TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => {
63 if let Some(TokenTree::Ident(test_name)) = body_it.next() {
64 tests.push((test_name, attributes.remove("cfg").unwrap_or_default()))
65 }
66 }
67
68 _ => (),
69 }
70 attributes.clear();
71 }
72
73 // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration.
74 let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap();
75 tokens.insert(
76 0,
77 TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
78 );
79
80 // Generate the test KUnit test suite and a test case for each `#[test]`.
81 // The code generated for the following test module:
82 //
83 // ```
84 // #[kunit_tests(kunit_test_suit_name)]
85 // mod tests {
86 // #[test]
87 // fn foo() {
88 // assert_eq!(1, 1);
89 // }
90 //
91 // #[test]
92 // fn bar() {
93 // assert_eq!(2, 2);
94 // }
95 // }
96 // ```
97 //
98 // Looks like:
99 //
100 // ```
101 // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); }
102 // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); }
103 //
104 // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [
105 // ::kernel::kunit::kunit_case(::kernel::c_str!("foo"), kunit_rust_wrapper_foo),
106 // ::kernel::kunit::kunit_case(::kernel::c_str!("bar"), kunit_rust_wrapper_bar),
107 // ::kernel::kunit::kunit_case_null(),
108 // ];
109 //
110 // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
111 // ```
112 let mut kunit_macros = "".to_owned();
113 let mut test_cases = "".to_owned();
114 let mut assert_macros = "".to_owned();
115 let path = crate::helpers::file();
116 let num_tests = tests.len();
117 for (test, cfg_attr) in tests {
118 let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");
119 // Append any `cfg` attributes the user might have written on their tests so we don't
120 // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce
121 // the length of the assert message.
122 let kunit_wrapper = format!(
123 r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit)
124 {{
125 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
126 {cfg_attr} {{
127 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
128 use ::kernel::kunit::is_test_result_ok;
129 assert!(is_test_result_ok({test}()));
130 }}
131 }}"#,
132 );
133 writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
134 writeln!(
135 test_cases,
136 " ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name}),"
137 )
138 .unwrap();
139 writeln!(
140 assert_macros,
141 r#"
142 /// Overrides the usual [`assert!`] macro with one that calls KUnit instead.
143 #[allow(unused)]
144 macro_rules! assert {{
145 ($cond:expr $(,)?) => {{{{
146 kernel::kunit_assert!("{test}", "{path}", 0, $cond);
147 }}}}
148 }}
149
150 /// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead.
151 #[allow(unused)]
152 macro_rules! assert_eq {{
153 ($left:expr, $right:expr $(,)?) => {{{{
154 kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right);
155 }}}}
156 }}
157 "#
158 )
159 .unwrap();
160 }
161
162 writeln!(kunit_macros).unwrap();
163 writeln!(
164 kunit_macros,
165 "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];",
166 num_tests + 1
167 )
168 .unwrap();
169
170 writeln!(
171 kunit_macros,
172 "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
173 )
174 .unwrap();
175
176 // Remove the `#[test]` macros.
177 // We do this at a token level, in order to preserve span information.
178 let mut new_body = vec![];
179 let mut body_it = body.stream().into_iter();
180
181 while let Some(token) = body_it.next() {
182 match token {
183 TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {
184 Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),
185 Some(next) => {
186 new_body.extend([token, next]);
187 }
188 _ => {
189 new_body.push(token);
190 }
191 },
192 _ => {
193 new_body.push(token);
194 }
195 }
196 }
197
198 let mut final_body = TokenStream::new();
199 final_body.extend::<TokenStream>(assert_macros.parse().unwrap());
200 final_body.extend(new_body);
201 final_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
202
203 tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body)));
204
205 tokens.into_iter().collect()
206 }
207