1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2021 Cong Wang <cong.wang@bytedance.com> */ 3 4 #include <linux/bpf.h> 5 #include <linux/skmsg.h> 6 #include <net/af_unix.h> 7 8 #include "af_unix.h" 9 10 #define unix_sk_has_data(__sk, __psock) \ 11 ({ !skb_queue_empty(&__sk->sk_receive_queue) || \ 12 !skb_queue_empty(&__psock->ingress_skb) || \ 13 !list_empty(&__psock->ingress_msg); \ 14 }) 15 16 static int unix_msg_wait_data(struct sock *sk, struct sk_psock *psock, 17 long timeo) 18 { 19 DEFINE_WAIT_FUNC(wait, woken_wake_function); 20 struct unix_sock *u = unix_sk(sk); 21 int ret = 0; 22 23 if (sk->sk_shutdown & RCV_SHUTDOWN) 24 return 1; 25 26 if (!timeo) 27 return ret; 28 29 add_wait_queue(sk_sleep(sk), &wait); 30 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 31 if (!unix_sk_has_data(sk, psock)) { 32 mutex_unlock(&u->iolock); 33 wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 34 mutex_lock(&u->iolock); 35 ret = unix_sk_has_data(sk, psock); 36 } 37 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 38 remove_wait_queue(sk_sleep(sk), &wait); 39 return ret; 40 } 41 42 static int __unix_recvmsg(struct sock *sk, struct msghdr *msg, 43 size_t len, int flags) 44 { 45 if (sk->sk_type == SOCK_DGRAM) 46 return __unix_dgram_recvmsg(sk, msg, len, flags); 47 else 48 return __unix_stream_recvmsg(sk, msg, len, flags); 49 } 50 51 static int unix_bpf_recvmsg(struct sock *sk, struct msghdr *msg, 52 size_t len, int flags, int *addr_len) 53 { 54 struct unix_sock *u = unix_sk(sk); 55 struct sk_psock *psock; 56 int copied; 57 58 if (flags & MSG_OOB) 59 return -EOPNOTSUPP; 60 61 if (!len) 62 return 0; 63 64 psock = sk_psock_get(sk); 65 if (unlikely(!psock)) 66 return __unix_recvmsg(sk, msg, len, flags); 67 68 mutex_lock(&u->iolock); 69 if (!skb_queue_empty(&sk->sk_receive_queue) && 70 sk_psock_queue_empty(psock)) { 71 mutex_unlock(&u->iolock); 72 sk_psock_put(sk, psock); 73 return __unix_recvmsg(sk, msg, len, flags); 74 } 75 76 msg_bytes_ready: 77 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 78 if (!copied) { 79 long timeo; 80 int data; 81 82 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 83 data = unix_msg_wait_data(sk, psock, timeo); 84 if (data) { 85 if (!sk_psock_queue_empty(psock)) 86 goto msg_bytes_ready; 87 mutex_unlock(&u->iolock); 88 sk_psock_put(sk, psock); 89 return __unix_recvmsg(sk, msg, len, flags); 90 } 91 copied = -EAGAIN; 92 } 93 mutex_unlock(&u->iolock); 94 sk_psock_put(sk, psock); 95 return copied; 96 } 97 98 static struct proto *unix_dgram_prot_saved __read_mostly; 99 static DEFINE_SPINLOCK(unix_dgram_prot_lock); 100 static struct proto unix_dgram_bpf_prot; 101 102 static struct proto *unix_stream_prot_saved __read_mostly; 103 static DEFINE_SPINLOCK(unix_stream_prot_lock); 104 static struct proto unix_stream_bpf_prot; 105 106 static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 107 { 108 *prot = *base; 109 prot->close = sock_map_close; 110 prot->recvmsg = unix_bpf_recvmsg; 111 prot->sock_is_readable = sk_msg_is_readable; 112 } 113 114 static void unix_stream_bpf_rebuild_protos(struct proto *prot, 115 const struct proto *base) 116 { 117 *prot = *base; 118 prot->close = sock_map_close; 119 prot->recvmsg = unix_bpf_recvmsg; 120 prot->sock_is_readable = sk_msg_is_readable; 121 prot->unhash = sock_map_unhash; 122 } 123 124 static void unix_dgram_bpf_check_needs_rebuild(struct proto *ops) 125 { 126 if (unlikely(ops != smp_load_acquire(&unix_dgram_prot_saved))) { 127 spin_lock_bh(&unix_dgram_prot_lock); 128 if (likely(ops != unix_dgram_prot_saved)) { 129 unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, ops); 130 smp_store_release(&unix_dgram_prot_saved, ops); 131 } 132 spin_unlock_bh(&unix_dgram_prot_lock); 133 } 134 } 135 136 static void unix_stream_bpf_check_needs_rebuild(struct proto *ops) 137 { 138 if (unlikely(ops != smp_load_acquire(&unix_stream_prot_saved))) { 139 spin_lock_bh(&unix_stream_prot_lock); 140 if (likely(ops != unix_stream_prot_saved)) { 141 unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, ops); 142 smp_store_release(&unix_stream_prot_saved, ops); 143 } 144 spin_unlock_bh(&unix_stream_prot_lock); 145 } 146 } 147 148 int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 149 { 150 if (sk->sk_type != SOCK_DGRAM) 151 return -EOPNOTSUPP; 152 153 if (restore) { 154 sk->sk_write_space = psock->saved_write_space; 155 sock_replace_proto(sk, psock->sk_proto); 156 return 0; 157 } 158 159 unix_dgram_bpf_check_needs_rebuild(psock->sk_proto); 160 sock_replace_proto(sk, &unix_dgram_bpf_prot); 161 return 0; 162 } 163 164 int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 165 { 166 struct sock *sk_pair; 167 168 /* Restore does not decrement the sk_pair reference yet because we must 169 * keep the a reference to the socket until after an RCU grace period 170 * and any pending sends have completed. 171 */ 172 if (restore) { 173 sk->sk_write_space = psock->saved_write_space; 174 sock_replace_proto(sk, psock->sk_proto); 175 return 0; 176 } 177 178 /* psock_update_sk_prot can be called multiple times if psock is 179 * added to multiple maps and/or slots in the same map. There is 180 * also an edge case where replacing a psock with itself can trigger 181 * an extra psock_update_sk_prot during the insert process. So it 182 * must be safe to do multiple calls. Here we need to ensure we don't 183 * increment the refcnt through sock_hold many times. There will only 184 * be a single matching destroy operation. 185 */ 186 if (!psock->sk_pair) { 187 sk_pair = unix_peer(sk); 188 sock_hold(sk_pair); 189 psock->sk_pair = sk_pair; 190 } 191 192 unix_stream_bpf_check_needs_rebuild(psock->sk_proto); 193 sock_replace_proto(sk, &unix_stream_bpf_prot); 194 return 0; 195 } 196 197 void __init unix_bpf_build_proto(void) 198 { 199 unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, &unix_dgram_proto); 200 unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, &unix_stream_proto); 201 202 } 203