1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4 #include <errno.h>
5 #include <fcntl.h>
6 #include <inttypes.h>
7 #include <limits.h>
8 #include <linux/types.h>
9 #include <sched.h>
10 #include <signal.h>
11 #include <stdbool.h>
12 #include <stdio.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <syscall.h>
16 #include <sys/ioctl.h>
17 #include <sys/mount.h>
18 #include <sys/prctl.h>
19 #include <sys/wait.h>
20 #include <unistd.h>
21
22 #include "pidfd.h"
23 #include "../kselftest.h"
24
safe_int(const char * numstr,int * converted)25 static int safe_int(const char *numstr, int *converted)
26 {
27 char *err = NULL;
28 long sli;
29
30 errno = 0;
31 sli = strtol(numstr, &err, 0);
32 if (errno == ERANGE && (sli == LONG_MAX || sli == LONG_MIN))
33 return -ERANGE;
34
35 if (errno != 0 && sli == 0)
36 return -EINVAL;
37
38 if (err == numstr || *err != '\0')
39 return -EINVAL;
40
41 if (sli > INT_MAX || sli < INT_MIN)
42 return -ERANGE;
43
44 *converted = (int)sli;
45 return 0;
46 }
47
char_left_gc(const char * buffer,size_t len)48 static int char_left_gc(const char *buffer, size_t len)
49 {
50 size_t i;
51
52 for (i = 0; i < len; i++) {
53 if (buffer[i] == ' ' ||
54 buffer[i] == '\t')
55 continue;
56
57 return i;
58 }
59
60 return 0;
61 }
62
char_right_gc(const char * buffer,size_t len)63 static int char_right_gc(const char *buffer, size_t len)
64 {
65 int i;
66
67 for (i = len - 1; i >= 0; i--) {
68 if (buffer[i] == ' ' ||
69 buffer[i] == '\t' ||
70 buffer[i] == '\n' ||
71 buffer[i] == '\0')
72 continue;
73
74 return i + 1;
75 }
76
77 return 0;
78 }
79
trim_whitespace_in_place(char * buffer)80 static char *trim_whitespace_in_place(char *buffer)
81 {
82 buffer += char_left_gc(buffer, strlen(buffer));
83 buffer[char_right_gc(buffer, strlen(buffer))] = '\0';
84 return buffer;
85 }
86
get_pid_from_fdinfo_file(int pidfd,const char * key,size_t keylen)87 static pid_t get_pid_from_fdinfo_file(int pidfd, const char *key, size_t keylen)
88 {
89 int ret;
90 char path[512];
91 FILE *f;
92 size_t n = 0;
93 pid_t result = -1;
94 char *line = NULL;
95
96 snprintf(path, sizeof(path), "/proc/self/fdinfo/%d", pidfd);
97
98 f = fopen(path, "re");
99 if (!f)
100 return -1;
101
102 while (getline(&line, &n, f) != -1) {
103 char *numstr;
104
105 if (strncmp(line, key, keylen))
106 continue;
107
108 numstr = trim_whitespace_in_place(line + 4);
109 ret = safe_int(numstr, &result);
110 if (ret < 0)
111 goto out;
112
113 break;
114 }
115
116 out:
117 free(line);
118 fclose(f);
119 return result;
120 }
121
main(int argc,char ** argv)122 int main(int argc, char **argv)
123 {
124 struct pidfd_info info = {
125 .mask = PIDFD_INFO_CGROUPID,
126 };
127 int pidfd = -1, ret = 1;
128 pid_t pid;
129
130 ksft_set_plan(4);
131
132 pidfd = sys_pidfd_open(-1, 0);
133 if (pidfd >= 0) {
134 ksft_print_msg(
135 "%s - succeeded to open pidfd for invalid pid -1\n",
136 strerror(errno));
137 goto on_error;
138 }
139 ksft_test_result_pass("do not allow invalid pid test: passed\n");
140
141 pidfd = sys_pidfd_open(getpid(), 1);
142 if (pidfd >= 0) {
143 ksft_print_msg(
144 "%s - succeeded to open pidfd with invalid flag value specified\n",
145 strerror(errno));
146 goto on_error;
147 }
148 ksft_test_result_pass("do not allow invalid flag test: passed\n");
149
150 pidfd = sys_pidfd_open(getpid(), 0);
151 if (pidfd < 0) {
152 ksft_print_msg("%s - failed to open pidfd\n", strerror(errno));
153 goto on_error;
154 }
155 ksft_test_result_pass("open a new pidfd test: passed\n");
156
157 pid = get_pid_from_fdinfo_file(pidfd, "Pid:", sizeof("Pid:") - 1);
158 ksft_print_msg("pidfd %d refers to process with pid %d\n", pidfd, pid);
159
160 if (ioctl(pidfd, PIDFD_GET_INFO, &info) < 0) {
161 ksft_print_msg("%s - failed to get info from pidfd\n", strerror(errno));
162 goto on_error;
163 }
164 if (info.pid != pid) {
165 ksft_print_msg("pid from fdinfo file %d does not match pid from ioctl %d\n",
166 pid, info.pid);
167 goto on_error;
168 }
169 if (info.ppid != getppid()) {
170 ksft_print_msg("ppid %d does not match ppid from ioctl %d\n",
171 pid, info.pid);
172 goto on_error;
173 }
174 if (info.ruid != getuid()) {
175 ksft_print_msg("uid %d does not match uid from ioctl %d\n",
176 getuid(), info.ruid);
177 goto on_error;
178 }
179 if (info.rgid != getgid()) {
180 ksft_print_msg("gid %d does not match gid from ioctl %d\n",
181 getgid(), info.rgid);
182 goto on_error;
183 }
184 if (info.euid != geteuid()) {
185 ksft_print_msg("euid %d does not match euid from ioctl %d\n",
186 geteuid(), info.euid);
187 goto on_error;
188 }
189 if (info.egid != getegid()) {
190 ksft_print_msg("egid %d does not match egid from ioctl %d\n",
191 getegid(), info.egid);
192 goto on_error;
193 }
194 if (info.suid != geteuid()) {
195 ksft_print_msg("suid %d does not match suid from ioctl %d\n",
196 geteuid(), info.suid);
197 goto on_error;
198 }
199 if (info.sgid != getegid()) {
200 ksft_print_msg("sgid %d does not match sgid from ioctl %d\n",
201 getegid(), info.sgid);
202 goto on_error;
203 }
204 if ((info.mask & PIDFD_INFO_CGROUPID) && info.cgroupid == 0) {
205 ksft_print_msg("cgroupid should not be 0 when PIDFD_INFO_CGROUPID is set\n");
206 goto on_error;
207 }
208 ksft_test_result_pass("get info from pidfd test: passed\n");
209
210 ret = 0;
211
212 on_error:
213 if (pidfd >= 0)
214 close(pidfd);
215
216 if (ret)
217 ksft_exit_fail();
218 ksft_exit_pass();
219 }
220