1 /* SPDX-License-Identifier: GPL-2.0 */
2 
3 #ifndef __LWT_HELPERS_H
4 #define __LWT_HELPERS_H
5 
6 #include <time.h>
7 #include <net/if.h>
8 #include <linux/icmp.h>
9 
10 #include "test_progs.h"
11 
12 #define log_err(MSG, ...) \
13 	fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \
14 		__FILE__, __LINE__, strerror(errno), ##__VA_ARGS__)
15 
16 #define RUN_TEST(name)                                                        \
17 	({                                                                    \
18 		if (test__start_subtest(#name))                               \
19 			if (ASSERT_OK(netns_create(), "netns_create")) {      \
20 				struct nstoken *token = open_netns(NETNS);    \
21 				if (ASSERT_OK_PTR(token, "setns")) {          \
22 					test_ ## name();                      \
23 					close_netns(token);                   \
24 				}                                             \
25 				netns_delete();                               \
26 			}                                                     \
27 	})
28 
netns_create(void)29 static inline int netns_create(void)
30 {
31 	return system("ip netns add " NETNS);
32 }
33 
netns_delete(void)34 static inline int netns_delete(void)
35 {
36 	return system("ip netns del " NETNS ">/dev/null 2>&1");
37 }
38 
39 #define ICMP_PAYLOAD_SIZE     100
40 
41 /* Match an ICMP packet with payload len ICMP_PAYLOAD_SIZE */
__expect_icmp_ipv4(char * buf,ssize_t len)42 static int __expect_icmp_ipv4(char *buf, ssize_t len)
43 {
44 	struct iphdr *ip = (struct iphdr *)buf;
45 	struct icmphdr *icmp = (struct icmphdr *)(ip + 1);
46 	ssize_t min_header_len = sizeof(*ip) + sizeof(*icmp);
47 
48 	if (len < min_header_len)
49 		return -1;
50 
51 	if (ip->protocol != IPPROTO_ICMP)
52 		return -1;
53 
54 	if (icmp->type != ICMP_ECHO)
55 		return -1;
56 
57 	return len == ICMP_PAYLOAD_SIZE + min_header_len;
58 }
59 
60 typedef int (*filter_t) (char *, ssize_t);
61 
62 /* wait_for_packet - wait for a packet that matches the filter
63  *
64  * @fd: tun fd/packet socket to read packet
65  * @filter: filter function, returning 1 if matches
66  * @timeout: timeout to wait for the packet
67  *
68  * Returns 1 if a matching packet is read, 0 if timeout expired, -1 on error.
69  */
wait_for_packet(int fd,filter_t filter,struct timeval * timeout)70 static int wait_for_packet(int fd, filter_t filter, struct timeval *timeout)
71 {
72 	char buf[4096];
73 	int max_retry = 5; /* in case we read some spurious packets */
74 	fd_set fds;
75 
76 	FD_ZERO(&fds);
77 	while (max_retry--) {
78 		/* Linux modifies timeout arg... So make a copy */
79 		struct timeval copied_timeout = *timeout;
80 		ssize_t ret = -1;
81 
82 		FD_SET(fd, &fds);
83 
84 		ret = select(1 + fd, &fds, NULL, NULL, &copied_timeout);
85 		if (ret <= 0) {
86 			if (errno == EINTR)
87 				continue;
88 			else if (errno == EAGAIN || ret == 0)
89 				return 0;
90 
91 			log_err("select failed");
92 			return -1;
93 		}
94 
95 		ret = read(fd, buf, sizeof(buf));
96 
97 		if (ret <= 0) {
98 			log_err("read(dev): %ld", ret);
99 			return -1;
100 		}
101 
102 		if (filter && filter(buf, ret) > 0)
103 			return 1;
104 	}
105 
106 	return 0;
107 }
108 
109 #endif /* __LWT_HELPERS_H */
110