1 // SPDX-License-Identifier: GPL-2.0
2
3 #include <stddef.h>
4 #include <linux/bpf.h>
5 #include <linux/in.h>
6 #include <linux/if_ether.h>
7 #include <linux/ip.h>
8 #include <linux/ipv6.h>
9 #include <linux/udp.h>
10 #include <bpf/bpf_endian.h>
11 #include <bpf/bpf_helpers.h>
12
13 #define MAX_ADJST_OFFSET 256
14 #define MAX_PAYLOAD_LEN 5000
15 #define MAX_HDR_LEN 64
16
17 enum {
18 XDP_MODE = 0,
19 XDP_PORT = 1,
20 XDP_ADJST_OFFSET = 2,
21 XDP_ADJST_TAG = 3,
22 } xdp_map_setup_keys;
23
24 enum {
25 XDP_MODE_PASS = 0,
26 XDP_MODE_DROP = 1,
27 XDP_MODE_TX = 2,
28 XDP_MODE_TAIL_ADJST = 3,
29 XDP_MODE_HEAD_ADJST = 4,
30 } xdp_map_modes;
31
32 enum {
33 STATS_RX = 0,
34 STATS_PASS = 1,
35 STATS_DROP = 2,
36 STATS_TX = 3,
37 STATS_ABORT = 4,
38 } xdp_stats;
39
40 struct {
41 __uint(type, BPF_MAP_TYPE_ARRAY);
42 __uint(max_entries, 5);
43 __type(key, __u32);
44 __type(value, __s32);
45 } map_xdp_setup SEC(".maps");
46
47 struct {
48 __uint(type, BPF_MAP_TYPE_ARRAY);
49 __uint(max_entries, 5);
50 __type(key, __u32);
51 __type(value, __u64);
52 } map_xdp_stats SEC(".maps");
53
min(__u32 a,__u32 b)54 static __u32 min(__u32 a, __u32 b)
55 {
56 return a < b ? a : b;
57 }
58
record_stats(struct xdp_md * ctx,__u32 stat_type)59 static void record_stats(struct xdp_md *ctx, __u32 stat_type)
60 {
61 __u64 *count;
62
63 count = bpf_map_lookup_elem(&map_xdp_stats, &stat_type);
64
65 if (count)
66 __sync_fetch_and_add(count, 1);
67 }
68
filter_udphdr(struct xdp_md * ctx,__u16 port)69 static struct udphdr *filter_udphdr(struct xdp_md *ctx, __u16 port)
70 {
71 void *data_end = (void *)(long)ctx->data_end;
72 void *data = (void *)(long)ctx->data;
73 struct udphdr *udph = NULL;
74 struct ethhdr *eth = data;
75
76 if (data + sizeof(*eth) > data_end)
77 return NULL;
78
79 if (eth->h_proto == bpf_htons(ETH_P_IP)) {
80 struct iphdr *iph = data + sizeof(*eth);
81
82 if (iph + 1 > (struct iphdr *)data_end ||
83 iph->protocol != IPPROTO_UDP)
84 return NULL;
85
86 udph = (void *)eth + sizeof(*iph) + sizeof(*eth);
87 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
88 struct ipv6hdr *ipv6h = data + sizeof(*eth);
89
90 if (ipv6h + 1 > (struct ipv6hdr *)data_end ||
91 ipv6h->nexthdr != IPPROTO_UDP)
92 return NULL;
93
94 udph = (void *)eth + sizeof(*ipv6h) + sizeof(*eth);
95 } else {
96 return NULL;
97 }
98
99 if (udph + 1 > (struct udphdr *)data_end)
100 return NULL;
101
102 if (udph->dest != bpf_htons(port))
103 return NULL;
104
105 record_stats(ctx, STATS_RX);
106
107 return udph;
108 }
109
xdp_mode_pass(struct xdp_md * ctx,__u16 port)110 static int xdp_mode_pass(struct xdp_md *ctx, __u16 port)
111 {
112 struct udphdr *udph = NULL;
113
114 udph = filter_udphdr(ctx, port);
115 if (!udph)
116 return XDP_PASS;
117
118 record_stats(ctx, STATS_PASS);
119
120 return XDP_PASS;
121 }
122
xdp_mode_drop_handler(struct xdp_md * ctx,__u16 port)123 static int xdp_mode_drop_handler(struct xdp_md *ctx, __u16 port)
124 {
125 struct udphdr *udph = NULL;
126
127 udph = filter_udphdr(ctx, port);
128 if (!udph)
129 return XDP_PASS;
130
131 record_stats(ctx, STATS_DROP);
132
133 return XDP_DROP;
134 }
135
swap_machdr(void * data)136 static void swap_machdr(void *data)
137 {
138 struct ethhdr *eth = data;
139 __u8 tmp_mac[ETH_ALEN];
140
141 __builtin_memcpy(tmp_mac, eth->h_source, ETH_ALEN);
142 __builtin_memcpy(eth->h_source, eth->h_dest, ETH_ALEN);
143 __builtin_memcpy(eth->h_dest, tmp_mac, ETH_ALEN);
144 }
145
xdp_mode_tx_handler(struct xdp_md * ctx,__u16 port)146 static int xdp_mode_tx_handler(struct xdp_md *ctx, __u16 port)
147 {
148 void *data_end = (void *)(long)ctx->data_end;
149 void *data = (void *)(long)ctx->data;
150 struct udphdr *udph = NULL;
151 struct ethhdr *eth = data;
152
153 if (data + sizeof(*eth) > data_end)
154 return XDP_PASS;
155
156 if (eth->h_proto == bpf_htons(ETH_P_IP)) {
157 struct iphdr *iph = data + sizeof(*eth);
158 __be32 tmp_ip = iph->saddr;
159
160 if (iph + 1 > (struct iphdr *)data_end ||
161 iph->protocol != IPPROTO_UDP)
162 return XDP_PASS;
163
164 udph = data + sizeof(*iph) + sizeof(*eth);
165
166 if (udph + 1 > (struct udphdr *)data_end)
167 return XDP_PASS;
168 if (udph->dest != bpf_htons(port))
169 return XDP_PASS;
170
171 record_stats(ctx, STATS_RX);
172 swap_machdr((void *)eth);
173
174 iph->saddr = iph->daddr;
175 iph->daddr = tmp_ip;
176
177 record_stats(ctx, STATS_TX);
178
179 return XDP_TX;
180
181 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
182 struct ipv6hdr *ipv6h = data + sizeof(*eth);
183 struct in6_addr tmp_ipv6;
184
185 if (ipv6h + 1 > (struct ipv6hdr *)data_end ||
186 ipv6h->nexthdr != IPPROTO_UDP)
187 return XDP_PASS;
188
189 udph = data + sizeof(*ipv6h) + sizeof(*eth);
190
191 if (udph + 1 > (struct udphdr *)data_end)
192 return XDP_PASS;
193 if (udph->dest != bpf_htons(port))
194 return XDP_PASS;
195
196 record_stats(ctx, STATS_RX);
197 swap_machdr((void *)eth);
198
199 __builtin_memcpy(&tmp_ipv6, &ipv6h->saddr, sizeof(tmp_ipv6));
200 __builtin_memcpy(&ipv6h->saddr, &ipv6h->daddr,
201 sizeof(tmp_ipv6));
202 __builtin_memcpy(&ipv6h->daddr, &tmp_ipv6, sizeof(tmp_ipv6));
203
204 record_stats(ctx, STATS_TX);
205
206 return XDP_TX;
207 }
208
209 return XDP_PASS;
210 }
211
update_pkt(struct xdp_md * ctx,__s16 offset,__u32 * udp_csum)212 static void *update_pkt(struct xdp_md *ctx, __s16 offset, __u32 *udp_csum)
213 {
214 void *data_end = (void *)(long)ctx->data_end;
215 void *data = (void *)(long)ctx->data;
216 struct udphdr *udph = NULL;
217 struct ethhdr *eth = data;
218 __u32 len, len_new;
219
220 if (data + sizeof(*eth) > data_end)
221 return NULL;
222
223 if (eth->h_proto == bpf_htons(ETH_P_IP)) {
224 struct iphdr *iph = data + sizeof(*eth);
225 __u16 total_len;
226
227 if (iph + 1 > (struct iphdr *)data_end)
228 return NULL;
229
230 iph->tot_len = bpf_htons(bpf_ntohs(iph->tot_len) + offset);
231
232 udph = (void *)eth + sizeof(*iph) + sizeof(*eth);
233 if (!udph || udph + 1 > (struct udphdr *)data_end)
234 return NULL;
235
236 len_new = bpf_htons(bpf_ntohs(udph->len) + offset);
237 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
238 struct ipv6hdr *ipv6h = data + sizeof(*eth);
239 __u16 payload_len;
240
241 if (ipv6h + 1 > (struct ipv6hdr *)data_end)
242 return NULL;
243
244 udph = (void *)eth + sizeof(*ipv6h) + sizeof(*eth);
245 if (!udph || udph + 1 > (struct udphdr *)data_end)
246 return NULL;
247
248 *udp_csum = ~((__u32)udph->check);
249
250 len = ipv6h->payload_len;
251 len_new = bpf_htons(bpf_ntohs(len) + offset);
252 ipv6h->payload_len = len_new;
253
254 *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new,
255 sizeof(len_new), *udp_csum);
256
257 len = udph->len;
258 len_new = bpf_htons(bpf_ntohs(udph->len) + offset);
259 *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new,
260 sizeof(len_new), *udp_csum);
261 } else {
262 return NULL;
263 }
264
265 udph->len = len_new;
266
267 return udph;
268 }
269
csum_fold_helper(__u32 csum)270 static __u16 csum_fold_helper(__u32 csum)
271 {
272 return ~((csum & 0xffff) + (csum >> 16)) ? : 0xffff;
273 }
274
xdp_adjst_tail_shrnk_data(struct xdp_md * ctx,__u16 offset,__u32 hdr_len)275 static int xdp_adjst_tail_shrnk_data(struct xdp_md *ctx, __u16 offset,
276 __u32 hdr_len)
277 {
278 char tmp_buff[MAX_ADJST_OFFSET];
279 __u32 buff_pos, udp_csum = 0;
280 struct udphdr *udph = NULL;
281 __u32 buff_len;
282
283 udph = update_pkt(ctx, 0 - offset, &udp_csum);
284 if (!udph)
285 return -1;
286
287 buff_len = bpf_xdp_get_buff_len(ctx);
288
289 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
290 offset & 0xff;
291 if (offset == 0)
292 return -1;
293
294 /* Make sure we have enough data to avoid eating the header */
295 if (buff_len - offset < hdr_len)
296 return -1;
297
298 buff_pos = buff_len - offset;
299 if (bpf_xdp_load_bytes(ctx, buff_pos, tmp_buff, offset) < 0)
300 return -1;
301
302 udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum);
303 udph->check = (__u16)csum_fold_helper(udp_csum);
304
305 if (bpf_xdp_adjust_tail(ctx, 0 - offset) < 0)
306 return -1;
307
308 return 0;
309 }
310
xdp_adjst_tail_grow_data(struct xdp_md * ctx,__u16 offset)311 static int xdp_adjst_tail_grow_data(struct xdp_md *ctx, __u16 offset)
312 {
313 char tmp_buff[MAX_ADJST_OFFSET];
314 __u32 buff_pos, udp_csum = 0;
315 __u32 buff_len, hdr_len, key;
316 struct udphdr *udph;
317 __s32 *val;
318 __u8 tag;
319
320 /* Proceed to update the packet headers before attempting to adjuste
321 * the tail. Once the tail is adjusted we lose access to the offset
322 * amount of data at the end of the packet which is crucial to update
323 * the checksum.
324 * Since any failure beyond this would abort the packet, we should
325 * not worry about passing a packet up the stack with wrong headers
326 */
327 udph = update_pkt(ctx, offset, &udp_csum);
328 if (!udph)
329 return -1;
330
331 key = XDP_ADJST_TAG;
332 val = bpf_map_lookup_elem(&map_xdp_setup, &key);
333 if (!val)
334 return -1;
335
336 tag = (__u8)(*val);
337
338 for (int i = 0; i < MAX_ADJST_OFFSET; i++)
339 __builtin_memcpy(&tmp_buff[i], &tag, 1);
340
341 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
342 offset & 0xff;
343 if (offset == 0)
344 return -1;
345
346 udp_csum = bpf_csum_diff(0, 0, (__be32 *)tmp_buff, offset, udp_csum);
347 udph->check = (__u16)csum_fold_helper(udp_csum);
348
349 buff_len = bpf_xdp_get_buff_len(ctx);
350
351 if (bpf_xdp_adjust_tail(ctx, offset) < 0) {
352 bpf_printk("Failed to adjust tail\n");
353 return -1;
354 }
355
356 if (bpf_xdp_store_bytes(ctx, buff_len, tmp_buff, offset) < 0)
357 return -1;
358
359 return 0;
360 }
361
xdp_adjst_tail(struct xdp_md * ctx,__u16 port)362 static int xdp_adjst_tail(struct xdp_md *ctx, __u16 port)
363 {
364 void *data = (void *)(long)ctx->data;
365 struct udphdr *udph = NULL;
366 __s32 *adjust_offset, *val;
367 __u32 key, hdr_len;
368 void *offset_ptr;
369 __u8 tag;
370 int ret;
371
372 udph = filter_udphdr(ctx, port);
373 if (!udph)
374 return XDP_PASS;
375
376 hdr_len = (void *)udph - data + sizeof(struct udphdr);
377 key = XDP_ADJST_OFFSET;
378 adjust_offset = bpf_map_lookup_elem(&map_xdp_setup, &key);
379 if (!adjust_offset)
380 return XDP_PASS;
381
382 if (*adjust_offset < 0)
383 ret = xdp_adjst_tail_shrnk_data(ctx,
384 (__u16)(0 - *adjust_offset),
385 hdr_len);
386 else
387 ret = xdp_adjst_tail_grow_data(ctx, (__u16)(*adjust_offset));
388 if (ret)
389 goto abort_pkt;
390
391 record_stats(ctx, STATS_PASS);
392 return XDP_PASS;
393
394 abort_pkt:
395 record_stats(ctx, STATS_ABORT);
396 return XDP_ABORTED;
397 }
398
xdp_adjst_head_shrnk_data(struct xdp_md * ctx,__u64 hdr_len,__u32 offset)399 static int xdp_adjst_head_shrnk_data(struct xdp_md *ctx, __u64 hdr_len,
400 __u32 offset)
401 {
402 char tmp_buff[MAX_ADJST_OFFSET];
403 struct udphdr *udph;
404 void *offset_ptr;
405 __u32 udp_csum = 0;
406
407 /* Update the length information in the IP and UDP headers before
408 * adjusting the headroom. This simplifies accessing the relevant
409 * fields in the IP and UDP headers for fragmented packets. Any
410 * failure beyond this point will result in the packet being aborted,
411 * so we don't need to worry about incorrect length information for
412 * passed packets.
413 */
414 udph = update_pkt(ctx, (__s16)(0 - offset), &udp_csum);
415 if (!udph)
416 return -1;
417
418 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
419 offset & 0xff;
420 if (offset == 0)
421 return -1;
422
423 if (bpf_xdp_load_bytes(ctx, hdr_len, tmp_buff, offset) < 0)
424 return -1;
425
426 udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum);
427
428 udph->check = (__u16)csum_fold_helper(udp_csum);
429
430 if (bpf_xdp_load_bytes(ctx, 0, tmp_buff, MAX_ADJST_OFFSET) < 0)
431 return -1;
432
433 if (bpf_xdp_adjust_head(ctx, offset) < 0)
434 return -1;
435
436 if (offset > MAX_ADJST_OFFSET)
437 return -1;
438
439 if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0)
440 return -1;
441
442 /* Added here to handle clang complain about negative value */
443 hdr_len = hdr_len & 0xff;
444
445 if (hdr_len == 0)
446 return -1;
447
448 if (bpf_xdp_store_bytes(ctx, 0, tmp_buff, hdr_len) < 0)
449 return -1;
450
451 return 0;
452 }
453
xdp_adjst_head_grow_data(struct xdp_md * ctx,__u64 hdr_len,__u32 offset)454 static int xdp_adjst_head_grow_data(struct xdp_md *ctx, __u64 hdr_len,
455 __u32 offset)
456 {
457 char hdr_buff[MAX_HDR_LEN];
458 char data_buff[MAX_ADJST_OFFSET];
459 void *offset_ptr;
460 __s32 *val;
461 __u32 key;
462 __u8 tag;
463 __u32 udp_csum = 0;
464 struct udphdr *udph;
465
466 udph = update_pkt(ctx, (__s16)(offset), &udp_csum);
467 if (!udph)
468 return -1;
469
470 key = XDP_ADJST_TAG;
471 val = bpf_map_lookup_elem(&map_xdp_setup, &key);
472 if (!val)
473 return -1;
474
475 tag = (__u8)(*val);
476 for (int i = 0; i < MAX_ADJST_OFFSET; i++)
477 __builtin_memcpy(&data_buff[i], &tag, 1);
478
479 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
480 offset & 0xff;
481 if (offset == 0)
482 return -1;
483
484 udp_csum = bpf_csum_diff(0, 0, (__be32 *)data_buff, offset, udp_csum);
485 udph->check = (__u16)csum_fold_helper(udp_csum);
486
487 if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0)
488 return -1;
489
490 /* Added here to handle clang complain about negative value */
491 hdr_len = hdr_len & 0xff;
492
493 if (hdr_len == 0)
494 return -1;
495
496 if (bpf_xdp_load_bytes(ctx, 0, hdr_buff, hdr_len) < 0)
497 return -1;
498
499 if (offset > MAX_ADJST_OFFSET)
500 return -1;
501
502 if (bpf_xdp_adjust_head(ctx, 0 - offset) < 0)
503 return -1;
504
505 if (bpf_xdp_store_bytes(ctx, 0, hdr_buff, hdr_len) < 0)
506 return -1;
507
508 if (bpf_xdp_store_bytes(ctx, hdr_len, data_buff, offset) < 0)
509 return -1;
510
511 return 0;
512 }
513
xdp_head_adjst(struct xdp_md * ctx,__u16 port)514 static int xdp_head_adjst(struct xdp_md *ctx, __u16 port)
515 {
516 void *data_end = (void *)(long)ctx->data_end;
517 void *data = (void *)(long)ctx->data;
518 struct udphdr *udph_ptr = NULL;
519 __u32 key, size, hdr_len;
520 __s32 *val;
521 int res;
522
523 /* Filter packets based on UDP port */
524 udph_ptr = filter_udphdr(ctx, port);
525 if (!udph_ptr)
526 return XDP_PASS;
527
528 hdr_len = (void *)udph_ptr - data + sizeof(struct udphdr);
529
530 key = XDP_ADJST_OFFSET;
531 val = bpf_map_lookup_elem(&map_xdp_setup, &key);
532 if (!val)
533 return XDP_PASS;
534
535 switch (*val) {
536 case -16:
537 case 16:
538 size = 16;
539 break;
540 case -32:
541 case 32:
542 size = 32;
543 break;
544 case -64:
545 case 64:
546 size = 64;
547 break;
548 case -128:
549 case 128:
550 size = 128;
551 break;
552 case -256:
553 case 256:
554 size = 256;
555 break;
556 default:
557 bpf_printk("Invalid adjustment offset: %d\n", *val);
558 goto abort;
559 }
560
561 if (*val < 0)
562 res = xdp_adjst_head_grow_data(ctx, hdr_len, size);
563 else
564 res = xdp_adjst_head_shrnk_data(ctx, hdr_len, size);
565
566 if (res)
567 goto abort;
568
569 record_stats(ctx, STATS_PASS);
570 return XDP_PASS;
571
572 abort:
573 record_stats(ctx, STATS_ABORT);
574 return XDP_ABORTED;
575 }
576
xdp_prog_common(struct xdp_md * ctx)577 static int xdp_prog_common(struct xdp_md *ctx)
578 {
579 __u32 key, *port;
580 __s32 *mode;
581
582 key = XDP_MODE;
583 mode = bpf_map_lookup_elem(&map_xdp_setup, &key);
584 if (!mode)
585 return XDP_PASS;
586
587 key = XDP_PORT;
588 port = bpf_map_lookup_elem(&map_xdp_setup, &key);
589 if (!port)
590 return XDP_PASS;
591
592 switch (*mode) {
593 case XDP_MODE_PASS:
594 return xdp_mode_pass(ctx, (__u16)(*port));
595 case XDP_MODE_DROP:
596 return xdp_mode_drop_handler(ctx, (__u16)(*port));
597 case XDP_MODE_TX:
598 return xdp_mode_tx_handler(ctx, (__u16)(*port));
599 case XDP_MODE_TAIL_ADJST:
600 return xdp_adjst_tail(ctx, (__u16)(*port));
601 case XDP_MODE_HEAD_ADJST:
602 return xdp_head_adjst(ctx, (__u16)(*port));
603 }
604
605 /* Default action is to simple pass */
606 return XDP_PASS;
607 }
608
609 SEC("xdp")
xdp_prog(struct xdp_md * ctx)610 int xdp_prog(struct xdp_md *ctx)
611 {
612 return xdp_prog_common(ctx);
613 }
614
615 SEC("xdp.frags")
xdp_prog_frags(struct xdp_md * ctx)616 int xdp_prog_frags(struct xdp_md *ctx)
617 {
618 return xdp_prog_common(ctx);
619 }
620
621 char _license[] SEC("license") = "GPL";
622