1 // SPDX-License-Identifier: GPL-2.0
2 #ifndef _GNU_SOURCE
3 #define _GNU_SOURCE
4 #endif
5 #include <fcntl.h>
6 #include <sys/types.h>
7 #include <dirent.h>
8 #include <grp.h>
9 #include <linux/limits.h>
10 #include <sched.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <sys/eventfd.h>
14 #include <sys/fsuid.h>
15 #include <sys/prctl.h>
16 #include <sys/socket.h>
17 #include <sys/stat.h>
18 #include <sys/types.h>
19 #include <sys/wait.h>
20 #include <sys/xattr.h>
21 #include <sys/mount.h>
22
23 #include "../kselftest.h"
24 #include "wrappers.h"
25 #include "utils.h"
26
27 #define MAX_USERNS_LEVEL 32
28
29 #define syserror(format, ...) \
30 ({ \
31 fprintf(stderr, "%m - " format "\n", ##__VA_ARGS__); \
32 (-errno); \
33 })
34
35 #define syserror_set(__ret__, format, ...) \
36 ({ \
37 typeof(__ret__) __internal_ret__ = (__ret__); \
38 errno = labs(__ret__); \
39 fprintf(stderr, "%m - " format "\n", ##__VA_ARGS__); \
40 __internal_ret__; \
41 })
42
43 #define STRLITERALLEN(x) (sizeof(""x"") - 1)
44
45 #define INTTYPE_TO_STRLEN(type) \
46 (2 + (sizeof(type) <= 1 \
47 ? 3 \
48 : sizeof(type) <= 2 \
49 ? 5 \
50 : sizeof(type) <= 4 \
51 ? 10 \
52 : sizeof(type) <= 8 ? 20 : sizeof(int[-2 * (sizeof(type) > 8)])))
53
54 #define list_for_each(__iterator, __list) \
55 for (__iterator = (__list)->next; __iterator != __list; __iterator = __iterator->next)
56
57 typedef enum idmap_type_t {
58 ID_TYPE_UID,
59 ID_TYPE_GID
60 } idmap_type_t;
61
62 struct id_map {
63 idmap_type_t map_type;
64 __u32 nsid;
65 __u32 hostid;
66 __u32 range;
67 };
68
69 struct list {
70 void *elem;
71 struct list *next;
72 struct list *prev;
73 };
74
75 struct userns_hierarchy {
76 int fd_userns;
77 int fd_event;
78 unsigned int level;
79 struct list id_map;
80 };
81
list_init(struct list * list)82 static inline void list_init(struct list *list)
83 {
84 list->elem = NULL;
85 list->next = list->prev = list;
86 }
87
list_empty(const struct list * list)88 static inline int list_empty(const struct list *list)
89 {
90 return list == list->next;
91 }
92
__list_add(struct list * new,struct list * prev,struct list * next)93 static inline void __list_add(struct list *new, struct list *prev, struct list *next)
94 {
95 next->prev = new;
96 new->next = next;
97 new->prev = prev;
98 prev->next = new;
99 }
100
list_add_tail(struct list * head,struct list * list)101 static inline void list_add_tail(struct list *head, struct list *list)
102 {
103 __list_add(list, head->prev, head);
104 }
105
list_del(struct list * list)106 static inline void list_del(struct list *list)
107 {
108 struct list *next, *prev;
109
110 next = list->next;
111 prev = list->prev;
112 next->prev = prev;
113 prev->next = next;
114 }
115
read_nointr(int fd,void * buf,size_t count)116 static ssize_t read_nointr(int fd, void *buf, size_t count)
117 {
118 ssize_t ret;
119
120 do {
121 ret = read(fd, buf, count);
122 } while (ret < 0 && errno == EINTR);
123
124 return ret;
125 }
126
write_nointr(int fd,const void * buf,size_t count)127 static ssize_t write_nointr(int fd, const void *buf, size_t count)
128 {
129 ssize_t ret;
130
131 do {
132 ret = write(fd, buf, count);
133 } while (ret < 0 && errno == EINTR);
134
135 return ret;
136 }
137
138 #define __STACK_SIZE (8 * 1024 * 1024)
do_clone(int (* fn)(void *),void * arg,int flags)139 static pid_t do_clone(int (*fn)(void *), void *arg, int flags)
140 {
141 void *stack;
142
143 stack = malloc(__STACK_SIZE);
144 if (!stack)
145 return -ENOMEM;
146
147 #ifdef __ia64__
148 return __clone2(fn, stack, __STACK_SIZE, flags | SIGCHLD, arg, NULL);
149 #else
150 return clone(fn, stack + __STACK_SIZE, flags | SIGCHLD, arg, NULL);
151 #endif
152 }
153
get_userns_fd_cb(void * data)154 static int get_userns_fd_cb(void *data)
155 {
156 for (;;)
157 pause();
158 _exit(0);
159 }
160
wait_for_pid(pid_t pid)161 static int wait_for_pid(pid_t pid)
162 {
163 int status, ret;
164
165 again:
166 ret = waitpid(pid, &status, 0);
167 if (ret == -1) {
168 if (errno == EINTR)
169 goto again;
170
171 return -1;
172 }
173
174 if (!WIFEXITED(status))
175 return -1;
176
177 return WEXITSTATUS(status);
178 }
179
write_id_mapping(idmap_type_t map_type,pid_t pid,const char * buf,size_t buf_size)180 static int write_id_mapping(idmap_type_t map_type, pid_t pid, const char *buf, size_t buf_size)
181 {
182 int fd = -EBADF, setgroups_fd = -EBADF;
183 int fret = -1;
184 int ret;
185 char path[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
186 STRLITERALLEN("/setgroups") + 1];
187
188 if (geteuid() != 0 && map_type == ID_TYPE_GID) {
189 ret = snprintf(path, sizeof(path), "/proc/%d/setgroups", pid);
190 if (ret < 0 || ret >= sizeof(path))
191 goto out;
192
193 setgroups_fd = open(path, O_WRONLY | O_CLOEXEC);
194 if (setgroups_fd < 0 && errno != ENOENT) {
195 syserror("Failed to open \"%s\"", path);
196 goto out;
197 }
198
199 if (setgroups_fd >= 0) {
200 ret = write_nointr(setgroups_fd, "deny\n", STRLITERALLEN("deny\n"));
201 if (ret != STRLITERALLEN("deny\n")) {
202 syserror("Failed to write \"deny\" to \"/proc/%d/setgroups\"", pid);
203 goto out;
204 }
205 }
206 }
207
208 ret = snprintf(path, sizeof(path), "/proc/%d/%cid_map", pid, map_type == ID_TYPE_UID ? 'u' : 'g');
209 if (ret < 0 || ret >= sizeof(path))
210 goto out;
211
212 fd = open(path, O_WRONLY | O_CLOEXEC);
213 if (fd < 0) {
214 syserror("Failed to open \"%s\"", path);
215 goto out;
216 }
217
218 ret = write_nointr(fd, buf, buf_size);
219 if (ret != buf_size) {
220 syserror("Failed to write %cid mapping to \"%s\"",
221 map_type == ID_TYPE_UID ? 'u' : 'g', path);
222 goto out;
223 }
224
225 fret = 0;
226 out:
227 close(fd);
228 close(setgroups_fd);
229
230 return fret;
231 }
232
map_ids_from_idmap(struct list * idmap,pid_t pid)233 static int map_ids_from_idmap(struct list *idmap, pid_t pid)
234 {
235 int fill, left;
236 char mapbuf[4096] = {};
237 bool had_entry = false;
238 idmap_type_t map_type, u_or_g;
239
240 if (list_empty(idmap))
241 return 0;
242
243 for (map_type = ID_TYPE_UID, u_or_g = 'u';
244 map_type <= ID_TYPE_GID; map_type++, u_or_g = 'g') {
245 char *pos = mapbuf;
246 int ret;
247 struct list *iterator;
248
249
250 list_for_each(iterator, idmap) {
251 struct id_map *map = iterator->elem;
252 if (map->map_type != map_type)
253 continue;
254
255 had_entry = true;
256
257 left = 4096 - (pos - mapbuf);
258 fill = snprintf(pos, left, "%u %u %u\n", map->nsid, map->hostid, map->range);
259 /*
260 * The kernel only takes <= 4k for writes to
261 * /proc/<pid>/{g,u}id_map
262 */
263 if (fill <= 0 || fill >= left)
264 return syserror_set(-E2BIG, "Too many %cid mappings defined", u_or_g);
265
266 pos += fill;
267 }
268 if (!had_entry)
269 continue;
270
271 ret = write_id_mapping(map_type, pid, mapbuf, pos - mapbuf);
272 if (ret < 0)
273 return syserror("Failed to write mapping: %s", mapbuf);
274
275 memset(mapbuf, 0, sizeof(mapbuf));
276 }
277
278 return 0;
279 }
280
get_userns_fd_from_idmap(struct list * idmap)281 static int get_userns_fd_from_idmap(struct list *idmap)
282 {
283 int ret;
284 pid_t pid;
285 char path_ns[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
286 STRLITERALLEN("/ns/user") + 1];
287
288 pid = do_clone(get_userns_fd_cb, NULL, CLONE_NEWUSER | CLONE_NEWNS);
289 if (pid < 0)
290 return -errno;
291
292 ret = map_ids_from_idmap(idmap, pid);
293 if (ret < 0)
294 return ret;
295
296 ret = snprintf(path_ns, sizeof(path_ns), "/proc/%d/ns/user", pid);
297 if (ret < 0 || (size_t)ret >= sizeof(path_ns))
298 ret = -EIO;
299 else
300 ret = open(path_ns, O_RDONLY | O_CLOEXEC | O_NOCTTY);
301
302 (void)kill(pid, SIGKILL);
303 (void)wait_for_pid(pid);
304 return ret;
305 }
306
get_userns_fd(unsigned long nsid,unsigned long hostid,unsigned long range)307 int get_userns_fd(unsigned long nsid, unsigned long hostid, unsigned long range)
308 {
309 struct list head, uid_mapl, gid_mapl;
310 struct id_map uid_map = {
311 .map_type = ID_TYPE_UID,
312 .nsid = nsid,
313 .hostid = hostid,
314 .range = range,
315 };
316 struct id_map gid_map = {
317 .map_type = ID_TYPE_GID,
318 .nsid = nsid,
319 .hostid = hostid,
320 .range = range,
321 };
322
323 list_init(&head);
324 uid_mapl.elem = &uid_map;
325 gid_mapl.elem = &gid_map;
326 list_add_tail(&head, &uid_mapl);
327 list_add_tail(&head, &gid_mapl);
328
329 return get_userns_fd_from_idmap(&head);
330 }
331
switch_ids(uid_t uid,gid_t gid)332 bool switch_ids(uid_t uid, gid_t gid)
333 {
334 if (setgroups(0, NULL))
335 return syserror("failure: setgroups");
336
337 if (setresgid(gid, gid, gid))
338 return syserror("failure: setresgid");
339
340 if (setresuid(uid, uid, uid))
341 return syserror("failure: setresuid");
342
343 /* Ensure we can access proc files from processes we can ptrace. */
344 if (prctl(PR_SET_DUMPABLE, 1, 0, 0, 0))
345 return syserror("failure: make dumpable");
346
347 return true;
348 }
349
350 static int create_userns_hierarchy(struct userns_hierarchy *h);
351
userns_fd_cb(void * data)352 static int userns_fd_cb(void *data)
353 {
354 struct userns_hierarchy *h = data;
355 char c;
356 int ret;
357
358 ret = read_nointr(h->fd_event, &c, 1);
359 if (ret < 0)
360 return syserror("failure: read from socketpair");
361
362 /* Only switch ids if someone actually wrote a mapping for us. */
363 if (c == '1') {
364 if (!switch_ids(0, 0))
365 return syserror("failure: switch ids to 0");
366 }
367
368 ret = write_nointr(h->fd_event, "1", 1);
369 if (ret < 0)
370 return syserror("failure: write to socketpair");
371
372 ret = create_userns_hierarchy(++h);
373 if (ret < 0)
374 return syserror("failure: userns level %d", h->level);
375
376 return 0;
377 }
378
create_userns_hierarchy(struct userns_hierarchy * h)379 static int create_userns_hierarchy(struct userns_hierarchy *h)
380 {
381 int fret = -1;
382 char c;
383 int fd_socket[2];
384 int fd_userns = -EBADF, ret = -1;
385 ssize_t bytes;
386 pid_t pid;
387 char path[256];
388
389 if (h->level == MAX_USERNS_LEVEL)
390 return 0;
391
392 ret = socketpair(AF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC, 0, fd_socket);
393 if (ret < 0)
394 return syserror("failure: create socketpair");
395
396 /* Note the CLONE_FILES | CLONE_VM when mucking with fds and memory. */
397 h->fd_event = fd_socket[1];
398 pid = do_clone(userns_fd_cb, h, CLONE_NEWUSER | CLONE_FILES | CLONE_VM);
399 if (pid < 0) {
400 syserror("failure: userns level %d", h->level);
401 goto out_close;
402 }
403
404 ret = map_ids_from_idmap(&h->id_map, pid);
405 if (ret < 0) {
406 kill(pid, SIGKILL);
407 syserror("failure: writing id mapping for userns level %d for %d", h->level, pid);
408 goto out_wait;
409 }
410
411 if (!list_empty(&h->id_map))
412 bytes = write_nointr(fd_socket[0], "1", 1); /* Inform the child we wrote a mapping. */
413 else
414 bytes = write_nointr(fd_socket[0], "0", 1); /* Inform the child we didn't write a mapping. */
415 if (bytes < 0) {
416 kill(pid, SIGKILL);
417 syserror("failure: write to socketpair");
418 goto out_wait;
419 }
420
421 /* Wait for child to set*id() and become dumpable. */
422 bytes = read_nointr(fd_socket[0], &c, 1);
423 if (bytes < 0) {
424 kill(pid, SIGKILL);
425 syserror("failure: read from socketpair");
426 goto out_wait;
427 }
428
429 snprintf(path, sizeof(path), "/proc/%d/ns/user", pid);
430 fd_userns = open(path, O_RDONLY | O_CLOEXEC);
431 if (fd_userns < 0) {
432 kill(pid, SIGKILL);
433 syserror("failure: open userns level %d for %d", h->level, pid);
434 goto out_wait;
435 }
436
437 fret = 0;
438
439 out_wait:
440 if (!wait_for_pid(pid) && !fret) {
441 h->fd_userns = fd_userns;
442 fd_userns = -EBADF;
443 }
444
445 out_close:
446 if (fd_userns >= 0)
447 close(fd_userns);
448 close(fd_socket[0]);
449 close(fd_socket[1]);
450 return fret;
451 }
452
write_file(const char * path,const char * val)453 static int write_file(const char *path, const char *val)
454 {
455 int fd = open(path, O_WRONLY);
456 size_t len = strlen(val);
457 int ret;
458
459 if (fd == -1) {
460 ksft_print_msg("opening %s for write: %s\n", path, strerror(errno));
461 return -1;
462 }
463
464 ret = write(fd, val, len);
465 if (ret == -1) {
466 ksft_print_msg("writing to %s: %s\n", path, strerror(errno));
467 return -1;
468 }
469 if (ret != len) {
470 ksft_print_msg("short write to %s\n", path);
471 return -1;
472 }
473
474 ret = close(fd);
475 if (ret == -1) {
476 ksft_print_msg("closing %s\n", path);
477 return -1;
478 }
479
480 return 0;
481 }
482
setup_userns(void)483 int setup_userns(void)
484 {
485 int ret;
486 char buf[32];
487 uid_t uid = getuid();
488 gid_t gid = getgid();
489
490 ret = unshare(CLONE_NEWNS|CLONE_NEWUSER|CLONE_NEWPID);
491 if (ret) {
492 ksft_exit_fail_msg("unsharing mountns and userns: %s\n",
493 strerror(errno));
494 return ret;
495 }
496
497 sprintf(buf, "0 %d 1", uid);
498 ret = write_file("/proc/self/uid_map", buf);
499 if (ret)
500 return ret;
501 ret = write_file("/proc/self/setgroups", "deny");
502 if (ret)
503 return ret;
504 sprintf(buf, "0 %d 1", gid);
505 ret = write_file("/proc/self/gid_map", buf);
506 if (ret)
507 return ret;
508
509 ret = mount("", "/", NULL, MS_REC|MS_PRIVATE, NULL);
510 if (ret) {
511 ksft_print_msg("making mount tree private: %s\n", strerror(errno));
512 return ret;
513 }
514
515 return 0;
516 }
517
518 /* caps_down - lower all effective caps */
caps_down(void)519 int caps_down(void)
520 {
521 bool fret = false;
522 cap_t caps = NULL;
523 int ret = -1;
524
525 caps = cap_get_proc();
526 if (!caps)
527 goto out;
528
529 ret = cap_clear_flag(caps, CAP_EFFECTIVE);
530 if (ret)
531 goto out;
532
533 ret = cap_set_proc(caps);
534 if (ret)
535 goto out;
536
537 fret = true;
538
539 out:
540 cap_free(caps);
541 return fret;
542 }
543
544 /* cap_down - lower an effective cap */
cap_down(cap_value_t down)545 int cap_down(cap_value_t down)
546 {
547 bool fret = false;
548 cap_t caps = NULL;
549 cap_value_t cap = down;
550 int ret = -1;
551
552 caps = cap_get_proc();
553 if (!caps)
554 goto out;
555
556 ret = cap_set_flag(caps, CAP_EFFECTIVE, 1, &cap, 0);
557 if (ret)
558 goto out;
559
560 ret = cap_set_proc(caps);
561 if (ret)
562 goto out;
563
564 fret = true;
565
566 out:
567 cap_free(caps);
568 return fret;
569 }
570
get_unique_mnt_id(const char * path)571 uint64_t get_unique_mnt_id(const char *path)
572 {
573 struct statx sx;
574 int ret;
575
576 ret = statx(AT_FDCWD, path, 0, STATX_MNT_ID_UNIQUE, &sx);
577 if (ret == -1) {
578 ksft_print_msg("retrieving unique mount ID for %s: %s\n", path,
579 strerror(errno));
580 return 0;
581 }
582
583 if (!(sx.stx_mask & STATX_MNT_ID_UNIQUE)) {
584 ksft_print_msg("no unique mount ID available for %s\n", path);
585 return 0;
586 }
587
588 return sx.stx_mnt_id;
589 }
590