xref: /linux/tools/lib/python/unittest_helper.py (revision 5181afcdf99527dd92a88f80fc4d0d8013e1b510)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3# Copyright(c) 2025-2026: Mauro Carvalho Chehab <mchehab@kernel.org>.
4#
5# pylint: disable=C0103,R0912,R0914,E1101
6
7"""
8Provides helper functions and classes execute python unit tests.
9
10Those help functions provide a nice colored output summary of each
11executed test and, when a test fails, it shows the different in diff
12format when running in verbose mode, like::
13
14    $ tools/unittests/nested_match.py -v
15    ...
16    Traceback (most recent call last):
17    File "/new_devel/docs/tools/unittests/nested_match.py", line 69, in test_count_limit
18        self.assertEqual(replaced, "bar(a); bar(b); foo(c)")
19        ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20    AssertionError: 'bar(a) foo(b); foo(c)' != 'bar(a); bar(b); foo(c)'
21    - bar(a) foo(b); foo(c)
22    ?       ^^^^
23    + bar(a); bar(b); foo(c)
24    ?       ^^^^^
25    ...
26
27It also allows filtering what tests will be executed via ``-k`` parameter.
28
29Typical usage is to do::
30
31    from unittest_helper import run_unittest
32    ...
33
34    if __name__ == "__main__":
35        run_unittest(__file__)
36
37If passing arguments is needed, on a more complex scenario, it can be
38used like on this example::
39
40    from unittest_helper import TestUnits, run_unittest
41    ...
42    env = {'sudo': ""}
43    ...
44    if __name__ == "__main__":
45        runner = TestUnits()
46        base_parser = runner.parse_args()
47        base_parser.add_argument('--sudo', action='store_true',
48                                help='Enable tests requiring sudo privileges')
49
50        args = base_parser.parse_args()
51
52        # Update module-level flag
53        if args.sudo:
54            env['sudo'] = "1"
55
56        # Run tests with customized arguments
57        runner.run(__file__, parser=base_parser, args=args, env=env)
58"""
59
60import argparse
61import atexit
62import os
63import re
64import unittest
65import sys
66
67from unittest.mock import patch
68
69
70class Summary(unittest.TestResult):
71    """
72    Overrides ``unittest.TestResult`` class to provide a nice colored
73    summary. When in verbose mode, displays actual/expected difference in
74    unified diff format.
75    """
76    def __init__(self, *args, **kwargs):
77        super().__init__(*args, **kwargs)
78
79        #: Dictionary to store organized test results.
80        self.test_results = {}
81
82        #: max length of the test names.
83        self.max_name_length = 0
84
85    def startTest(self, test):
86        super().startTest(test)
87        test_id = test.id()
88        parts = test_id.split(".")
89
90        # Extract module, class, and method names
91        if len(parts) >= 3:
92            module_name = parts[-3]
93        else:
94            module_name = ""
95        if len(parts) >= 2:
96            class_name = parts[-2]
97        else:
98            class_name = ""
99
100        method_name = parts[-1]
101
102        # Build the hierarchical structure
103        if module_name not in self.test_results:
104            self.test_results[module_name] = {}
105
106        if class_name not in self.test_results[module_name]:
107            self.test_results[module_name][class_name] = []
108
109        # Track maximum test name length for alignment
110        display_name = f"{method_name}:"
111
112        self.max_name_length = max(len(display_name), self.max_name_length)
113
114    def _record_test(self, test, status):
115        test_id = test.id()
116        parts = test_id.split(".")
117        if len(parts) >= 3:
118            module_name = parts[-3]
119        else:
120            module_name = ""
121        if len(parts) >= 2:
122            class_name = parts[-2]
123        else:
124            class_name = ""
125        method_name = parts[-1]
126        self.test_results[module_name][class_name].append((method_name, status))
127
128    def addSuccess(self, test):
129        super().addSuccess(test)
130        self._record_test(test, "OK")
131
132    def addFailure(self, test, err):
133        super().addFailure(test, err)
134        self._record_test(test, "FAIL")
135
136    def addError(self, test, err):
137        super().addError(test, err)
138        self._record_test(test, "ERROR")
139
140    def addSkip(self, test, reason):
141        super().addSkip(test, reason)
142        self._record_test(test, f"SKIP ({reason})")
143
144    def printResults(self, verbose):
145        """
146        Print results using colors if tty.
147        """
148        # Check for ANSI color support
149        use_color = sys.stdout.isatty()
150        COLORS = {
151            "OK":            "\033[32m",   # Green
152            "FAIL":          "\033[31m",   # Red
153            "SKIP":          "\033[1;33m", # Yellow
154            "PARTIAL":       "\033[33m",   # Orange
155            "EXPECTED_FAIL": "\033[36m",   # Cyan
156            "reset":         "\033[0m",    # Reset to default terminal color
157        }
158        if not use_color:
159            for c in COLORS:
160                COLORS[c] = ""
161
162        # Calculate maximum test name length
163        if not self.test_results:
164            return
165        try:
166            lengths = []
167            for module in self.test_results.values():
168                for tests in module.values():
169                    for test_name, _ in tests:
170                        lengths.append(len(test_name) + 1)  # +1 for colon
171            max_length = max(lengths) + 2  # Additional padding
172        except ValueError:
173            sys.exit("Test list is empty")
174
175        # Print results
176        for module_name, classes in self.test_results.items():
177            if verbose:
178                print(f"{module_name}:")
179            for class_name, tests in classes.items():
180                if verbose:
181                    print(f"    {class_name}:")
182                for test_name, status in tests:
183                    if not verbose and status in [ "OK", "EXPECTED_FAIL" ]:
184                        continue
185
186                    # Get base status without reason for SKIP
187                    if status.startswith("SKIP"):
188                        status_code = status.split()[0]
189                    else:
190                        status_code = status
191                    color = COLORS.get(status_code, "")
192                    print(
193                        f"        {test_name + ':':<{max_length}}{color}{status}{COLORS['reset']}"
194                    )
195            if verbose:
196                print()
197
198        # Print summary
199        print(f"\nRan {self.testsRun} tests", end="")
200        if hasattr(self, "timeTaken"):
201            print(f" in {self.timeTaken:.3f}s", end="")
202        print()
203
204        if not self.wasSuccessful():
205            print(f"\n{COLORS['FAIL']}FAILED (", end="")
206            failures = getattr(self, "failures", [])
207            errors = getattr(self, "errors", [])
208            if failures:
209                print(f"failures={len(failures)}", end="")
210            if errors:
211                if failures:
212                    print(", ", end="")
213                print(f"errors={len(errors)}", end="")
214            print(f"){COLORS['reset']}")
215
216
217def flatten_suite(suite):
218    """Flatten test suite hierarchy."""
219    tests = []
220    for item in suite:
221        if isinstance(item, unittest.TestSuite):
222            tests.extend(flatten_suite(item))
223        else:
224            tests.append(item)
225    return tests
226
227
228class TestUnits:
229    """
230    Helper class to set verbosity level.
231
232    This class discover test files, import its unittest classes and
233    executes the test on it.
234    """
235    def parse_args(self):
236        """Returns a parser for command line arguments."""
237        parser = argparse.ArgumentParser(description="Test runner with regex filtering")
238        parser.add_argument("-v", "--verbose", action="count", default=1)
239        parser.add_argument("-q", "--quiet", action="store_true")
240        parser.add_argument("-f", "--failfast", action="store_true")
241        parser.add_argument("-k", "--keyword",
242                            help="Regex pattern to filter test methods")
243        return parser
244
245    def run(self, caller_file=None, pattern=None,
246            suite=None, parser=None, args=None, env=None):
247        """
248        Execute all tests from the unity test file.
249
250        It contains several optional parameters:
251
252        ``caller_file``:
253            -  name of the file that contains test.
254
255               typical usage is to place __file__ at the caller test, e.g.::
256
257                    if __name__ == "__main__":
258                        TestUnits().run(__file__)
259
260        ``pattern``:
261            - optional pattern to match multiple file names. Defaults
262              to basename of ``caller_file``.
263
264        ``suite``:
265            - an unittest suite initialized by the caller using
266              ``unittest.TestLoader().discover()``.
267
268        ``parser``:
269            - an argparse parser. If not defined, this helper will create
270              one.
271
272        ``args``:
273            - an ``argparse.Namespace`` data filled by the caller.
274
275        ``env``:
276            - environment variables that will be passed to the test suite
277
278        At least ``caller_file`` or ``suite`` must be used, otherwise a
279        ``TypeError`` will be raised.
280        """
281        if not args:
282            if not parser:
283                parser = self.parse_args()
284            args = parser.parse_args()
285
286        if not caller_file and not suite:
287            raise TypeError("Either caller_file or suite is needed at TestUnits")
288
289        if args.quiet:
290            verbose = 0
291        else:
292            verbose = args.verbose
293
294        if not env:
295            env = os.environ.copy()
296
297        env["VERBOSE"] = f"{verbose}"
298
299        patcher = patch.dict(os.environ, env)
300        patcher.start()
301        # ensure it gets stopped after
302        atexit.register(patcher.stop)
303
304
305        if verbose >= 2:
306            unittest.TextTestRunner(verbosity=verbose).run = lambda suite: suite
307
308        # Load ONLY tests from the calling file
309        if not suite:
310            if not pattern:
311                pattern = caller_file
312
313            loader = unittest.TestLoader()
314            suite = loader.discover(start_dir=os.path.dirname(caller_file),
315                                    pattern=os.path.basename(caller_file))
316
317        # Flatten the suite for environment injection
318        tests_to_inject = flatten_suite(suite)
319
320        # Filter tests by method name if -k specified
321        if args.keyword:
322            try:
323                pattern = re.compile(args.keyword)
324                filtered_suite = unittest.TestSuite()
325                for test in tests_to_inject:  # Use the pre-flattened list
326                    method_name = test.id().split(".")[-1]
327                    if pattern.search(method_name):
328                        filtered_suite.addTest(test)
329                suite = filtered_suite
330            except re.error as e:
331                sys.stderr.write(f"Invalid regex pattern: {e}\n")
332                sys.exit(1)
333        else:
334            # Maintain original suite structure if no keyword filtering
335            suite = unittest.TestSuite(tests_to_inject)
336
337        if verbose >= 2:
338            resultclass = None
339        else:
340            resultclass = Summary
341
342        runner = unittest.TextTestRunner(verbosity=args.verbose,
343                                            resultclass=resultclass,
344                                            failfast=args.failfast)
345        result = runner.run(suite)
346        if resultclass:
347            result.printResults(verbose)
348
349        sys.exit(not result.wasSuccessful())
350
351
352def run_unittest(fname):
353    """
354    Basic usage of TestUnits class.
355
356    Use it when there's no need to pass any extra argument to the tests
357    with. The recommended way is to place this at the end of each
358    unittest module::
359
360        if __name__ == "__main__":
361            run_unittest(__file__)
362    """
363    TestUnits().run(fname)
364