1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Copyright (C) 2023 ARM Limited.
4 */
5
6 #include <limits.h>
7 #include <stdbool.h>
8
9 #include <linux/prctl.h>
10
11 #include <sys/mman.h>
12 #include <asm/mman.h>
13 #include <asm/hwcap.h>
14 #include <linux/sched.h>
15
16 #include "kselftest.h"
17 #include "gcs-util.h"
18
19 /* nolibc doesn't have sysconf(), just hard code the maximum */
20 static size_t page_size = 65536;
21
valid_gcs_function(void)22 static __attribute__((noinline)) void valid_gcs_function(void)
23 {
24 /* Do something the compiler can't optimise out */
25 syscall(__NR_prctl, PR_SVE_GET_VL);
26 }
27
gcs_set_status(unsigned long mode)28 static inline int gcs_set_status(unsigned long mode)
29 {
30 bool enabling = mode & PR_SHADOW_STACK_ENABLE;
31 int ret;
32 unsigned long new_mode;
33
34 /*
35 * The prctl takes 1 argument but we need to ensure that the
36 * other 3 values passed in registers to the syscall are zero
37 * since the kernel validates them.
38 */
39 ret = syscall(__NR_prctl, PR_SET_SHADOW_STACK_STATUS, mode, 0, 0, 0);
40
41 if (ret == 0) {
42 ret = syscall(__NR_prctl, PR_GET_SHADOW_STACK_STATUS, &new_mode, 0, 0, 0);
43 if (ret == 0) {
44 if (new_mode != mode) {
45 ksft_print_msg("Mode set to %lx not %lx\n",
46 new_mode, mode);
47 ret = -EINVAL;
48 }
49 } else {
50 ksft_print_msg("Failed to validate mode: %d\n", errno);
51 }
52
53 if (enabling != chkfeat_gcs()) {
54 ksft_print_msg("%senabled by prctl but %senabled in CHKFEAT\n",
55 enabling ? "" : "not ",
56 chkfeat_gcs() ? "" : "not ");
57 ret = -EINVAL;
58 }
59 }
60
61 return ret;
62 }
63
64 /* Try to read the status */
read_status(void)65 static bool read_status(void)
66 {
67 unsigned long state;
68 int ret;
69
70 ret = syscall(__NR_prctl, PR_GET_SHADOW_STACK_STATUS, &state, 0, 0, 0);
71 if (ret != 0) {
72 ksft_print_msg("Failed to read state: %d\n", errno);
73 return false;
74 }
75
76 return state & PR_SHADOW_STACK_ENABLE;
77 }
78
79 /* Just a straight enable */
base_enable(void)80 static bool base_enable(void)
81 {
82 int ret;
83
84 ret = gcs_set_status(PR_SHADOW_STACK_ENABLE);
85 if (ret) {
86 ksft_print_msg("PR_SHADOW_STACK_ENABLE failed %d\n", ret);
87 return false;
88 }
89
90 return true;
91 }
92
93 /* Check we can read GCSPR_EL0 when GCS is enabled */
read_gcspr_el0(void)94 static bool read_gcspr_el0(void)
95 {
96 unsigned long *gcspr_el0;
97
98 ksft_print_msg("GET GCSPR\n");
99 gcspr_el0 = get_gcspr();
100 ksft_print_msg("GCSPR_EL0 is %p\n", gcspr_el0);
101
102 return true;
103 }
104
105 /* Also allow writes to stack */
enable_writeable(void)106 static bool enable_writeable(void)
107 {
108 int ret;
109
110 ret = gcs_set_status(PR_SHADOW_STACK_ENABLE | PR_SHADOW_STACK_WRITE);
111 if (ret) {
112 ksft_print_msg("PR_SHADOW_STACK_ENABLE writeable failed: %d\n", ret);
113 return false;
114 }
115
116 ret = gcs_set_status(PR_SHADOW_STACK_ENABLE);
117 if (ret) {
118 ksft_print_msg("failed to restore plain enable %d\n", ret);
119 return false;
120 }
121
122 return true;
123 }
124
125 /* Also allow writes to stack */
enable_push_pop(void)126 static bool enable_push_pop(void)
127 {
128 int ret;
129
130 ret = gcs_set_status(PR_SHADOW_STACK_ENABLE | PR_SHADOW_STACK_PUSH);
131 if (ret) {
132 ksft_print_msg("PR_SHADOW_STACK_ENABLE with push failed: %d\n",
133 ret);
134 return false;
135 }
136
137 ret = gcs_set_status(PR_SHADOW_STACK_ENABLE);
138 if (ret) {
139 ksft_print_msg("failed to restore plain enable %d\n", ret);
140 return false;
141 }
142
143 return true;
144 }
145
146 /* Enable GCS and allow everything */
enable_all(void)147 static bool enable_all(void)
148 {
149 int ret;
150
151 ret = gcs_set_status(PR_SHADOW_STACK_ENABLE | PR_SHADOW_STACK_PUSH |
152 PR_SHADOW_STACK_WRITE);
153 if (ret) {
154 ksft_print_msg("PR_SHADOW_STACK_ENABLE with everything failed: %d\n",
155 ret);
156 return false;
157 }
158
159 ret = gcs_set_status(PR_SHADOW_STACK_ENABLE);
160 if (ret) {
161 ksft_print_msg("failed to restore plain enable %d\n", ret);
162 return false;
163 }
164
165 return true;
166 }
167
enable_invalid(void)168 static bool enable_invalid(void)
169 {
170 int ret = gcs_set_status(ULONG_MAX);
171 if (ret == 0) {
172 ksft_print_msg("GCS_SET_STATUS %lx succeeded\n", ULONG_MAX);
173 return false;
174 }
175
176 return true;
177 }
178
179 /* Map a GCS */
map_guarded_stack(void)180 static bool map_guarded_stack(void)
181 {
182 int ret;
183 uint64_t *buf;
184 uint64_t expected_cap;
185 int elem;
186 bool pass = true;
187
188 buf = (void *)syscall(__NR_map_shadow_stack, 0, page_size,
189 SHADOW_STACK_SET_MARKER | SHADOW_STACK_SET_TOKEN);
190 if (buf == MAP_FAILED) {
191 ksft_print_msg("Failed to map %lu byte GCS: %d\n",
192 page_size, errno);
193 return false;
194 }
195 ksft_print_msg("Mapped GCS at %p-%p\n", buf,
196 (void *)((uint64_t)buf + page_size));
197
198 /* The top of the newly allocated region should be 0 */
199 elem = (page_size / sizeof(uint64_t)) - 1;
200 if (buf[elem]) {
201 ksft_print_msg("Last entry is 0x%llx not 0x0\n", buf[elem]);
202 pass = false;
203 }
204
205 /* Then a valid cap token */
206 elem--;
207 expected_cap = ((uint64_t)buf + page_size - 16);
208 expected_cap &= GCS_CAP_ADDR_MASK;
209 expected_cap |= GCS_CAP_VALID_TOKEN;
210 if (buf[elem] != expected_cap) {
211 ksft_print_msg("Cap entry is 0x%llx not 0x%llx\n",
212 buf[elem], expected_cap);
213 pass = false;
214 }
215 ksft_print_msg("cap token is 0x%llx\n", buf[elem]);
216
217 /* The rest should be zeros */
218 for (elem = 0; elem < page_size / sizeof(uint64_t) - 2; elem++) {
219 if (!buf[elem])
220 continue;
221 ksft_print_msg("GCS slot %d is 0x%llx not 0x0\n",
222 elem, buf[elem]);
223 pass = false;
224 }
225
226 ret = munmap(buf, page_size);
227 if (ret != 0) {
228 ksft_print_msg("Failed to unmap %ld byte GCS: %d\n",
229 page_size, errno);
230 pass = false;
231 }
232
233 return pass;
234 }
235
236 /* A fork()ed process can run */
test_fork(void)237 static bool test_fork(void)
238 {
239 unsigned long child_mode;
240 int ret, status;
241 pid_t pid;
242 bool pass = true;
243
244 pid = fork();
245 if (pid == -1) {
246 ksft_print_msg("fork() failed: %d\n", errno);
247 pass = false;
248 goto out;
249 }
250 if (pid == 0) {
251 /* In child, make sure we can call a function, read
252 * the GCS pointer and status and then exit */
253 valid_gcs_function();
254 get_gcspr();
255
256 ret = syscall(__NR_prctl, PR_GET_SHADOW_STACK_STATUS, &child_mode, 0, 0, 0);
257 if (ret == 0 && !(child_mode & PR_SHADOW_STACK_ENABLE)) {
258 ksft_print_msg("GCS not enabled in child\n");
259 ret = -EINVAL;
260 }
261
262 exit(ret);
263 }
264
265 /*
266 * In parent, check we can still do function calls then block
267 * for the child.
268 */
269 valid_gcs_function();
270
271 ksft_print_msg("Waiting for child %d\n", pid);
272
273 ret = waitpid(pid, &status, 0);
274 if (ret == -1) {
275 ksft_print_msg("Failed to wait for child: %d\n",
276 errno);
277 return false;
278 }
279
280 if (!WIFEXITED(status)) {
281 ksft_print_msg("Child exited due to signal %d\n",
282 WTERMSIG(status));
283 pass = false;
284 } else {
285 if (WEXITSTATUS(status)) {
286 ksft_print_msg("Child exited with status %d\n",
287 WEXITSTATUS(status));
288 pass = false;
289 }
290 }
291
292 out:
293
294 return pass;
295 }
296
297 /* A vfork()ed process can run and exit */
test_vfork(void)298 static bool test_vfork(void)
299 {
300 unsigned long child_mode;
301 int ret, status;
302 pid_t pid;
303 bool pass = true;
304
305 pid = vfork();
306 if (pid == -1) {
307 ksft_print_msg("vfork() failed: %d\n", errno);
308 pass = false;
309 goto out;
310 }
311 if (pid == 0) {
312 /*
313 * In child, make sure we can call a function, read
314 * the GCS pointer and status and then exit.
315 */
316 valid_gcs_function();
317 get_gcspr();
318
319 ret = syscall(__NR_prctl, PR_GET_SHADOW_STACK_STATUS, &child_mode, 0, 0, 0);
320 if (ret == 0 && !(child_mode & PR_SHADOW_STACK_ENABLE)) {
321 ksft_print_msg("GCS not enabled in child\n");
322 ret = EXIT_FAILURE;
323 }
324
325 _exit(ret);
326 }
327
328 /*
329 * In parent, check we can still do function calls then check
330 * on the child.
331 */
332 valid_gcs_function();
333
334 ksft_print_msg("Waiting for child %d\n", pid);
335
336 ret = waitpid(pid, &status, 0);
337 if (ret == -1) {
338 ksft_print_msg("Failed to wait for child: %d\n",
339 errno);
340 return false;
341 }
342
343 if (!WIFEXITED(status)) {
344 ksft_print_msg("Child exited due to signal %d\n",
345 WTERMSIG(status));
346 pass = false;
347 } else if (WEXITSTATUS(status)) {
348 ksft_print_msg("Child exited with status %d\n",
349 WEXITSTATUS(status));
350 pass = false;
351 }
352
353 out:
354
355 return pass;
356 }
357
358 typedef bool (*gcs_test)(void);
359
360 static struct {
361 char *name;
362 gcs_test test;
363 bool needs_enable;
364 } tests[] = {
365 { "read_status", read_status },
366 { "base_enable", base_enable, true },
367 { "read_gcspr_el0", read_gcspr_el0 },
368 { "enable_writeable", enable_writeable, true },
369 { "enable_push_pop", enable_push_pop, true },
370 { "enable_all", enable_all, true },
371 { "enable_invalid", enable_invalid, true },
372 { "map_guarded_stack", map_guarded_stack },
373 { "fork", test_fork },
374 { "vfork", test_vfork },
375 };
376
main(void)377 int main(void)
378 {
379 int i, ret;
380 unsigned long gcs_mode;
381
382 ksft_print_header();
383
384 if (!(getauxval(AT_HWCAP) & HWCAP_GCS))
385 ksft_exit_skip("SKIP GCS not supported\n");
386
387 ret = syscall(__NR_prctl, PR_GET_SHADOW_STACK_STATUS, &gcs_mode, 0, 0, 0);
388 if (ret != 0)
389 ksft_exit_fail_msg("Failed to read GCS state: %d\n", errno);
390
391 if (!(gcs_mode & PR_SHADOW_STACK_ENABLE)) {
392 gcs_mode = PR_SHADOW_STACK_ENABLE;
393 ret = syscall(__NR_prctl, PR_SET_SHADOW_STACK_STATUS, gcs_mode, 0, 0, 0);
394 if (ret != 0)
395 ksft_exit_fail_msg("Failed to enable GCS: %d\n", errno);
396 }
397
398 ksft_set_plan(ARRAY_SIZE(tests));
399
400 for (i = 0; i < ARRAY_SIZE(tests); i++) {
401 ksft_test_result((*tests[i].test)(), "%s\n", tests[i].name);
402 }
403
404 /* One last test: disable GCS, we can do this one time */
405 ret = syscall(__NR_prctl, PR_SET_SHADOW_STACK_STATUS, 0, 0, 0, 0);
406 if (ret != 0)
407 ksft_print_msg("Failed to disable GCS: %d\n", errno);
408
409 ksft_finished();
410
411 return 0;
412 }
413