1 // SPDX-License-Identifier: GPL-2.0
2 #include <fcntl.h>
3 #include <pthread.h>
4 #include <sched.h>
5 #include <signal.h>
6 #include "aolib.h"
7
8 /*
9 * Can't be included in the header: it defines static variables which
10 * will be unique to every object. Let's include it only once here.
11 */
12 #include "../../../kselftest.h"
13
14 /* Prevent overriding of one thread's output by another */
15 static pthread_mutex_t ksft_print_lock = PTHREAD_MUTEX_INITIALIZER;
16
__test_msg(const char * buf)17 void __test_msg(const char *buf)
18 {
19 pthread_mutex_lock(&ksft_print_lock);
20 ksft_print_msg(buf);
21 pthread_mutex_unlock(&ksft_print_lock);
22 }
__test_ok(const char * buf)23 void __test_ok(const char *buf)
24 {
25 pthread_mutex_lock(&ksft_print_lock);
26 ksft_test_result_pass(buf);
27 pthread_mutex_unlock(&ksft_print_lock);
28 }
__test_fail(const char * buf)29 void __test_fail(const char *buf)
30 {
31 pthread_mutex_lock(&ksft_print_lock);
32 ksft_test_result_fail(buf);
33 pthread_mutex_unlock(&ksft_print_lock);
34 }
__test_xfail(const char * buf)35 void __test_xfail(const char *buf)
36 {
37 pthread_mutex_lock(&ksft_print_lock);
38 ksft_test_result_xfail(buf);
39 pthread_mutex_unlock(&ksft_print_lock);
40 }
__test_error(const char * buf)41 void __test_error(const char *buf)
42 {
43 pthread_mutex_lock(&ksft_print_lock);
44 ksft_test_result_error(buf);
45 pthread_mutex_unlock(&ksft_print_lock);
46 }
__test_skip(const char * buf)47 void __test_skip(const char *buf)
48 {
49 pthread_mutex_lock(&ksft_print_lock);
50 ksft_test_result_skip(buf);
51 pthread_mutex_unlock(&ksft_print_lock);
52 }
53
54 static volatile int failed;
55 static volatile int skipped;
56
test_failed(void)57 void test_failed(void)
58 {
59 failed = 1;
60 }
61
test_exit(void)62 static void test_exit(void)
63 {
64 if (failed) {
65 ksft_exit_fail();
66 } else if (skipped) {
67 /* ksft_exit_skip() is different from ksft_exit_*() */
68 ksft_print_cnts();
69 exit(KSFT_SKIP);
70 } else {
71 ksft_exit_pass();
72 }
73 }
74
75 struct dlist_t {
76 void (*destruct)(void);
77 struct dlist_t *next;
78 };
79 static struct dlist_t *destructors_list;
80
test_add_destructor(void (* d)(void))81 void test_add_destructor(void (*d)(void))
82 {
83 struct dlist_t *p;
84
85 p = malloc(sizeof(struct dlist_t));
86 if (p == NULL)
87 test_error("malloc() failed");
88
89 p->next = destructors_list;
90 p->destruct = d;
91 destructors_list = p;
92 }
93
94 static void test_destructor(void) __attribute__((destructor));
test_destructor(void)95 static void test_destructor(void)
96 {
97 while (destructors_list) {
98 struct dlist_t *p = destructors_list->next;
99
100 destructors_list->destruct();
101 free(destructors_list);
102 destructors_list = p;
103 }
104 test_exit();
105 }
106
sig_int(int signo)107 static void sig_int(int signo)
108 {
109 test_error("Caught SIGINT - exiting");
110 }
111
open_netns(void)112 int open_netns(void)
113 {
114 const char *netns_path = "/proc/self/ns/net";
115 int fd;
116
117 fd = open(netns_path, O_RDONLY);
118 if (fd < 0)
119 test_error("open(%s)", netns_path);
120 return fd;
121 }
122
unshare_open_netns(void)123 int unshare_open_netns(void)
124 {
125 if (unshare(CLONE_NEWNET) != 0)
126 test_error("unshare()");
127
128 return open_netns();
129 }
130
switch_ns(int fd)131 void switch_ns(int fd)
132 {
133 if (setns(fd, CLONE_NEWNET))
134 test_error("setns()");
135 }
136
switch_save_ns(int new_ns)137 int switch_save_ns(int new_ns)
138 {
139 int ret = open_netns();
140
141 switch_ns(new_ns);
142 return ret;
143 }
144
145 static int nsfd_outside = -1;
146 static int nsfd_parent = -1;
147 static int nsfd_child = -1;
148 const char veth_name[] = "ktst-veth";
149
init_namespaces(void)150 static void init_namespaces(void)
151 {
152 nsfd_outside = open_netns();
153 nsfd_parent = unshare_open_netns();
154 nsfd_child = unshare_open_netns();
155 }
156
link_init(const char * veth,int family,uint8_t prefix,union tcp_addr addr,union tcp_addr dest)157 static void link_init(const char *veth, int family, uint8_t prefix,
158 union tcp_addr addr, union tcp_addr dest)
159 {
160 if (link_set_up(veth))
161 test_error("Failed to set link up");
162 if (ip_addr_add(veth, family, addr, prefix))
163 test_error("Failed to add ip address");
164 if (ip_route_add(veth, family, addr, dest))
165 test_error("Failed to add route");
166 }
167
168 static unsigned int nr_threads = 1;
169
170 static pthread_mutex_t sync_lock = PTHREAD_MUTEX_INITIALIZER;
171 static pthread_cond_t sync_cond = PTHREAD_COND_INITIALIZER;
172 static volatile unsigned int stage_threads[2];
173 static volatile unsigned int stage_nr;
174
175 /* synchronize all threads in the same stage */
synchronize_threads(void)176 void synchronize_threads(void)
177 {
178 unsigned int q = stage_nr;
179
180 pthread_mutex_lock(&sync_lock);
181 stage_threads[q]++;
182 if (stage_threads[q] == nr_threads) {
183 stage_nr ^= 1;
184 stage_threads[stage_nr] = 0;
185 pthread_cond_signal(&sync_cond);
186 }
187 while (stage_threads[q] < nr_threads)
188 pthread_cond_wait(&sync_cond, &sync_lock);
189 pthread_mutex_unlock(&sync_lock);
190 }
191
192 __thread union tcp_addr this_ip_addr;
193 __thread union tcp_addr this_ip_dest;
194 int test_family;
195
196 struct new_pthread_arg {
197 thread_fn func;
198 union tcp_addr my_ip;
199 union tcp_addr dest_ip;
200 };
new_pthread_entry(void * arg)201 static void *new_pthread_entry(void *arg)
202 {
203 struct new_pthread_arg *p = arg;
204
205 this_ip_addr = p->my_ip;
206 this_ip_dest = p->dest_ip;
207 p->func(NULL); /* shouldn't return */
208 exit(KSFT_FAIL);
209 }
210
__test_skip_all(const char * msg)211 static void __test_skip_all(const char *msg)
212 {
213 ksft_set_plan(1);
214 ksft_print_header();
215 skipped = 1;
216 test_skip("%s", msg);
217 exit(KSFT_SKIP);
218 }
219
__test_init(unsigned int ntests,int family,unsigned int prefix,union tcp_addr addr1,union tcp_addr addr2,thread_fn peer1,thread_fn peer2)220 void __test_init(unsigned int ntests, int family, unsigned int prefix,
221 union tcp_addr addr1, union tcp_addr addr2,
222 thread_fn peer1, thread_fn peer2)
223 {
224 struct sigaction sa = {
225 .sa_handler = sig_int,
226 .sa_flags = SA_RESTART,
227 };
228 time_t seed = time(NULL);
229
230 sigemptyset(&sa.sa_mask);
231 if (sigaction(SIGINT, &sa, NULL))
232 test_error("Can't set SIGINT handler");
233
234 test_family = family;
235 if (!kernel_config_has(KCONFIG_NET_NS))
236 __test_skip_all(tests_skip_reason[KCONFIG_NET_NS]);
237 if (!kernel_config_has(KCONFIG_VETH))
238 __test_skip_all(tests_skip_reason[KCONFIG_VETH]);
239 if (!kernel_config_has(KCONFIG_TCP_AO))
240 __test_skip_all(tests_skip_reason[KCONFIG_TCP_AO]);
241
242 ksft_set_plan(ntests);
243 test_print("rand seed %u", (unsigned int)seed);
244 srand(seed);
245
246
247 ksft_print_header();
248 init_namespaces();
249
250 if (add_veth(veth_name, nsfd_parent, nsfd_child))
251 test_error("Failed to add veth");
252
253 switch_ns(nsfd_child);
254 link_init(veth_name, family, prefix, addr2, addr1);
255 if (peer2) {
256 struct new_pthread_arg targ;
257 pthread_t t;
258
259 targ.my_ip = addr2;
260 targ.dest_ip = addr1;
261 targ.func = peer2;
262 nr_threads++;
263 if (pthread_create(&t, NULL, new_pthread_entry, &targ))
264 test_error("Failed to create pthread");
265 }
266 switch_ns(nsfd_parent);
267 link_init(veth_name, family, prefix, addr1, addr2);
268
269 this_ip_addr = addr1;
270 this_ip_dest = addr2;
271 peer1(NULL);
272 if (failed)
273 exit(KSFT_FAIL);
274 else
275 exit(KSFT_PASS);
276 }
277
278 /* /proc/sys/net/core/optmem_max artifically limits the amount of memory
279 * that can be allocated with sock_kmalloc() on each socket in the system.
280 * It is not virtualized in v6.7, so it has to written outside test
281 * namespaces. To be nice a test will revert optmem back to the old value.
282 * Keeping it simple without any file lock, which means the tests that
283 * need to set/increase optmem value shouldn't run in parallel.
284 * Also, not re-entrant.
285 * Since commit f5769faeec36 ("net: Namespace-ify sysctl_optmem_max")
286 * it is per-namespace, keeping logic for non-virtualized optmem_max
287 * for v6.7, which supports TCP-AO.
288 */
289 static const char *optmem_file = "/proc/sys/net/core/optmem_max";
290 static size_t saved_optmem;
291 static int optmem_ns = -1;
292
is_optmem_namespaced(void)293 static bool is_optmem_namespaced(void)
294 {
295 if (optmem_ns == -1) {
296 int old_ns = switch_save_ns(nsfd_child);
297
298 optmem_ns = !access(optmem_file, F_OK);
299 switch_ns(old_ns);
300 }
301 return !!optmem_ns;
302 }
303
test_get_optmem(void)304 size_t test_get_optmem(void)
305 {
306 int old_ns = 0;
307 FILE *foptmem;
308 size_t ret;
309
310 if (!is_optmem_namespaced())
311 old_ns = switch_save_ns(nsfd_outside);
312 foptmem = fopen(optmem_file, "r");
313 if (!foptmem)
314 test_error("failed to open %s", optmem_file);
315
316 if (fscanf(foptmem, "%zu", &ret) != 1)
317 test_error("can't read from %s", optmem_file);
318 fclose(foptmem);
319 if (!is_optmem_namespaced())
320 switch_ns(old_ns);
321 return ret;
322 }
323
__test_set_optmem(size_t new,size_t * old)324 static void __test_set_optmem(size_t new, size_t *old)
325 {
326 int old_ns = 0;
327 FILE *foptmem;
328
329 if (old != NULL)
330 *old = test_get_optmem();
331
332 if (!is_optmem_namespaced())
333 old_ns = switch_save_ns(nsfd_outside);
334 foptmem = fopen(optmem_file, "w");
335 if (!foptmem)
336 test_error("failed to open %s", optmem_file);
337
338 if (fprintf(foptmem, "%zu", new) <= 0)
339 test_error("can't write %zu to %s", new, optmem_file);
340 fclose(foptmem);
341 if (!is_optmem_namespaced())
342 switch_ns(old_ns);
343 }
344
test_revert_optmem(void)345 static void test_revert_optmem(void)
346 {
347 if (saved_optmem == 0)
348 return;
349
350 __test_set_optmem(saved_optmem, NULL);
351 }
352
test_set_optmem(size_t value)353 void test_set_optmem(size_t value)
354 {
355 if (saved_optmem == 0) {
356 __test_set_optmem(value, &saved_optmem);
357 test_add_destructor(test_revert_optmem);
358 } else {
359 __test_set_optmem(value, NULL);
360 }
361 }
362