1 /* SPDX-License-Identifier: GPL-2.0 */ 2 3 #ifndef __LWT_HELPERS_H 4 #define __LWT_HELPERS_H 5 6 #include <time.h> 7 #include <net/if.h> 8 #include <linux/icmp.h> 9 10 #include "test_progs.h" 11 12 #define log_err(MSG, ...) \ 13 fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \ 14 __FILE__, __LINE__, strerror(errno), ##__VA_ARGS__) 15 16 #define RUN_TEST(name) \ 17 ({ \ 18 if (test__start_subtest(#name)) \ 19 if (ASSERT_OK(netns_create(), "netns_create")) { \ 20 struct nstoken *token = open_netns(NETNS); \ 21 if (ASSERT_OK_PTR(token, "setns")) { \ 22 test_ ## name(); \ 23 close_netns(token); \ 24 } \ 25 netns_delete(); \ 26 } \ 27 }) 28 29 static inline int netns_create(void) 30 { 31 return system("ip netns add " NETNS); 32 } 33 34 static inline int netns_delete(void) 35 { 36 return system("ip netns del " NETNS ">/dev/null 2>&1"); 37 } 38 39 #define ICMP_PAYLOAD_SIZE 100 40 41 /* Match an ICMP packet with payload len ICMP_PAYLOAD_SIZE */ 42 static int __expect_icmp_ipv4(char *buf, ssize_t len) 43 { 44 struct iphdr *ip = (struct iphdr *)buf; 45 struct icmphdr *icmp = (struct icmphdr *)(ip + 1); 46 ssize_t min_header_len = sizeof(*ip) + sizeof(*icmp); 47 48 if (len < min_header_len) 49 return -1; 50 51 if (ip->protocol != IPPROTO_ICMP) 52 return -1; 53 54 if (icmp->type != ICMP_ECHO) 55 return -1; 56 57 return len == ICMP_PAYLOAD_SIZE + min_header_len; 58 } 59 60 typedef int (*filter_t) (char *, ssize_t); 61 62 /* wait_for_packet - wait for a packet that matches the filter 63 * 64 * @fd: tun fd/packet socket to read packet 65 * @filter: filter function, returning 1 if matches 66 * @timeout: timeout to wait for the packet 67 * 68 * Returns 1 if a matching packet is read, 0 if timeout expired, -1 on error. 69 */ 70 static int wait_for_packet(int fd, filter_t filter, struct timeval *timeout) 71 { 72 char buf[4096]; 73 int max_retry = 5; /* in case we read some spurious packets */ 74 fd_set fds; 75 76 FD_ZERO(&fds); 77 while (max_retry--) { 78 /* Linux modifies timeout arg... So make a copy */ 79 struct timeval copied_timeout = *timeout; 80 ssize_t ret = -1; 81 82 FD_SET(fd, &fds); 83 84 ret = select(1 + fd, &fds, NULL, NULL, &copied_timeout); 85 if (ret <= 0) { 86 if (errno == EINTR) 87 continue; 88 else if (errno == EAGAIN || ret == 0) 89 return 0; 90 91 log_err("select failed"); 92 return -1; 93 } 94 95 ret = read(fd, buf, sizeof(buf)); 96 97 if (ret <= 0) { 98 log_err("read(dev): %ld", ret); 99 return -1; 100 } 101 102 if (filter && filter(buf, ret) > 0) 103 return 1; 104 } 105 106 return 0; 107 } 108 109 #endif /* __LWT_HELPERS_H */ 110