1 // SPDX-License-Identifier: GPL-2.0
2
3 #include "vmlinux.h"
4
5 #include <bpf/bpf_helpers.h>
6 #include <bpf/bpf_tracing.h>
7 #include "bpf_tracing_net.h"
8
9 char _license[] SEC("license") = "GPL";
10
11 enum {
12 BPF_SMC_LISTEN = 10,
13 };
14
15 struct smc_sock___local {
16 struct sock sk;
17 struct smc_sock *listen_smc;
18 bool use_fallback;
19 } __attribute__((preserve_access_index));
20
21 int smc_cnt = 0;
22 int fallback_cnt = 0;
23
24 SEC("fentry/smc_release")
BPF_PROG(bpf_smc_release,struct socket * sock)25 int BPF_PROG(bpf_smc_release, struct socket *sock)
26 {
27 /* only count from one side (client) */
28 if (sock->sk->__sk_common.skc_state == BPF_SMC_LISTEN)
29 return 0;
30 smc_cnt++;
31 return 0;
32 }
33
34 SEC("fentry/smc_switch_to_fallback")
BPF_PROG(bpf_smc_switch_to_fallback,struct smc_sock___local * smc)35 int BPF_PROG(bpf_smc_switch_to_fallback, struct smc_sock___local *smc)
36 {
37 /* only count from one side (client) */
38 if (smc && !smc->listen_smc)
39 fallback_cnt++;
40 return 0;
41 }
42
43 /* go with default value if no strat was found */
44 bool default_ip_strat_value = true;
45
46 struct smc_policy_ip_key {
47 __u32 sip;
48 __u32 dip;
49 };
50
51 struct smc_policy_ip_value {
52 __u8 mode;
53 };
54
55 struct {
56 __uint(type, BPF_MAP_TYPE_HASH);
57 __uint(key_size, sizeof(struct smc_policy_ip_key));
58 __uint(value_size, sizeof(struct smc_policy_ip_value));
59 __uint(max_entries, 128);
60 __uint(map_flags, BPF_F_NO_PREALLOC);
61 } smc_policy_ip SEC(".maps");
62
smc_check(__u32 src,__u32 dst)63 static bool smc_check(__u32 src, __u32 dst)
64 {
65 struct smc_policy_ip_value *value;
66 struct smc_policy_ip_key key = {
67 .sip = src,
68 .dip = dst,
69 };
70
71 value = bpf_map_lookup_elem(&smc_policy_ip, &key);
72 return value ? value->mode : default_ip_strat_value;
73 }
74
75 SEC("fmod_ret/update_socket_protocol")
BPF_PROG(smc_run,int family,int type,int protocol)76 int BPF_PROG(smc_run, int family, int type, int protocol)
77 {
78 struct task_struct *task;
79
80 if (family != AF_INET && family != AF_INET6)
81 return protocol;
82
83 if ((type & 0xf) != SOCK_STREAM)
84 return protocol;
85
86 if (protocol != 0 && protocol != IPPROTO_TCP)
87 return protocol;
88
89 task = bpf_get_current_task_btf();
90 /* Prevent from affecting other tests */
91 if (!task || !task->nsproxy->net_ns->smc.hs_ctrl)
92 return protocol;
93
94 return IPPROTO_SMC;
95 }
96
97 SEC("struct_ops")
BPF_PROG(bpf_smc_set_tcp_option_cond,const struct tcp_sock * tp,struct inet_request_sock * ireq)98 int BPF_PROG(bpf_smc_set_tcp_option_cond, const struct tcp_sock *tp,
99 struct inet_request_sock *ireq)
100 {
101 return smc_check(ireq->req.__req_common.skc_daddr,
102 ireq->req.__req_common.skc_rcv_saddr);
103 }
104
105 SEC("struct_ops")
BPF_PROG(bpf_smc_set_tcp_option,struct tcp_sock * tp)106 int BPF_PROG(bpf_smc_set_tcp_option, struct tcp_sock *tp)
107 {
108 return smc_check(tp->inet_conn.icsk_inet.sk.__sk_common.skc_rcv_saddr,
109 tp->inet_conn.icsk_inet.sk.__sk_common.skc_daddr);
110 }
111
112 SEC(".struct_ops")
113 struct smc_hs_ctrl linkcheck = {
114 .name = "linkcheck",
115 .syn_option = (void *)bpf_smc_set_tcp_option,
116 .synack_option = (void *)bpf_smc_set_tcp_option_cond,
117 };
118