1 // SPDX-License-Identifier: GPL-2.0
2 #include <assert.h>
3 #include <errno.h>
4 #include <error.h>
5 #include <fcntl.h>
6 #include <limits.h>
7 #include <stdbool.h>
8 #include <stdint.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <unistd.h>
13 
14 #include <arpa/inet.h>
15 #include <linux/errqueue.h>
16 #include <linux/if_packet.h>
17 #include <linux/ipv6.h>
18 #include <linux/socket.h>
19 #include <linux/sockios.h>
20 #include <net/ethernet.h>
21 #include <net/if.h>
22 #include <netinet/in.h>
23 #include <netinet/ip.h>
24 #include <netinet/ip6.h>
25 #include <netinet/tcp.h>
26 #include <netinet/udp.h>
27 #include <sys/epoll.h>
28 #include <sys/ioctl.h>
29 #include <sys/mman.h>
30 #include <sys/resource.h>
31 #include <sys/socket.h>
32 #include <sys/stat.h>
33 #include <sys/time.h>
34 #include <sys/types.h>
35 #include <sys/un.h>
36 #include <sys/wait.h>
37 
38 #include <liburing.h>
39 
40 #define PAGE_SIZE (4096)
41 #define AREA_SIZE (8192 * PAGE_SIZE)
42 #define SEND_SIZE (512 * 4096)
43 #define min(a, b) \
44 	({ \
45 		typeof(a) _a = (a); \
46 		typeof(b) _b = (b); \
47 		_a < _b ? _a : _b; \
48 	})
49 #define min_t(t, a, b) \
50 	({ \
51 		t _ta = (a); \
52 		t _tb = (b); \
53 		min(_ta, _tb); \
54 	})
55 
56 #define ALIGN_UP(v, align) (((v) + (align) - 1) & ~((align) - 1))
57 
58 static int cfg_server;
59 static int cfg_client;
60 static int cfg_port = 8000;
61 static int cfg_payload_len;
62 static const char *cfg_ifname;
63 static int cfg_queue_id = -1;
64 static bool cfg_oneshot;
65 static int cfg_oneshot_recvs;
66 static int cfg_send_size = SEND_SIZE;
67 static struct sockaddr_in6 cfg_addr;
68 
69 static char payload[SEND_SIZE] __attribute__((aligned(PAGE_SIZE)));
70 static void *area_ptr;
71 static void *ring_ptr;
72 static size_t ring_size;
73 static struct io_uring_zcrx_rq rq_ring;
74 static unsigned long area_token;
75 static int connfd;
76 static bool stop;
77 static size_t received;
78 
gettimeofday_ms(void)79 static unsigned long gettimeofday_ms(void)
80 {
81 	struct timeval tv;
82 
83 	gettimeofday(&tv, NULL);
84 	return (tv.tv_sec * 1000) + (tv.tv_usec / 1000);
85 }
86 
parse_address(const char * str,int port,struct sockaddr_in6 * sin6)87 static int parse_address(const char *str, int port, struct sockaddr_in6 *sin6)
88 {
89 	int ret;
90 
91 	sin6->sin6_family = AF_INET6;
92 	sin6->sin6_port = htons(port);
93 
94 	ret = inet_pton(sin6->sin6_family, str, &sin6->sin6_addr);
95 	if (ret != 1) {
96 		/* fallback to plain IPv4 */
97 		ret = inet_pton(AF_INET, str, &sin6->sin6_addr.s6_addr32[3]);
98 		if (ret != 1)
99 			return -1;
100 
101 		/* add ::ffff prefix */
102 		sin6->sin6_addr.s6_addr32[0] = 0;
103 		sin6->sin6_addr.s6_addr32[1] = 0;
104 		sin6->sin6_addr.s6_addr16[4] = 0;
105 		sin6->sin6_addr.s6_addr16[5] = 0xffff;
106 	}
107 
108 	return 0;
109 }
110 
get_refill_ring_size(unsigned int rq_entries)111 static inline size_t get_refill_ring_size(unsigned int rq_entries)
112 {
113 	size_t size;
114 
115 	ring_size = rq_entries * sizeof(struct io_uring_zcrx_rqe);
116 	/* add space for the header (head/tail/etc.) */
117 	ring_size += PAGE_SIZE;
118 	return ALIGN_UP(ring_size, 4096);
119 }
120 
setup_zcrx(struct io_uring * ring)121 static void setup_zcrx(struct io_uring *ring)
122 {
123 	unsigned int ifindex;
124 	unsigned int rq_entries = 4096;
125 	int ret;
126 
127 	ifindex = if_nametoindex(cfg_ifname);
128 	if (!ifindex)
129 		error(1, 0, "bad interface name: %s", cfg_ifname);
130 
131 	area_ptr = mmap(NULL,
132 			AREA_SIZE,
133 			PROT_READ | PROT_WRITE,
134 			MAP_ANONYMOUS | MAP_PRIVATE,
135 			0,
136 			0);
137 	if (area_ptr == MAP_FAILED)
138 		error(1, 0, "mmap(): zero copy area");
139 
140 	ring_size = get_refill_ring_size(rq_entries);
141 	ring_ptr = mmap(NULL,
142 			ring_size,
143 			PROT_READ | PROT_WRITE,
144 			MAP_ANONYMOUS | MAP_PRIVATE,
145 			0,
146 			0);
147 
148 	struct io_uring_region_desc region_reg = {
149 		.size = ring_size,
150 		.user_addr = (__u64)(unsigned long)ring_ptr,
151 		.flags = IORING_MEM_REGION_TYPE_USER,
152 	};
153 
154 	struct io_uring_zcrx_area_reg area_reg = {
155 		.addr = (__u64)(unsigned long)area_ptr,
156 		.len = AREA_SIZE,
157 		.flags = 0,
158 	};
159 
160 	struct io_uring_zcrx_ifq_reg reg = {
161 		.if_idx = ifindex,
162 		.if_rxq = cfg_queue_id,
163 		.rq_entries = rq_entries,
164 		.area_ptr = (__u64)(unsigned long)&area_reg,
165 		.region_ptr = (__u64)(unsigned long)&region_reg,
166 	};
167 
168 	ret = io_uring_register_ifq(ring, &reg);
169 	if (ret)
170 		error(1, 0, "io_uring_register_ifq(): %d", ret);
171 
172 	rq_ring.khead = (unsigned int *)((char *)ring_ptr + reg.offsets.head);
173 	rq_ring.ktail = (unsigned int *)((char *)ring_ptr + reg.offsets.tail);
174 	rq_ring.rqes = (struct io_uring_zcrx_rqe *)((char *)ring_ptr + reg.offsets.rqes);
175 	rq_ring.rq_tail = 0;
176 	rq_ring.ring_entries = reg.rq_entries;
177 
178 	area_token = area_reg.rq_area_token;
179 }
180 
add_accept(struct io_uring * ring,int sockfd)181 static void add_accept(struct io_uring *ring, int sockfd)
182 {
183 	struct io_uring_sqe *sqe;
184 
185 	sqe = io_uring_get_sqe(ring);
186 
187 	io_uring_prep_accept(sqe, sockfd, NULL, NULL, 0);
188 	sqe->user_data = 1;
189 }
190 
add_recvzc(struct io_uring * ring,int sockfd)191 static void add_recvzc(struct io_uring *ring, int sockfd)
192 {
193 	struct io_uring_sqe *sqe;
194 
195 	sqe = io_uring_get_sqe(ring);
196 
197 	io_uring_prep_rw(IORING_OP_RECV_ZC, sqe, sockfd, NULL, 0, 0);
198 	sqe->ioprio |= IORING_RECV_MULTISHOT;
199 	sqe->user_data = 2;
200 }
201 
add_recvzc_oneshot(struct io_uring * ring,int sockfd,size_t len)202 static void add_recvzc_oneshot(struct io_uring *ring, int sockfd, size_t len)
203 {
204 	struct io_uring_sqe *sqe;
205 
206 	sqe = io_uring_get_sqe(ring);
207 
208 	io_uring_prep_rw(IORING_OP_RECV_ZC, sqe, sockfd, NULL, len, 0);
209 	sqe->ioprio |= IORING_RECV_MULTISHOT;
210 	sqe->user_data = 2;
211 }
212 
process_accept(struct io_uring * ring,struct io_uring_cqe * cqe)213 static void process_accept(struct io_uring *ring, struct io_uring_cqe *cqe)
214 {
215 	if (cqe->res < 0)
216 		error(1, 0, "accept()");
217 	if (connfd)
218 		error(1, 0, "Unexpected second connection");
219 
220 	connfd = cqe->res;
221 	if (cfg_oneshot)
222 		add_recvzc_oneshot(ring, connfd, PAGE_SIZE);
223 	else
224 		add_recvzc(ring, connfd);
225 }
226 
process_recvzc(struct io_uring * ring,struct io_uring_cqe * cqe)227 static void process_recvzc(struct io_uring *ring, struct io_uring_cqe *cqe)
228 {
229 	unsigned rq_mask = rq_ring.ring_entries - 1;
230 	struct io_uring_zcrx_cqe *rcqe;
231 	struct io_uring_zcrx_rqe *rqe;
232 	struct io_uring_sqe *sqe;
233 	uint64_t mask;
234 	char *data;
235 	ssize_t n;
236 	int i;
237 
238 	if (cqe->res == 0 && cqe->flags == 0 && cfg_oneshot_recvs == 0) {
239 		stop = true;
240 		return;
241 	}
242 
243 	if (cqe->res < 0)
244 		error(1, 0, "recvzc(): %d", cqe->res);
245 
246 	if (cfg_oneshot) {
247 		if (cqe->res == 0 && cqe->flags == 0 && cfg_oneshot_recvs) {
248 			add_recvzc_oneshot(ring, connfd, PAGE_SIZE);
249 			cfg_oneshot_recvs--;
250 		}
251 	} else if (!(cqe->flags & IORING_CQE_F_MORE)) {
252 		add_recvzc(ring, connfd);
253 	}
254 
255 	rcqe = (struct io_uring_zcrx_cqe *)(cqe + 1);
256 
257 	n = cqe->res;
258 	mask = (1ULL << IORING_ZCRX_AREA_SHIFT) - 1;
259 	data = (char *)area_ptr + (rcqe->off & mask);
260 
261 	for (i = 0; i < n; i++) {
262 		if (*(data + i) != payload[(received + i)])
263 			error(1, 0, "payload mismatch at ", i);
264 	}
265 	received += n;
266 
267 	rqe = &rq_ring.rqes[(rq_ring.rq_tail & rq_mask)];
268 	rqe->off = (rcqe->off & ~IORING_ZCRX_AREA_MASK) | area_token;
269 	rqe->len = cqe->res;
270 	io_uring_smp_store_release(rq_ring.ktail, ++rq_ring.rq_tail);
271 }
272 
server_loop(struct io_uring * ring)273 static void server_loop(struct io_uring *ring)
274 {
275 	struct io_uring_cqe *cqe;
276 	unsigned int count = 0;
277 	unsigned int head;
278 	int i, ret;
279 
280 	io_uring_submit_and_wait(ring, 1);
281 
282 	io_uring_for_each_cqe(ring, head, cqe) {
283 		if (cqe->user_data == 1)
284 			process_accept(ring, cqe);
285 		else if (cqe->user_data == 2)
286 			process_recvzc(ring, cqe);
287 		else
288 			error(1, 0, "unknown cqe");
289 		count++;
290 	}
291 	io_uring_cq_advance(ring, count);
292 }
293 
run_server(void)294 static void run_server(void)
295 {
296 	unsigned int flags = 0;
297 	struct io_uring ring;
298 	int fd, enable, ret;
299 	uint64_t tstop;
300 
301 	fd = socket(AF_INET6, SOCK_STREAM, 0);
302 	if (fd == -1)
303 		error(1, 0, "socket()");
304 
305 	enable = 1;
306 	ret = setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
307 	if (ret < 0)
308 		error(1, 0, "setsockopt(SO_REUSEADDR)");
309 
310 	ret = bind(fd, (struct sockaddr *)&cfg_addr, sizeof(cfg_addr));
311 	if (ret < 0)
312 		error(1, 0, "bind()");
313 
314 	if (listen(fd, 1024) < 0)
315 		error(1, 0, "listen()");
316 
317 	flags |= IORING_SETUP_COOP_TASKRUN;
318 	flags |= IORING_SETUP_SINGLE_ISSUER;
319 	flags |= IORING_SETUP_DEFER_TASKRUN;
320 	flags |= IORING_SETUP_SUBMIT_ALL;
321 	flags |= IORING_SETUP_CQE32;
322 
323 	io_uring_queue_init(512, &ring, flags);
324 
325 	setup_zcrx(&ring);
326 
327 	add_accept(&ring, fd);
328 
329 	tstop = gettimeofday_ms() + 5000;
330 	while (!stop && gettimeofday_ms() < tstop)
331 		server_loop(&ring);
332 
333 	if (!stop)
334 		error(1, 0, "test failed\n");
335 }
336 
run_client(void)337 static void run_client(void)
338 {
339 	ssize_t to_send = cfg_send_size;
340 	ssize_t sent = 0;
341 	ssize_t chunk, res;
342 	int fd;
343 
344 	fd = socket(AF_INET6, SOCK_STREAM, 0);
345 	if (fd == -1)
346 		error(1, 0, "socket()");
347 
348 	if (connect(fd, (struct sockaddr *)&cfg_addr, sizeof(cfg_addr)))
349 		error(1, 0, "connect()");
350 
351 	while (to_send) {
352 		void *src = &payload[sent];
353 
354 		chunk = min_t(ssize_t, cfg_payload_len, to_send);
355 		res = send(fd, src, chunk, 0);
356 		if (res < 0)
357 			error(1, 0, "send(): %d", sent);
358 		sent += res;
359 		to_send -= res;
360 	}
361 
362 	close(fd);
363 }
364 
usage(const char * filepath)365 static void usage(const char *filepath)
366 {
367 	error(1, 0, "Usage: %s (-4|-6) (-s|-c) -h<server_ip> -p<port> "
368 		    "-l<payload_size> -i<ifname> -q<rxq_id>", filepath);
369 }
370 
parse_opts(int argc,char ** argv)371 static void parse_opts(int argc, char **argv)
372 {
373 	const int max_payload_len = sizeof(payload) -
374 				    sizeof(struct ipv6hdr) -
375 				    sizeof(struct tcphdr) -
376 				    40 /* max tcp options */;
377 	struct sockaddr_in6 *addr6 = (void *) &cfg_addr;
378 	char *addr = NULL;
379 	int ret;
380 	int c;
381 
382 	if (argc <= 1)
383 		usage(argv[0]);
384 	cfg_payload_len = max_payload_len;
385 
386 	while ((c = getopt(argc, argv, "sch:p:l:i:q:o:z:")) != -1) {
387 		switch (c) {
388 		case 's':
389 			if (cfg_client)
390 				error(1, 0, "Pass one of -s or -c");
391 			cfg_server = 1;
392 			break;
393 		case 'c':
394 			if (cfg_server)
395 				error(1, 0, "Pass one of -s or -c");
396 			cfg_client = 1;
397 			break;
398 		case 'h':
399 			addr = optarg;
400 			break;
401 		case 'p':
402 			cfg_port = strtoul(optarg, NULL, 0);
403 			break;
404 		case 'l':
405 			cfg_payload_len = strtoul(optarg, NULL, 0);
406 			break;
407 		case 'i':
408 			cfg_ifname = optarg;
409 			break;
410 		case 'q':
411 			cfg_queue_id = strtoul(optarg, NULL, 0);
412 			break;
413 		case 'o': {
414 			cfg_oneshot = true;
415 			cfg_oneshot_recvs = strtoul(optarg, NULL, 0);
416 			break;
417 		}
418 		case 'z':
419 			cfg_send_size = strtoul(optarg, NULL, 0);
420 			break;
421 		}
422 	}
423 
424 	if (cfg_server && addr)
425 		error(1, 0, "Receiver cannot have -h specified");
426 
427 	memset(addr6, 0, sizeof(*addr6));
428 	addr6->sin6_family = AF_INET6;
429 	addr6->sin6_port = htons(cfg_port);
430 	addr6->sin6_addr = in6addr_any;
431 	if (addr) {
432 		ret = parse_address(addr, cfg_port, addr6);
433 		if (ret)
434 			error(1, 0, "receiver address parse error: %s", addr);
435 	}
436 
437 	if (cfg_payload_len > max_payload_len)
438 		error(1, 0, "-l: payload exceeds max (%d)", max_payload_len);
439 }
440 
main(int argc,char ** argv)441 int main(int argc, char **argv)
442 {
443 	const char *cfg_test = argv[argc - 1];
444 	int i;
445 
446 	parse_opts(argc, argv);
447 
448 	for (i = 0; i < SEND_SIZE; i++)
449 		payload[i] = 'a' + (i % 26);
450 
451 	if (cfg_server)
452 		run_server();
453 	else if (cfg_client)
454 		run_client();
455 
456 	return 0;
457 }
458