xref: /linux/tools/testing/selftests/net/lib/xdp_native.bpf.c (revision 8be4d31cb8aaeea27bde4b7ddb26e28a89062ebf)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <stddef.h>
4 #include <linux/bpf.h>
5 #include <linux/in.h>
6 #include <linux/if_ether.h>
7 #include <linux/ip.h>
8 #include <linux/ipv6.h>
9 #include <linux/udp.h>
10 #include <bpf/bpf_endian.h>
11 #include <bpf/bpf_helpers.h>
12 
13 #define MAX_ADJST_OFFSET 256
14 #define MAX_PAYLOAD_LEN 5000
15 #define MAX_HDR_LEN 64
16 
17 enum {
18 	XDP_MODE = 0,
19 	XDP_PORT = 1,
20 	XDP_ADJST_OFFSET = 2,
21 	XDP_ADJST_TAG = 3,
22 } xdp_map_setup_keys;
23 
24 enum {
25 	XDP_MODE_PASS = 0,
26 	XDP_MODE_DROP = 1,
27 	XDP_MODE_TX = 2,
28 	XDP_MODE_TAIL_ADJST = 3,
29 	XDP_MODE_HEAD_ADJST = 4,
30 } xdp_map_modes;
31 
32 enum {
33 	STATS_RX = 0,
34 	STATS_PASS = 1,
35 	STATS_DROP = 2,
36 	STATS_TX = 3,
37 	STATS_ABORT = 4,
38 } xdp_stats;
39 
40 struct {
41 	__uint(type, BPF_MAP_TYPE_ARRAY);
42 	__uint(max_entries, 5);
43 	__type(key, __u32);
44 	__type(value, __s32);
45 } map_xdp_setup SEC(".maps");
46 
47 struct {
48 	__uint(type, BPF_MAP_TYPE_ARRAY);
49 	__uint(max_entries, 5);
50 	__type(key, __u32);
51 	__type(value, __u64);
52 } map_xdp_stats SEC(".maps");
53 
min(__u32 a,__u32 b)54 static __u32 min(__u32 a, __u32 b)
55 {
56 	return a < b ? a : b;
57 }
58 
record_stats(struct xdp_md * ctx,__u32 stat_type)59 static void record_stats(struct xdp_md *ctx, __u32 stat_type)
60 {
61 	__u64 *count;
62 
63 	count = bpf_map_lookup_elem(&map_xdp_stats, &stat_type);
64 
65 	if (count)
66 		__sync_fetch_and_add(count, 1);
67 }
68 
filter_udphdr(struct xdp_md * ctx,__u16 port)69 static struct udphdr *filter_udphdr(struct xdp_md *ctx, __u16 port)
70 {
71 	void *data_end = (void *)(long)ctx->data_end;
72 	void *data = (void *)(long)ctx->data;
73 	struct udphdr *udph = NULL;
74 	struct ethhdr *eth = data;
75 
76 	if (data + sizeof(*eth) > data_end)
77 		return NULL;
78 
79 	if (eth->h_proto == bpf_htons(ETH_P_IP)) {
80 		struct iphdr *iph = data + sizeof(*eth);
81 
82 		if (iph + 1 > (struct iphdr *)data_end ||
83 		    iph->protocol != IPPROTO_UDP)
84 			return NULL;
85 
86 		udph = (void *)eth + sizeof(*iph) + sizeof(*eth);
87 	} else if (eth->h_proto  == bpf_htons(ETH_P_IPV6)) {
88 		struct ipv6hdr *ipv6h = data + sizeof(*eth);
89 
90 		if (ipv6h + 1 > (struct ipv6hdr *)data_end ||
91 		    ipv6h->nexthdr != IPPROTO_UDP)
92 			return NULL;
93 
94 		udph = (void *)eth + sizeof(*ipv6h) + sizeof(*eth);
95 	} else {
96 		return NULL;
97 	}
98 
99 	if (udph + 1 > (struct udphdr *)data_end)
100 		return NULL;
101 
102 	if (udph->dest != bpf_htons(port))
103 		return NULL;
104 
105 	record_stats(ctx, STATS_RX);
106 
107 	return udph;
108 }
109 
xdp_mode_pass(struct xdp_md * ctx,__u16 port)110 static int xdp_mode_pass(struct xdp_md *ctx, __u16 port)
111 {
112 	struct udphdr *udph = NULL;
113 
114 	udph = filter_udphdr(ctx, port);
115 	if (!udph)
116 		return XDP_PASS;
117 
118 	record_stats(ctx, STATS_PASS);
119 
120 	return XDP_PASS;
121 }
122 
xdp_mode_drop_handler(struct xdp_md * ctx,__u16 port)123 static int xdp_mode_drop_handler(struct xdp_md *ctx, __u16 port)
124 {
125 	struct udphdr *udph = NULL;
126 
127 	udph = filter_udphdr(ctx, port);
128 	if (!udph)
129 		return XDP_PASS;
130 
131 	record_stats(ctx, STATS_DROP);
132 
133 	return XDP_DROP;
134 }
135 
swap_machdr(void * data)136 static void swap_machdr(void *data)
137 {
138 	struct ethhdr *eth = data;
139 	__u8 tmp_mac[ETH_ALEN];
140 
141 	__builtin_memcpy(tmp_mac, eth->h_source, ETH_ALEN);
142 	__builtin_memcpy(eth->h_source, eth->h_dest, ETH_ALEN);
143 	__builtin_memcpy(eth->h_dest, tmp_mac, ETH_ALEN);
144 }
145 
xdp_mode_tx_handler(struct xdp_md * ctx,__u16 port)146 static int xdp_mode_tx_handler(struct xdp_md *ctx, __u16 port)
147 {
148 	void *data_end = (void *)(long)ctx->data_end;
149 	void *data = (void *)(long)ctx->data;
150 	struct udphdr *udph = NULL;
151 	struct ethhdr *eth = data;
152 
153 	if (data + sizeof(*eth) > data_end)
154 		return XDP_PASS;
155 
156 	if (eth->h_proto == bpf_htons(ETH_P_IP)) {
157 		struct iphdr *iph = data + sizeof(*eth);
158 		__be32 tmp_ip = iph->saddr;
159 
160 		if (iph + 1 > (struct iphdr *)data_end ||
161 		    iph->protocol != IPPROTO_UDP)
162 			return XDP_PASS;
163 
164 		udph = data + sizeof(*iph) + sizeof(*eth);
165 
166 		if (udph + 1 > (struct udphdr *)data_end)
167 			return XDP_PASS;
168 		if (udph->dest != bpf_htons(port))
169 			return XDP_PASS;
170 
171 		record_stats(ctx, STATS_RX);
172 		swap_machdr((void *)eth);
173 
174 		iph->saddr = iph->daddr;
175 		iph->daddr = tmp_ip;
176 
177 		record_stats(ctx, STATS_TX);
178 
179 		return XDP_TX;
180 
181 	} else if (eth->h_proto  == bpf_htons(ETH_P_IPV6)) {
182 		struct ipv6hdr *ipv6h = data + sizeof(*eth);
183 		struct in6_addr tmp_ipv6;
184 
185 		if (ipv6h + 1 > (struct ipv6hdr *)data_end ||
186 		    ipv6h->nexthdr != IPPROTO_UDP)
187 			return XDP_PASS;
188 
189 		udph = data + sizeof(*ipv6h) + sizeof(*eth);
190 
191 		if (udph + 1 > (struct udphdr *)data_end)
192 			return XDP_PASS;
193 		if (udph->dest != bpf_htons(port))
194 			return XDP_PASS;
195 
196 		record_stats(ctx, STATS_RX);
197 		swap_machdr((void *)eth);
198 
199 		__builtin_memcpy(&tmp_ipv6, &ipv6h->saddr, sizeof(tmp_ipv6));
200 		__builtin_memcpy(&ipv6h->saddr, &ipv6h->daddr,
201 				 sizeof(tmp_ipv6));
202 		__builtin_memcpy(&ipv6h->daddr, &tmp_ipv6, sizeof(tmp_ipv6));
203 
204 		record_stats(ctx, STATS_TX);
205 
206 		return XDP_TX;
207 	}
208 
209 	return XDP_PASS;
210 }
211 
update_pkt(struct xdp_md * ctx,__s16 offset,__u32 * udp_csum)212 static void *update_pkt(struct xdp_md *ctx, __s16 offset, __u32 *udp_csum)
213 {
214 	void *data_end = (void *)(long)ctx->data_end;
215 	void *data = (void *)(long)ctx->data;
216 	struct udphdr *udph = NULL;
217 	struct ethhdr *eth = data;
218 	__u32 len, len_new;
219 
220 	if (data + sizeof(*eth) > data_end)
221 		return NULL;
222 
223 	if (eth->h_proto == bpf_htons(ETH_P_IP)) {
224 		struct iphdr *iph = data + sizeof(*eth);
225 		__u16 total_len;
226 
227 		if (iph + 1 > (struct iphdr *)data_end)
228 			return NULL;
229 
230 		iph->tot_len = bpf_htons(bpf_ntohs(iph->tot_len) + offset);
231 
232 		udph = (void *)eth + sizeof(*iph) + sizeof(*eth);
233 		if (!udph || udph + 1 > (struct udphdr *)data_end)
234 			return NULL;
235 
236 		len_new = bpf_htons(bpf_ntohs(udph->len) + offset);
237 	} else if (eth->h_proto  == bpf_htons(ETH_P_IPV6)) {
238 		struct ipv6hdr *ipv6h = data + sizeof(*eth);
239 		__u16 payload_len;
240 
241 		if (ipv6h + 1 > (struct ipv6hdr *)data_end)
242 			return NULL;
243 
244 		udph = (void *)eth + sizeof(*ipv6h) + sizeof(*eth);
245 		if (!udph || udph + 1 > (struct udphdr *)data_end)
246 			return NULL;
247 
248 		*udp_csum = ~((__u32)udph->check);
249 
250 		len = ipv6h->payload_len;
251 		len_new = bpf_htons(bpf_ntohs(len) + offset);
252 		ipv6h->payload_len = len_new;
253 
254 		*udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new,
255 					  sizeof(len_new), *udp_csum);
256 
257 		len = udph->len;
258 		len_new = bpf_htons(bpf_ntohs(udph->len) + offset);
259 		*udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new,
260 					  sizeof(len_new), *udp_csum);
261 	} else {
262 		return NULL;
263 	}
264 
265 	udph->len = len_new;
266 
267 	return udph;
268 }
269 
csum_fold_helper(__u32 csum)270 static __u16 csum_fold_helper(__u32 csum)
271 {
272 	return ~((csum & 0xffff) + (csum >> 16)) ? : 0xffff;
273 }
274 
xdp_adjst_tail_shrnk_data(struct xdp_md * ctx,__u16 offset,__u32 hdr_len)275 static int xdp_adjst_tail_shrnk_data(struct xdp_md *ctx, __u16 offset,
276 				     __u32 hdr_len)
277 {
278 	char tmp_buff[MAX_ADJST_OFFSET];
279 	__u32 buff_pos, udp_csum = 0;
280 	struct udphdr *udph = NULL;
281 	__u32 buff_len;
282 
283 	udph = update_pkt(ctx, 0 - offset, &udp_csum);
284 	if (!udph)
285 		return -1;
286 
287 	buff_len = bpf_xdp_get_buff_len(ctx);
288 
289 	offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
290 				     offset & 0xff;
291 	if (offset == 0)
292 		return -1;
293 
294 	/* Make sure we have enough data to avoid eating the header */
295 	if (buff_len - offset < hdr_len)
296 		return -1;
297 
298 	buff_pos = buff_len - offset;
299 	if (bpf_xdp_load_bytes(ctx, buff_pos, tmp_buff, offset) < 0)
300 		return -1;
301 
302 	udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum);
303 	udph->check = (__u16)csum_fold_helper(udp_csum);
304 
305 	if (bpf_xdp_adjust_tail(ctx, 0 - offset) < 0)
306 		return -1;
307 
308 	return 0;
309 }
310 
xdp_adjst_tail_grow_data(struct xdp_md * ctx,__u16 offset)311 static int xdp_adjst_tail_grow_data(struct xdp_md *ctx, __u16 offset)
312 {
313 	char tmp_buff[MAX_ADJST_OFFSET];
314 	__u32 buff_pos, udp_csum = 0;
315 	__u32 buff_len, hdr_len, key;
316 	struct udphdr *udph;
317 	__s32 *val;
318 	__u8 tag;
319 
320 	/* Proceed to update the packet headers before attempting to adjuste
321 	 * the tail. Once the tail is adjusted we lose access to the offset
322 	 * amount of data at the end of the packet which is crucial to update
323 	 * the checksum.
324 	 * Since any failure beyond this would abort the packet, we should
325 	 * not worry about passing a packet up the stack with wrong headers
326 	 */
327 	udph = update_pkt(ctx, offset, &udp_csum);
328 	if (!udph)
329 		return -1;
330 
331 	key = XDP_ADJST_TAG;
332 	val = bpf_map_lookup_elem(&map_xdp_setup, &key);
333 	if (!val)
334 		return -1;
335 
336 	tag = (__u8)(*val);
337 
338 	for (int i = 0; i < MAX_ADJST_OFFSET; i++)
339 		__builtin_memcpy(&tmp_buff[i], &tag, 1);
340 
341 	offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
342 				     offset & 0xff;
343 	if (offset == 0)
344 		return -1;
345 
346 	udp_csum = bpf_csum_diff(0, 0, (__be32 *)tmp_buff, offset, udp_csum);
347 	udph->check = (__u16)csum_fold_helper(udp_csum);
348 
349 	buff_len = bpf_xdp_get_buff_len(ctx);
350 
351 	if (bpf_xdp_adjust_tail(ctx, offset) < 0) {
352 		bpf_printk("Failed to adjust tail\n");
353 		return -1;
354 	}
355 
356 	if (bpf_xdp_store_bytes(ctx, buff_len, tmp_buff, offset) < 0)
357 		return -1;
358 
359 	return 0;
360 }
361 
xdp_adjst_tail(struct xdp_md * ctx,__u16 port)362 static int xdp_adjst_tail(struct xdp_md *ctx, __u16 port)
363 {
364 	void *data = (void *)(long)ctx->data;
365 	struct udphdr *udph = NULL;
366 	__s32 *adjust_offset, *val;
367 	__u32 key, hdr_len;
368 	void *offset_ptr;
369 	__u8 tag;
370 	int ret;
371 
372 	udph = filter_udphdr(ctx, port);
373 	if (!udph)
374 		return XDP_PASS;
375 
376 	hdr_len = (void *)udph - data + sizeof(struct udphdr);
377 	key = XDP_ADJST_OFFSET;
378 	adjust_offset = bpf_map_lookup_elem(&map_xdp_setup, &key);
379 	if (!adjust_offset)
380 		return XDP_PASS;
381 
382 	if (*adjust_offset < 0)
383 		ret = xdp_adjst_tail_shrnk_data(ctx,
384 						(__u16)(0 - *adjust_offset),
385 						hdr_len);
386 	else
387 		ret = xdp_adjst_tail_grow_data(ctx, (__u16)(*adjust_offset));
388 	if (ret)
389 		goto abort_pkt;
390 
391 	record_stats(ctx, STATS_PASS);
392 	return XDP_PASS;
393 
394 abort_pkt:
395 	record_stats(ctx, STATS_ABORT);
396 	return XDP_ABORTED;
397 }
398 
xdp_adjst_head_shrnk_data(struct xdp_md * ctx,__u64 hdr_len,__u32 offset)399 static int xdp_adjst_head_shrnk_data(struct xdp_md *ctx, __u64 hdr_len,
400 				     __u32 offset)
401 {
402 	char tmp_buff[MAX_ADJST_OFFSET];
403 	struct udphdr *udph;
404 	void *offset_ptr;
405 	__u32 udp_csum = 0;
406 
407 	/* Update the length information in the IP and UDP headers before
408 	 * adjusting the headroom. This simplifies accessing the relevant
409 	 * fields in the IP and UDP headers for fragmented packets. Any
410 	 * failure beyond this point will result in the packet being aborted,
411 	 * so we don't need to worry about incorrect length information for
412 	 * passed packets.
413 	 */
414 	udph = update_pkt(ctx, (__s16)(0 - offset), &udp_csum);
415 	if (!udph)
416 		return -1;
417 
418 	offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
419 				     offset & 0xff;
420 	if (offset == 0)
421 		return -1;
422 
423 	if (bpf_xdp_load_bytes(ctx, hdr_len, tmp_buff, offset) < 0)
424 		return -1;
425 
426 	udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum);
427 
428 	udph->check = (__u16)csum_fold_helper(udp_csum);
429 
430 	if (bpf_xdp_load_bytes(ctx, 0, tmp_buff, MAX_ADJST_OFFSET) < 0)
431 		return -1;
432 
433 	if (bpf_xdp_adjust_head(ctx, offset) < 0)
434 		return -1;
435 
436 	if (offset > MAX_ADJST_OFFSET)
437 		return -1;
438 
439 	if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0)
440 		return -1;
441 
442 	/* Added here to handle clang complain about negative value */
443 	hdr_len = hdr_len & 0xff;
444 
445 	if (hdr_len == 0)
446 		return -1;
447 
448 	if (bpf_xdp_store_bytes(ctx, 0, tmp_buff, hdr_len) < 0)
449 		return -1;
450 
451 	return 0;
452 }
453 
xdp_adjst_head_grow_data(struct xdp_md * ctx,__u64 hdr_len,__u32 offset)454 static int xdp_adjst_head_grow_data(struct xdp_md *ctx, __u64 hdr_len,
455 				    __u32 offset)
456 {
457 	char hdr_buff[MAX_HDR_LEN];
458 	char data_buff[MAX_ADJST_OFFSET];
459 	void *offset_ptr;
460 	__s32 *val;
461 	__u32 key;
462 	__u8 tag;
463 	__u32 udp_csum = 0;
464 	struct udphdr *udph;
465 
466 	udph = update_pkt(ctx, (__s16)(offset), &udp_csum);
467 	if (!udph)
468 		return -1;
469 
470 	key = XDP_ADJST_TAG;
471 	val = bpf_map_lookup_elem(&map_xdp_setup, &key);
472 	if (!val)
473 		return -1;
474 
475 	tag = (__u8)(*val);
476 	for (int i = 0; i < MAX_ADJST_OFFSET; i++)
477 		__builtin_memcpy(&data_buff[i], &tag, 1);
478 
479 	offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
480 				     offset & 0xff;
481 	if (offset == 0)
482 		return -1;
483 
484 	udp_csum = bpf_csum_diff(0, 0, (__be32 *)data_buff, offset, udp_csum);
485 	udph->check = (__u16)csum_fold_helper(udp_csum);
486 
487 	if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0)
488 		return -1;
489 
490 	/* Added here to handle clang complain about negative value */
491 	hdr_len = hdr_len & 0xff;
492 
493 	if (hdr_len == 0)
494 		return -1;
495 
496 	if (bpf_xdp_load_bytes(ctx, 0, hdr_buff, hdr_len) < 0)
497 		return -1;
498 
499 	if (offset > MAX_ADJST_OFFSET)
500 		return -1;
501 
502 	if (bpf_xdp_adjust_head(ctx, 0 - offset) < 0)
503 		return -1;
504 
505 	if (bpf_xdp_store_bytes(ctx, 0, hdr_buff, hdr_len) < 0)
506 		return -1;
507 
508 	if (bpf_xdp_store_bytes(ctx, hdr_len, data_buff, offset) < 0)
509 		return -1;
510 
511 	return 0;
512 }
513 
xdp_head_adjst(struct xdp_md * ctx,__u16 port)514 static int xdp_head_adjst(struct xdp_md *ctx, __u16 port)
515 {
516 	void *data_end = (void *)(long)ctx->data_end;
517 	void *data = (void *)(long)ctx->data;
518 	struct udphdr *udph_ptr = NULL;
519 	__u32 key, size, hdr_len;
520 	__s32 *val;
521 	int res;
522 
523 	/* Filter packets based on UDP port */
524 	udph_ptr = filter_udphdr(ctx, port);
525 	if (!udph_ptr)
526 		return XDP_PASS;
527 
528 	hdr_len = (void *)udph_ptr - data + sizeof(struct udphdr);
529 
530 	key = XDP_ADJST_OFFSET;
531 	val = bpf_map_lookup_elem(&map_xdp_setup, &key);
532 	if (!val)
533 		return XDP_PASS;
534 
535 	switch (*val) {
536 	case -16:
537 	case 16:
538 		size = 16;
539 		break;
540 	case -32:
541 	case 32:
542 		size = 32;
543 		break;
544 	case -64:
545 	case 64:
546 		size = 64;
547 		break;
548 	case -128:
549 	case 128:
550 		size = 128;
551 		break;
552 	case -256:
553 	case 256:
554 		size = 256;
555 		break;
556 	default:
557 		bpf_printk("Invalid adjustment offset: %d\n", *val);
558 		goto abort;
559 	}
560 
561 	if (*val < 0)
562 		res = xdp_adjst_head_grow_data(ctx, hdr_len, size);
563 	else
564 		res = xdp_adjst_head_shrnk_data(ctx, hdr_len, size);
565 
566 	if (res)
567 		goto abort;
568 
569 	record_stats(ctx, STATS_PASS);
570 	return XDP_PASS;
571 
572 abort:
573 	record_stats(ctx, STATS_ABORT);
574 	return XDP_ABORTED;
575 }
576 
xdp_prog_common(struct xdp_md * ctx)577 static int xdp_prog_common(struct xdp_md *ctx)
578 {
579 	__u32 key, *port;
580 	__s32 *mode;
581 
582 	key = XDP_MODE;
583 	mode = bpf_map_lookup_elem(&map_xdp_setup, &key);
584 	if (!mode)
585 		return XDP_PASS;
586 
587 	key = XDP_PORT;
588 	port = bpf_map_lookup_elem(&map_xdp_setup, &key);
589 	if (!port)
590 		return XDP_PASS;
591 
592 	switch (*mode) {
593 	case XDP_MODE_PASS:
594 		return xdp_mode_pass(ctx, (__u16)(*port));
595 	case XDP_MODE_DROP:
596 		return xdp_mode_drop_handler(ctx, (__u16)(*port));
597 	case XDP_MODE_TX:
598 		return xdp_mode_tx_handler(ctx, (__u16)(*port));
599 	case XDP_MODE_TAIL_ADJST:
600 		return xdp_adjst_tail(ctx, (__u16)(*port));
601 	case XDP_MODE_HEAD_ADJST:
602 		return xdp_head_adjst(ctx, (__u16)(*port));
603 	}
604 
605 	/* Default action is to simple pass */
606 	return XDP_PASS;
607 }
608 
609 SEC("xdp")
xdp_prog(struct xdp_md * ctx)610 int xdp_prog(struct xdp_md *ctx)
611 {
612 	return xdp_prog_common(ctx);
613 }
614 
615 SEC("xdp.frags")
xdp_prog_frags(struct xdp_md * ctx)616 int xdp_prog_frags(struct xdp_md *ctx)
617 {
618 	return xdp_prog_common(ctx);
619 }
620 
621 char _license[] SEC("license") = "GPL";
622