xref: /linux/tools/testing/selftests/filesystems/utils.c (revision ab93e0dd72c37d378dd936f031ffb83ff2bd87ce)
1 // SPDX-License-Identifier: GPL-2.0
2 #ifndef _GNU_SOURCE
3 #define _GNU_SOURCE
4 #endif
5 #include <fcntl.h>
6 #include <sys/types.h>
7 #include <dirent.h>
8 #include <grp.h>
9 #include <linux/limits.h>
10 #include <sched.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <sys/eventfd.h>
14 #include <sys/fsuid.h>
15 #include <sys/prctl.h>
16 #include <sys/socket.h>
17 #include <sys/stat.h>
18 #include <sys/types.h>
19 #include <sys/wait.h>
20 #include <sys/xattr.h>
21 #include <sys/mount.h>
22 
23 #include "../kselftest.h"
24 #include "wrappers.h"
25 #include "utils.h"
26 
27 #define MAX_USERNS_LEVEL 32
28 
29 #define syserror(format, ...)                           \
30 	({                                              \
31 		fprintf(stderr, "%m - " format "\n", ##__VA_ARGS__); \
32 		(-errno);                               \
33 	})
34 
35 #define syserror_set(__ret__, format, ...)                    \
36 	({                                                    \
37 		typeof(__ret__) __internal_ret__ = (__ret__); \
38 		errno = labs(__ret__);                        \
39 		fprintf(stderr, "%m - " format "\n", ##__VA_ARGS__);       \
40 		__internal_ret__;                             \
41 	})
42 
43 #define STRLITERALLEN(x) (sizeof(""x"") - 1)
44 
45 #define INTTYPE_TO_STRLEN(type)             \
46 	(2 + (sizeof(type) <= 1             \
47 		  ? 3                       \
48 		  : sizeof(type) <= 2       \
49 			? 5                 \
50 			: sizeof(type) <= 4 \
51 			      ? 10          \
52 			      : sizeof(type) <= 8 ? 20 : sizeof(int[-2 * (sizeof(type) > 8)])))
53 
54 #define list_for_each(__iterator, __list) \
55 	for (__iterator = (__list)->next; __iterator != __list; __iterator = __iterator->next)
56 
57 typedef enum idmap_type_t {
58 	ID_TYPE_UID,
59 	ID_TYPE_GID
60 } idmap_type_t;
61 
62 struct id_map {
63 	idmap_type_t map_type;
64 	__u32 nsid;
65 	__u32 hostid;
66 	__u32 range;
67 };
68 
69 struct list {
70 	void *elem;
71 	struct list *next;
72 	struct list *prev;
73 };
74 
75 struct userns_hierarchy {
76 	int fd_userns;
77 	int fd_event;
78 	unsigned int level;
79 	struct list id_map;
80 };
81 
list_init(struct list * list)82 static inline void list_init(struct list *list)
83 {
84 	list->elem = NULL;
85 	list->next = list->prev = list;
86 }
87 
list_empty(const struct list * list)88 static inline int list_empty(const struct list *list)
89 {
90 	return list == list->next;
91 }
92 
__list_add(struct list * new,struct list * prev,struct list * next)93 static inline void __list_add(struct list *new, struct list *prev, struct list *next)
94 {
95 	next->prev = new;
96 	new->next = next;
97 	new->prev = prev;
98 	prev->next = new;
99 }
100 
list_add_tail(struct list * head,struct list * list)101 static inline void list_add_tail(struct list *head, struct list *list)
102 {
103 	__list_add(list, head->prev, head);
104 }
105 
list_del(struct list * list)106 static inline void list_del(struct list *list)
107 {
108 	struct list *next, *prev;
109 
110 	next = list->next;
111 	prev = list->prev;
112 	next->prev = prev;
113 	prev->next = next;
114 }
115 
read_nointr(int fd,void * buf,size_t count)116 static ssize_t read_nointr(int fd, void *buf, size_t count)
117 {
118 	ssize_t ret;
119 
120 	do {
121 		ret = read(fd, buf, count);
122 	} while (ret < 0 && errno == EINTR);
123 
124 	return ret;
125 }
126 
write_nointr(int fd,const void * buf,size_t count)127 static ssize_t write_nointr(int fd, const void *buf, size_t count)
128 {
129 	ssize_t ret;
130 
131 	do {
132 		ret = write(fd, buf, count);
133 	} while (ret < 0 && errno == EINTR);
134 
135 	return ret;
136 }
137 
138 #define __STACK_SIZE (8 * 1024 * 1024)
do_clone(int (* fn)(void *),void * arg,int flags)139 static pid_t do_clone(int (*fn)(void *), void *arg, int flags)
140 {
141 	void *stack;
142 
143 	stack = malloc(__STACK_SIZE);
144 	if (!stack)
145 		return -ENOMEM;
146 
147 #ifdef __ia64__
148 	return __clone2(fn, stack, __STACK_SIZE, flags | SIGCHLD, arg, NULL);
149 #else
150 	return clone(fn, stack + __STACK_SIZE, flags | SIGCHLD, arg, NULL);
151 #endif
152 }
153 
get_userns_fd_cb(void * data)154 static int get_userns_fd_cb(void *data)
155 {
156 	for (;;)
157 		pause();
158 	_exit(0);
159 }
160 
wait_for_pid(pid_t pid)161 static int wait_for_pid(pid_t pid)
162 {
163 	int status, ret;
164 
165 again:
166 	ret = waitpid(pid, &status, 0);
167 	if (ret == -1) {
168 		if (errno == EINTR)
169 			goto again;
170 
171 		return -1;
172 	}
173 
174 	if (!WIFEXITED(status))
175 		return -1;
176 
177 	return WEXITSTATUS(status);
178 }
179 
write_id_mapping(idmap_type_t map_type,pid_t pid,const char * buf,size_t buf_size)180 static int write_id_mapping(idmap_type_t map_type, pid_t pid, const char *buf, size_t buf_size)
181 {
182 	int fd = -EBADF, setgroups_fd = -EBADF;
183 	int fret = -1;
184 	int ret;
185 	char path[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
186 		  STRLITERALLEN("/setgroups") + 1];
187 
188 	if (geteuid() != 0 && map_type == ID_TYPE_GID) {
189 		ret = snprintf(path, sizeof(path), "/proc/%d/setgroups", pid);
190 		if (ret < 0 || ret >= sizeof(path))
191 			goto out;
192 
193 		setgroups_fd = open(path, O_WRONLY | O_CLOEXEC);
194 		if (setgroups_fd < 0 && errno != ENOENT) {
195 			syserror("Failed to open \"%s\"", path);
196 			goto out;
197 		}
198 
199 		if (setgroups_fd >= 0) {
200 			ret = write_nointr(setgroups_fd, "deny\n", STRLITERALLEN("deny\n"));
201 			if (ret != STRLITERALLEN("deny\n")) {
202 				syserror("Failed to write \"deny\" to \"/proc/%d/setgroups\"", pid);
203 				goto out;
204 			}
205 		}
206 	}
207 
208 	ret = snprintf(path, sizeof(path), "/proc/%d/%cid_map", pid, map_type == ID_TYPE_UID ? 'u' : 'g');
209 	if (ret < 0 || ret >= sizeof(path))
210 		goto out;
211 
212 	fd = open(path, O_WRONLY | O_CLOEXEC);
213 	if (fd < 0) {
214 		syserror("Failed to open \"%s\"", path);
215 		goto out;
216 	}
217 
218 	ret = write_nointr(fd, buf, buf_size);
219 	if (ret != buf_size) {
220 		syserror("Failed to write %cid mapping to \"%s\"",
221 			 map_type == ID_TYPE_UID ? 'u' : 'g', path);
222 		goto out;
223 	}
224 
225 	fret = 0;
226 out:
227 	close(fd);
228 	close(setgroups_fd);
229 
230 	return fret;
231 }
232 
map_ids_from_idmap(struct list * idmap,pid_t pid)233 static int map_ids_from_idmap(struct list *idmap, pid_t pid)
234 {
235 	int fill, left;
236 	char mapbuf[4096] = {};
237 	bool had_entry = false;
238 	idmap_type_t map_type, u_or_g;
239 
240 	if (list_empty(idmap))
241 		return 0;
242 
243 	for (map_type = ID_TYPE_UID, u_or_g = 'u';
244 	     map_type <= ID_TYPE_GID; map_type++, u_or_g = 'g') {
245 		char *pos = mapbuf;
246 		int ret;
247 		struct list *iterator;
248 
249 
250 		list_for_each(iterator, idmap) {
251 			struct id_map *map = iterator->elem;
252 			if (map->map_type != map_type)
253 				continue;
254 
255 			had_entry = true;
256 
257 			left = 4096 - (pos - mapbuf);
258 			fill = snprintf(pos, left, "%u %u %u\n", map->nsid, map->hostid, map->range);
259 			/*
260 			 * The kernel only takes <= 4k for writes to
261 			 * /proc/<pid>/{g,u}id_map
262 			 */
263 			if (fill <= 0 || fill >= left)
264 				return syserror_set(-E2BIG, "Too many %cid mappings defined", u_or_g);
265 
266 			pos += fill;
267 		}
268 		if (!had_entry)
269 			continue;
270 
271 		ret = write_id_mapping(map_type, pid, mapbuf, pos - mapbuf);
272 		if (ret < 0)
273 			return syserror("Failed to write mapping: %s", mapbuf);
274 
275 		memset(mapbuf, 0, sizeof(mapbuf));
276 	}
277 
278 	return 0;
279 }
280 
get_userns_fd_from_idmap(struct list * idmap)281 static int get_userns_fd_from_idmap(struct list *idmap)
282 {
283 	int ret;
284 	pid_t pid;
285 	char path_ns[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
286 		     STRLITERALLEN("/ns/user") + 1];
287 
288 	pid = do_clone(get_userns_fd_cb, NULL, CLONE_NEWUSER | CLONE_NEWNS);
289 	if (pid < 0)
290 		return -errno;
291 
292 	ret = map_ids_from_idmap(idmap, pid);
293 	if (ret < 0)
294 		return ret;
295 
296 	ret = snprintf(path_ns, sizeof(path_ns), "/proc/%d/ns/user", pid);
297 	if (ret < 0 || (size_t)ret >= sizeof(path_ns))
298 		ret = -EIO;
299 	else
300 		ret = open(path_ns, O_RDONLY | O_CLOEXEC | O_NOCTTY);
301 
302 	(void)kill(pid, SIGKILL);
303 	(void)wait_for_pid(pid);
304 	return ret;
305 }
306 
get_userns_fd(unsigned long nsid,unsigned long hostid,unsigned long range)307 int get_userns_fd(unsigned long nsid, unsigned long hostid, unsigned long range)
308 {
309 	struct list head, uid_mapl, gid_mapl;
310 	struct id_map uid_map = {
311 		.map_type	= ID_TYPE_UID,
312 		.nsid		= nsid,
313 		.hostid		= hostid,
314 		.range		= range,
315 	};
316 	struct id_map gid_map = {
317 		.map_type	= ID_TYPE_GID,
318 		.nsid		= nsid,
319 		.hostid		= hostid,
320 		.range		= range,
321 	};
322 
323 	list_init(&head);
324 	uid_mapl.elem = &uid_map;
325 	gid_mapl.elem = &gid_map;
326 	list_add_tail(&head, &uid_mapl);
327 	list_add_tail(&head, &gid_mapl);
328 
329 	return get_userns_fd_from_idmap(&head);
330 }
331 
switch_ids(uid_t uid,gid_t gid)332 bool switch_ids(uid_t uid, gid_t gid)
333 {
334 	if (setgroups(0, NULL))
335 		return syserror("failure: setgroups");
336 
337 	if (setresgid(gid, gid, gid))
338 		return syserror("failure: setresgid");
339 
340 	if (setresuid(uid, uid, uid))
341 		return syserror("failure: setresuid");
342 
343 	/* Ensure we can access proc files from processes we can ptrace. */
344 	if (prctl(PR_SET_DUMPABLE, 1, 0, 0, 0))
345 		return syserror("failure: make dumpable");
346 
347 	return true;
348 }
349 
350 static int create_userns_hierarchy(struct userns_hierarchy *h);
351 
userns_fd_cb(void * data)352 static int userns_fd_cb(void *data)
353 {
354 	struct userns_hierarchy *h = data;
355 	char c;
356 	int ret;
357 
358 	ret = read_nointr(h->fd_event, &c, 1);
359 	if (ret < 0)
360 		return syserror("failure: read from socketpair");
361 
362 	/* Only switch ids if someone actually wrote a mapping for us. */
363 	if (c == '1') {
364 		if (!switch_ids(0, 0))
365 			return syserror("failure: switch ids to 0");
366 	}
367 
368 	ret = write_nointr(h->fd_event, "1", 1);
369 	if (ret < 0)
370 		return syserror("failure: write to socketpair");
371 
372 	ret = create_userns_hierarchy(++h);
373 	if (ret < 0)
374 		return syserror("failure: userns level %d", h->level);
375 
376 	return 0;
377 }
378 
create_userns_hierarchy(struct userns_hierarchy * h)379 static int create_userns_hierarchy(struct userns_hierarchy *h)
380 {
381 	int fret = -1;
382 	char c;
383 	int fd_socket[2];
384 	int fd_userns = -EBADF, ret = -1;
385 	ssize_t bytes;
386 	pid_t pid;
387 	char path[256];
388 
389 	if (h->level == MAX_USERNS_LEVEL)
390 		return 0;
391 
392 	ret = socketpair(AF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC, 0, fd_socket);
393 	if (ret < 0)
394 		return syserror("failure: create socketpair");
395 
396 	/* Note the CLONE_FILES | CLONE_VM when mucking with fds and memory. */
397 	h->fd_event = fd_socket[1];
398 	pid = do_clone(userns_fd_cb, h, CLONE_NEWUSER | CLONE_FILES | CLONE_VM);
399 	if (pid < 0) {
400 		syserror("failure: userns level %d", h->level);
401 		goto out_close;
402 	}
403 
404 	ret = map_ids_from_idmap(&h->id_map, pid);
405 	if (ret < 0) {
406 		kill(pid, SIGKILL);
407 		syserror("failure: writing id mapping for userns level %d for %d", h->level, pid);
408 		goto out_wait;
409 	}
410 
411 	if (!list_empty(&h->id_map))
412 		bytes = write_nointr(fd_socket[0], "1", 1); /* Inform the child we wrote a mapping. */
413 	else
414 		bytes = write_nointr(fd_socket[0], "0", 1); /* Inform the child we didn't write a mapping. */
415 	if (bytes < 0) {
416 		kill(pid, SIGKILL);
417 		syserror("failure: write to socketpair");
418 		goto out_wait;
419 	}
420 
421 	/* Wait for child to set*id() and become dumpable. */
422 	bytes = read_nointr(fd_socket[0], &c, 1);
423 	if (bytes < 0) {
424 		kill(pid, SIGKILL);
425 		syserror("failure: read from socketpair");
426 		goto out_wait;
427 	}
428 
429 	snprintf(path, sizeof(path), "/proc/%d/ns/user", pid);
430 	fd_userns = open(path, O_RDONLY | O_CLOEXEC);
431 	if (fd_userns < 0) {
432 		kill(pid, SIGKILL);
433 		syserror("failure: open userns level %d for %d", h->level, pid);
434 		goto out_wait;
435 	}
436 
437 	fret = 0;
438 
439 out_wait:
440 	if (!wait_for_pid(pid) && !fret) {
441 		h->fd_userns = fd_userns;
442 		fd_userns = -EBADF;
443 	}
444 
445 out_close:
446 	if (fd_userns >= 0)
447 		close(fd_userns);
448 	close(fd_socket[0]);
449 	close(fd_socket[1]);
450 	return fret;
451 }
452 
write_file(const char * path,const char * val)453 static int write_file(const char *path, const char *val)
454 {
455 	int fd = open(path, O_WRONLY);
456 	size_t len = strlen(val);
457 	int ret;
458 
459 	if (fd == -1) {
460 		ksft_print_msg("opening %s for write: %s\n", path, strerror(errno));
461 		return -1;
462 	}
463 
464 	ret = write(fd, val, len);
465 	if (ret == -1) {
466 		ksft_print_msg("writing to %s: %s\n", path, strerror(errno));
467 		return -1;
468 	}
469 	if (ret != len) {
470 		ksft_print_msg("short write to %s\n", path);
471 		return -1;
472 	}
473 
474 	ret = close(fd);
475 	if (ret == -1) {
476 		ksft_print_msg("closing %s\n", path);
477 		return -1;
478 	}
479 
480 	return 0;
481 }
482 
setup_userns(void)483 int setup_userns(void)
484 {
485 	int ret;
486 	char buf[32];
487 	uid_t uid = getuid();
488 	gid_t gid = getgid();
489 
490 	ret = unshare(CLONE_NEWNS|CLONE_NEWUSER|CLONE_NEWPID);
491 	if (ret) {
492 		ksft_exit_fail_msg("unsharing mountns and userns: %s\n",
493 				   strerror(errno));
494 		return ret;
495 	}
496 
497 	sprintf(buf, "0 %d 1", uid);
498 	ret = write_file("/proc/self/uid_map", buf);
499 	if (ret)
500 		return ret;
501 	ret = write_file("/proc/self/setgroups", "deny");
502 	if (ret)
503 		return ret;
504 	sprintf(buf, "0 %d 1", gid);
505 	ret = write_file("/proc/self/gid_map", buf);
506 	if (ret)
507 		return ret;
508 
509 	ret = mount("", "/", NULL, MS_REC|MS_PRIVATE, NULL);
510 	if (ret) {
511 		ksft_print_msg("making mount tree private: %s\n", strerror(errno));
512 		return ret;
513 	}
514 
515 	return 0;
516 }
517 
518 /* caps_down - lower all effective caps */
caps_down(void)519 int caps_down(void)
520 {
521 	bool fret = false;
522 	cap_t caps = NULL;
523 	int ret = -1;
524 
525 	caps = cap_get_proc();
526 	if (!caps)
527 		goto out;
528 
529 	ret = cap_clear_flag(caps, CAP_EFFECTIVE);
530 	if (ret)
531 		goto out;
532 
533 	ret = cap_set_proc(caps);
534 	if (ret)
535 		goto out;
536 
537 	fret = true;
538 
539 out:
540 	cap_free(caps);
541 	return fret;
542 }
543 
544 /* cap_down - lower an effective cap */
cap_down(cap_value_t down)545 int cap_down(cap_value_t down)
546 {
547 	bool fret = false;
548 	cap_t caps = NULL;
549 	cap_value_t cap = down;
550 	int ret = -1;
551 
552 	caps = cap_get_proc();
553 	if (!caps)
554 		goto out;
555 
556 	ret = cap_set_flag(caps, CAP_EFFECTIVE, 1, &cap, 0);
557 	if (ret)
558 		goto out;
559 
560 	ret = cap_set_proc(caps);
561 	if (ret)
562 		goto out;
563 
564 	fret = true;
565 
566 out:
567 	cap_free(caps);
568 	return fret;
569 }
570 
get_unique_mnt_id(const char * path)571 uint64_t get_unique_mnt_id(const char *path)
572 {
573 	struct statx sx;
574 	int ret;
575 
576 	ret = statx(AT_FDCWD, path, 0, STATX_MNT_ID_UNIQUE, &sx);
577 	if (ret == -1) {
578 		ksft_print_msg("retrieving unique mount ID for %s: %s\n", path,
579 			 strerror(errno));
580 		return 0;
581 	}
582 
583 	if (!(sx.stx_mask & STATX_MNT_ID_UNIQUE)) {
584 		ksft_print_msg("no unique mount ID available for %s\n", path);
585 		return 0;
586 	}
587 
588 	return sx.stx_mnt_id;
589 }
590