1 // SPDX-License-Identifier: GPL-2.0
2 /*
3 * The test creates shmem PMD huge pages, fills all pages with known patterns,
4 * then continuously verifies non-punched pages with 16 threads. Meanwhile, the
5 * main thread punches holes via MADV_REMOVE on the shmem.
6 *
7 * It tests the race condition between folio_split() and filemap_get_entry(),
8 * where the hole punches on shmem lead to folio_split() and reading the shmem
9 * lead to filemap_get_entry().
10 */
11
12 #define _GNU_SOURCE
13 #include <errno.h>
14 #include <inttypes.h>
15 #include <linux/mman.h>
16 #include <pthread.h>
17 #include <stdatomic.h>
18 #include <stdbool.h>
19 #include <stdint.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 #include <sys/mman.h>
24 #include <signal.h>
25 #include <unistd.h>
26 #include "vm_util.h"
27 #include "kselftest.h"
28 #include "thp_settings.h"
29
30 uint64_t page_size;
31 uint64_t pmd_pagesize;
32 #define NR_PMD_PAGE 5
33 #define FILE_SIZE (pmd_pagesize * NR_PMD_PAGE)
34 #define TOTAL_PAGES (FILE_SIZE / page_size)
35
36 /* Every N-th to N+M-th pages are punched; not aligned with huge page boundaries. */
37 #define PUNCH_INTERVAL 50 /* N */
38 #define PUNCH_SIZE_FACTOR 3 /* M */
39
40 #define NUM_READER_THREADS 16
41 #define FILL_BYTE 0xAF
42 #define NUM_ITERATIONS 100
43
44 /* Shared control block: control reading threads and record stats */
45 struct shared_ctl {
46 atomic_uint_fast32_t stop;
47 atomic_uint_fast64_t reader_failures;
48 atomic_uint_fast64_t reader_verified;
49 pthread_barrier_t barrier;
50 };
51
fill_page(unsigned char * base,size_t page_idx)52 static void fill_page(unsigned char *base, size_t page_idx)
53 {
54 unsigned char *page_ptr = base + page_idx * page_size;
55 uint64_t idx = (uint64_t)page_idx;
56
57 memset(page_ptr, FILL_BYTE, page_size);
58 memcpy(page_ptr, &idx, sizeof(idx));
59 }
60
61 /* Returns true if valid, false if corrupted. */
check_page(unsigned char * base,uint64_t page_idx)62 static bool check_page(unsigned char *base, uint64_t page_idx)
63 {
64 unsigned char *page_ptr = base + page_idx * page_size;
65 uint64_t expected_idx = (uint64_t)page_idx;
66 uint64_t got_idx;
67
68 memcpy(&got_idx, page_ptr, 8);
69
70 if (got_idx != expected_idx) {
71 uint64_t off;
72 int all_zero = 1;
73
74 for (off = 0; off < page_size; off++) {
75 if (page_ptr[off] != 0) {
76 all_zero = 0;
77 break;
78 }
79 }
80 if (all_zero) {
81 ksft_print_msg("CORRUPTED: page %" PRIu64
82 " (huge page %" PRIu64
83 ") is ALL ZEROS\n",
84 page_idx,
85 (page_idx * page_size) / pmd_pagesize);
86 } else {
87 ksft_print_msg("CORRUPTED: page %" PRIu64
88 " (huge page %" PRIu64
89 "): expected idx %" PRIu64
90 ", got %" PRIu64 "\n",
91 page_idx,
92 (page_idx * page_size) / pmd_pagesize,
93 page_idx, got_idx);
94 }
95 return false;
96 }
97 return true;
98 }
99
100 struct reader_arg {
101 unsigned char *base;
102 struct shared_ctl *ctl;
103 int tid;
104 atomic_uint_fast64_t *failures;
105 atomic_uint_fast64_t *verified;
106 };
107
reader_thread(void * arg)108 static void *reader_thread(void *arg)
109 {
110 struct reader_arg *ra = (struct reader_arg *)arg;
111 unsigned char *base = ra->base;
112 struct shared_ctl *ctl = ra->ctl;
113 int tid = ra->tid;
114 atomic_uint_fast64_t *failures = ra->failures;
115 atomic_uint_fast64_t *verified = ra->verified;
116 uint64_t page_idx;
117
118 pthread_barrier_wait(&ctl->barrier);
119
120 while (atomic_load_explicit(&ctl->stop, memory_order_acquire) == 0) {
121 for (page_idx = (size_t)tid; page_idx < TOTAL_PAGES;
122 page_idx += NUM_READER_THREADS) {
123 /*
124 * page_idx % PUNCH_INTERVAL is in [0, PUNCH_INTERVAL),
125 * skip [0, PUNCH_SIZE_FACTOR)
126 */
127 if (page_idx % PUNCH_INTERVAL < PUNCH_SIZE_FACTOR)
128 continue;
129 if (check_page(base, page_idx))
130 atomic_fetch_add_explicit(verified, 1,
131 memory_order_relaxed);
132 else
133 atomic_fetch_add_explicit(failures, 1,
134 memory_order_relaxed);
135 }
136 if (atomic_load_explicit(failures, memory_order_relaxed) > 0)
137 break;
138 }
139
140 return NULL;
141 }
142
create_readers(pthread_t * threads,struct reader_arg * args,unsigned char * base,struct shared_ctl * ctl)143 static void create_readers(pthread_t *threads, struct reader_arg *args,
144 unsigned char *base, struct shared_ctl *ctl)
145 {
146 int i;
147
148 for (i = 0; i < NUM_READER_THREADS; i++) {
149 args[i].base = base;
150 args[i].ctl = ctl;
151 args[i].tid = i;
152 args[i].failures = &ctl->reader_failures;
153 args[i].verified = &ctl->reader_verified;
154 if (pthread_create(&threads[i], NULL, reader_thread,
155 &args[i]) != 0)
156 ksft_exit_fail_msg("pthread_create failed\n");
157 }
158 }
159
160 /* Run a single iteration. Returns total number of corrupted pages. */
run_iteration(void)161 static uint64_t run_iteration(void)
162 {
163 uint64_t reader_failures, reader_verified;
164 struct reader_arg args[NUM_READER_THREADS];
165 pthread_t threads[NUM_READER_THREADS];
166 unsigned char *mmap_base;
167 struct shared_ctl ctl;
168 uint64_t i;
169
170 memset(&ctl, 0, sizeof(struct shared_ctl));
171
172 mmap_base = mmap(NULL, FILE_SIZE, PROT_READ | PROT_WRITE,
173 MAP_SHARED | MAP_ANONYMOUS, -1, 0);
174
175 if (mmap_base == MAP_FAILED)
176 ksft_exit_fail_msg("mmap failed: %d\n", errno);
177
178 if (madvise(mmap_base, FILE_SIZE, MADV_HUGEPAGE) != 0)
179 ksft_exit_fail_msg("madvise(MADV_HUGEPAGE) failed: %d\n",
180 errno);
181
182 for (i = 0; i < TOTAL_PAGES; i++)
183 fill_page(mmap_base, i);
184
185 if (!check_huge_shmem(mmap_base, NR_PMD_PAGE, pmd_pagesize))
186 ksft_exit_fail_msg("No shmem THP is allocated\n");
187
188 if (pthread_barrier_init(&ctl.barrier, NULL, NUM_READER_THREADS + 1) != 0)
189 ksft_exit_fail_msg("pthread_barrier_init failed\n");
190
191 create_readers(threads, args, mmap_base, &ctl);
192
193 /* Wait for all reader threads to be ready before punching holes. */
194 pthread_barrier_wait(&ctl.barrier);
195
196 for (i = 0; i < TOTAL_PAGES; i++) {
197 if (i % PUNCH_INTERVAL != 0)
198 continue;
199 if (madvise(mmap_base + i * page_size,
200 PUNCH_SIZE_FACTOR * page_size, MADV_REMOVE) != 0) {
201 ksft_exit_fail_msg(
202 "madvise(MADV_REMOVE) failed on page %" PRIu64 ": %d\n",
203 i, errno);
204 }
205
206 i += PUNCH_SIZE_FACTOR - 1;
207 }
208
209 atomic_store_explicit(&ctl.stop, 1, memory_order_release);
210
211 for (i = 0; i < NUM_READER_THREADS; i++)
212 pthread_join(threads[i], NULL);
213
214 pthread_barrier_destroy(&ctl.barrier);
215
216 reader_failures = atomic_load_explicit(&ctl.reader_failures,
217 memory_order_acquire);
218 reader_verified = atomic_load_explicit(&ctl.reader_verified,
219 memory_order_acquire);
220 if (reader_failures)
221 ksft_print_msg("Child: %" PRIu64 " pages verified, %" PRIu64 " failures\n",
222 reader_verified, reader_failures);
223
224 munmap(mmap_base, FILE_SIZE);
225
226 return reader_failures;
227 }
228
thp_cleanup_handler(int signum)229 static void thp_cleanup_handler(int signum)
230 {
231 thp_restore_settings();
232 /*
233 * Restore default handler and re-raise the signal to exit.
234 * This is to ensure the test process exits with the correct
235 * status code corresponding to the signal.
236 */
237 signal(signum, SIG_DFL);
238 raise(signum);
239 }
240
thp_settings_cleanup(void)241 static void thp_settings_cleanup(void)
242 {
243 thp_restore_settings();
244 }
245
main(void)246 int main(void)
247 {
248 struct thp_settings current_settings;
249 uint64_t corrupted_pages;
250 uint64_t iter;
251
252 ksft_print_header();
253
254 page_size = getpagesize();
255 pmd_pagesize = read_pmd_pagesize();
256
257 if (!thp_available() || !pmd_pagesize)
258 ksft_exit_skip("Transparent Hugepages not available\n");
259
260 if (geteuid() != 0)
261 ksft_exit_skip("Please run the test as root\n");
262
263 thp_save_settings();
264 /* make sure thp settings are restored */
265 if (atexit(thp_settings_cleanup) != 0)
266 ksft_exit_fail_msg("atexit failed\n");
267
268 signal(SIGINT, thp_cleanup_handler);
269 signal(SIGTERM, thp_cleanup_handler);
270
271 thp_read_settings(¤t_settings);
272 current_settings.shmem_enabled = SHMEM_ADVISE;
273 thp_write_settings(¤t_settings);
274
275 ksft_set_plan(1);
276
277 ksft_print_msg("folio split race test\n");
278
279 for (iter = 0; iter < NUM_ITERATIONS; iter++) {
280 corrupted_pages = run_iteration();
281 if (corrupted_pages > 0)
282 break;
283 }
284
285 if (iter < NUM_ITERATIONS)
286 ksft_test_result_fail("FAILED on iteration %" PRIu64
287 ": %" PRIu64
288 " pages corrupted by MADV_REMOVE!\n",
289 iter, corrupted_pages);
290 else
291 ksft_test_result_pass("All %d iterations passed\n",
292 NUM_ITERATIONS);
293
294 ksft_exit(iter == NUM_ITERATIONS);
295
296 return 0;
297 }
298