xref: /linux/net/ipv6/inet6_hashtables.c (revision a0b0f6c7d7f29f1ade9ec59699d02e3b153ee8e4)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * INET		An implementation of the TCP/IP protocol suite for the LINUX
4  *		operating system.  INET is implemented using the BSD Socket
5  *		interface as the means of communication with the user level.
6  *
7  *		Generic INET6 transport hashtables
8  *
9  * Authors:	Lotsa people, from code originally in tcp, generalised here
10  *		by Arnaldo Carvalho de Melo <acme@mandriva.com>
11  */
12 
13 #include <linux/module.h>
14 #include <linux/random.h>
15 
16 #include <net/addrconf.h>
17 #include <net/hotdata.h>
18 #include <net/inet_connection_sock.h>
19 #include <net/inet_hashtables.h>
20 #include <net/inet6_hashtables.h>
21 #include <net/secure_seq.h>
22 #include <net/ip.h>
23 #include <net/sock_reuseport.h>
24 #include <net/tcp.h>
25 
inet6_init_ehash_secret(void)26 void inet6_init_ehash_secret(void)
27 {
28 	net_get_random_sleepable_once(&inet6_ehash_secret,
29 				      sizeof(inet6_ehash_secret));
30 	net_get_random_sleepable_once(&tcp_ipv6_hash_secret,
31 				      sizeof(tcp_ipv6_hash_secret));
32 }
33 
inet6_ehashfn(const struct net * net,const struct in6_addr * laddr,const u16 lport,const struct in6_addr * faddr,const __be16 fport)34 u32 inet6_ehashfn(const struct net *net,
35 		  const struct in6_addr *laddr, const u16 lport,
36 		  const struct in6_addr *faddr, const __be16 fport)
37 {
38 	u32 a, b, c;
39 
40 	/*
41 	 * Please look at jhash() implementation for reference.
42 	 * Hash laddr + faddr + lport/fport + net_hash_mix.
43 	 * Notes:
44 	 * We combine laddr[0] (high order 32 bits of local address)
45 	 * with net_hash_mix() to hash a multiple of 3 words.
46 	 *
47 	 * We do not include JHASH_INITVAL + 36 contribution
48 	 * to initial values of a, b, c.
49 	 */
50 
51 	a = b = c = tcp_ipv6_hash_secret;
52 
53 	a += (__force u32)laddr->s6_addr32[0] ^ net_hash_mix(net);
54 	b += (__force u32)laddr->s6_addr32[1];
55 	c += (__force u32)laddr->s6_addr32[2];
56 	__jhash_mix(a, b, c);
57 
58 	a += (__force u32)laddr->s6_addr32[3];
59 	b += (__force u32)faddr->s6_addr32[0];
60 	c += (__force u32)faddr->s6_addr32[1];
61 	__jhash_mix(a, b, c);
62 
63 	a += (__force u32)faddr->s6_addr32[2];
64 	b += (__force u32)faddr->s6_addr32[3];
65 	c += (__force u32)fport;
66 	__jhash_final(a, b, c);
67 
68 	/* Note: We need to add @lport instead of fully hashing it.
69 	 * See commits 9544d60a2605 ("inet: change lport contribution
70 	 * to inet_ehashfn() and inet6_ehashfn()") and d4438ce68bf1
71 	 * ("inet: call inet6_ehashfn() once from inet6_hash_connect()")
72 	 * for references.
73 	 */
74 	return lport + c;
75 }
76 EXPORT_SYMBOL_GPL(inet6_ehashfn);
77 
78 /*
79  * Sockets in TCP_CLOSE state are _always_ taken out of the hash, so
80  * we need not check it for TCP lookups anymore, thanks Alexey. -DaveM
81  *
82  * The sockhash lock must be held as a reader here.
83  */
__inet6_lookup_established(const struct net * net,const struct in6_addr * saddr,const __be16 sport,const struct in6_addr * daddr,const u16 hnum,const int dif,const int sdif)84 struct sock *__inet6_lookup_established(const struct net *net,
85 					const struct in6_addr *saddr,
86 					const __be16 sport,
87 					const struct in6_addr *daddr,
88 					const u16 hnum,
89 					const int dif, const int sdif)
90 {
91 	const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
92 	const struct hlist_nulls_node *node;
93 	struct inet_ehash_bucket *head;
94 	struct inet_hashinfo *hashinfo;
95 	unsigned int hash, slot;
96 	struct sock *sk;
97 
98 	hashinfo = net->ipv4.tcp_death_row.hashinfo;
99 	hash = inet6_ehashfn(net, daddr, hnum, saddr, sport);
100 	slot = hash & hashinfo->ehash_mask;
101 	head = &hashinfo->ehash[slot];
102 begin:
103 	sk_nulls_for_each_rcu(sk, node, &head->chain) {
104 		if (sk->sk_hash != hash)
105 			continue;
106 		if (!inet6_match(net, sk, saddr, daddr, ports, dif, sdif))
107 			continue;
108 		if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
109 			goto out;
110 
111 		if (unlikely(!inet6_match(net, sk, saddr, daddr, ports, dif, sdif))) {
112 			sock_gen_put(sk);
113 			goto begin;
114 		}
115 		goto found;
116 	}
117 	if (get_nulls_value(node) != slot)
118 		goto begin;
119 out:
120 	sk = NULL;
121 found:
122 	return sk;
123 }
124 EXPORT_SYMBOL(__inet6_lookup_established);
125 
compute_score(struct sock * sk,const struct net * net,const unsigned short hnum,const struct in6_addr * daddr,const int dif,const int sdif)126 static inline int compute_score(struct sock *sk, const struct net *net,
127 				const unsigned short hnum,
128 				const struct in6_addr *daddr,
129 				const int dif, const int sdif)
130 {
131 	int score = -1;
132 
133 	if (net_eq(sock_net(sk), net) &&
134 	    READ_ONCE(inet_sk(sk)->inet_num) == hnum &&
135 	    sk->sk_family == PF_INET6) {
136 		if (!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, daddr))
137 			return -1;
138 
139 		if (!inet_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif))
140 			return -1;
141 
142 		score =  sk->sk_bound_dev_if ? 2 : 1;
143 		if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
144 			score++;
145 	}
146 	return score;
147 }
148 
149 /**
150  * inet6_lookup_reuseport() - execute reuseport logic on AF_INET6 socket if necessary.
151  * @net: network namespace.
152  * @sk: AF_INET6 socket, must be in TCP_LISTEN state for TCP or TCP_CLOSE for UDP.
153  * @skb: context for a potential SK_REUSEPORT program.
154  * @doff: header offset.
155  * @saddr: source address.
156  * @sport: source port.
157  * @daddr: destination address.
158  * @hnum: destination port in host byte order.
159  * @ehashfn: hash function used to generate the fallback hash.
160  *
161  * Return: NULL if sk doesn't have SO_REUSEPORT set, otherwise a pointer to
162  *         the selected sock or an error.
163  */
inet6_lookup_reuseport(const struct net * net,struct sock * sk,struct sk_buff * skb,int doff,const struct in6_addr * saddr,__be16 sport,const struct in6_addr * daddr,unsigned short hnum,inet6_ehashfn_t * ehashfn)164 struct sock *inet6_lookup_reuseport(const struct net *net, struct sock *sk,
165 				    struct sk_buff *skb, int doff,
166 				    const struct in6_addr *saddr,
167 				    __be16 sport,
168 				    const struct in6_addr *daddr,
169 				    unsigned short hnum,
170 				    inet6_ehashfn_t *ehashfn)
171 {
172 	struct sock *reuse_sk = NULL;
173 	u32 phash;
174 
175 	if (sk->sk_reuseport) {
176 		phash = INDIRECT_CALL_INET(ehashfn, udp6_ehashfn, inet6_ehashfn,
177 					   net, daddr, hnum, saddr, sport);
178 		reuse_sk = reuseport_select_sock(sk, phash, skb, doff);
179 	}
180 	return reuse_sk;
181 }
182 EXPORT_SYMBOL_GPL(inet6_lookup_reuseport);
183 
184 /* called with rcu_read_lock() */
inet6_lhash2_lookup(const struct net * net,struct inet_listen_hashbucket * ilb2,struct sk_buff * skb,int doff,const struct in6_addr * saddr,const __be16 sport,const struct in6_addr * daddr,const unsigned short hnum,const int dif,const int sdif)185 static struct sock *inet6_lhash2_lookup(const struct net *net,
186 		struct inet_listen_hashbucket *ilb2,
187 		struct sk_buff *skb, int doff,
188 		const struct in6_addr *saddr,
189 		const __be16 sport, const struct in6_addr *daddr,
190 		const unsigned short hnum, const int dif, const int sdif)
191 {
192 	struct sock *sk, *result = NULL;
193 	struct hlist_nulls_node *node;
194 	int score, hiscore = 0;
195 
196 	sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) {
197 		score = compute_score(sk, net, hnum, daddr, dif, sdif);
198 		if (score > hiscore) {
199 			result = inet6_lookup_reuseport(net, sk, skb, doff,
200 							saddr, sport, daddr, hnum, inet6_ehashfn);
201 			if (result)
202 				return result;
203 
204 			result = sk;
205 			hiscore = score;
206 		}
207 	}
208 
209 	return result;
210 }
211 
inet6_lookup_run_sk_lookup(const struct net * net,int protocol,struct sk_buff * skb,int doff,const struct in6_addr * saddr,const __be16 sport,const struct in6_addr * daddr,const u16 hnum,const int dif,inet6_ehashfn_t * ehashfn)212 struct sock *inet6_lookup_run_sk_lookup(const struct net *net,
213 					int protocol,
214 					struct sk_buff *skb, int doff,
215 					const struct in6_addr *saddr,
216 					const __be16 sport,
217 					const struct in6_addr *daddr,
218 					const u16 hnum, const int dif,
219 					inet6_ehashfn_t *ehashfn)
220 {
221 	struct sock *sk, *reuse_sk;
222 	bool no_reuseport;
223 
224 	no_reuseport = bpf_sk_lookup_run_v6(net, protocol, saddr, sport,
225 					    daddr, hnum, dif, &sk);
226 	if (no_reuseport || IS_ERR_OR_NULL(sk))
227 		return sk;
228 
229 	reuse_sk = inet6_lookup_reuseport(net, sk, skb, doff,
230 					  saddr, sport, daddr, hnum, ehashfn);
231 	if (reuse_sk)
232 		sk = reuse_sk;
233 	return sk;
234 }
235 EXPORT_SYMBOL_GPL(inet6_lookup_run_sk_lookup);
236 
inet6_lookup_listener(const struct net * net,struct sk_buff * skb,int doff,const struct in6_addr * saddr,const __be16 sport,const struct in6_addr * daddr,const unsigned short hnum,const int dif,const int sdif)237 struct sock *inet6_lookup_listener(const struct net *net,
238 				   struct sk_buff *skb, int doff,
239 				   const struct in6_addr *saddr,
240 				   const __be16 sport,
241 				   const struct in6_addr *daddr,
242 				   const unsigned short hnum,
243 				   const int dif, const int sdif)
244 {
245 	struct inet_listen_hashbucket *ilb2;
246 	struct inet_hashinfo *hashinfo;
247 	struct sock *result = NULL;
248 	unsigned int hash2;
249 
250 	/* Lookup redirect from BPF */
251 	if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
252 		result = inet6_lookup_run_sk_lookup(net, IPPROTO_TCP, skb, doff,
253 						    saddr, sport, daddr, hnum, dif,
254 						    inet6_ehashfn);
255 		if (result)
256 			goto done;
257 	}
258 
259 	hashinfo = net->ipv4.tcp_death_row.hashinfo;
260 	hash2 = ipv6_portaddr_hash(net, daddr, hnum);
261 	ilb2 = inet_lhash2_bucket(hashinfo, hash2);
262 
263 	result = inet6_lhash2_lookup(net, ilb2, skb, doff,
264 				     saddr, sport, daddr, hnum,
265 				     dif, sdif);
266 	if (result)
267 		goto done;
268 
269 	/* Lookup lhash2 with in6addr_any */
270 	hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);
271 	ilb2 = inet_lhash2_bucket(hashinfo, hash2);
272 
273 	result = inet6_lhash2_lookup(net, ilb2, skb, doff,
274 				     saddr, sport, &in6addr_any, hnum,
275 				     dif, sdif);
276 done:
277 	if (IS_ERR(result))
278 		return NULL;
279 	return result;
280 }
281 EXPORT_SYMBOL_GPL(inet6_lookup_listener);
282 
inet6_lookup(const struct net * net,struct sk_buff * skb,int doff,const struct in6_addr * saddr,const __be16 sport,const struct in6_addr * daddr,const __be16 dport,const int dif)283 struct sock *inet6_lookup(const struct net *net,
284 			  struct sk_buff *skb, int doff,
285 			  const struct in6_addr *saddr, const __be16 sport,
286 			  const struct in6_addr *daddr, const __be16 dport,
287 			  const int dif)
288 {
289 	struct sock *sk;
290 	bool refcounted;
291 
292 	sk = __inet6_lookup(net, skb, doff, saddr, sport, daddr,
293 			    ntohs(dport), dif, 0, &refcounted);
294 	if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
295 		sk = NULL;
296 	return sk;
297 }
298 EXPORT_SYMBOL_GPL(inet6_lookup);
299 
__inet6_check_established(struct inet_timewait_death_row * death_row,struct sock * sk,const __u16 lport,struct inet_timewait_sock ** twp,bool rcu_lookup,u32 hash)300 static int __inet6_check_established(struct inet_timewait_death_row *death_row,
301 				     struct sock *sk, const __u16 lport,
302 				     struct inet_timewait_sock **twp,
303 				     bool rcu_lookup,
304 				     u32 hash)
305 {
306 	struct inet_hashinfo *hinfo = death_row->hashinfo;
307 	struct inet_sock *inet = inet_sk(sk);
308 	const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr;
309 	const struct in6_addr *saddr = &sk->sk_v6_daddr;
310 	const int dif = sk->sk_bound_dev_if;
311 	struct net *net = sock_net(sk);
312 	const int sdif = l3mdev_master_ifindex_by_index(net, dif);
313 	const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
314 	struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
315 	struct inet_timewait_sock *tw = NULL;
316 	const struct hlist_nulls_node *node;
317 	struct sock *sk2;
318 	spinlock_t *lock;
319 
320 	if (rcu_lookup) {
321 		sk_nulls_for_each(sk2, node, &head->chain) {
322 			if (sk2->sk_hash != hash ||
323 			    !inet6_match(net, sk2, saddr, daddr,
324 					 ports, dif, sdif))
325 				continue;
326 			if (sk2->sk_state == TCP_TIME_WAIT)
327 				break;
328 			return -EADDRNOTAVAIL;
329 		}
330 		return 0;
331 	}
332 
333 	lock = inet_ehash_lockp(hinfo, hash);
334 	spin_lock(lock);
335 
336 	sk_nulls_for_each(sk2, node, &head->chain) {
337 		if (sk2->sk_hash != hash)
338 			continue;
339 
340 		if (likely(inet6_match(net, sk2, saddr, daddr, ports,
341 				       dif, sdif))) {
342 			if (sk2->sk_state == TCP_TIME_WAIT) {
343 				tw = inet_twsk(sk2);
344 				if (tcp_twsk_unique(sk, sk2, twp))
345 					break;
346 			}
347 			goto not_unique;
348 		}
349 	}
350 
351 	/* Must record num and sport now. Otherwise we will see
352 	 * in hash table socket with a funny identity.
353 	 */
354 	inet->inet_num = lport;
355 	inet->inet_sport = htons(lport);
356 	sk->sk_hash = hash;
357 	WARN_ON(!sk_unhashed(sk));
358 	__sk_nulls_add_node_rcu(sk, &head->chain);
359 	if (tw) {
360 		sk_nulls_del_node_init_rcu((struct sock *)tw);
361 		__NET_INC_STATS(net, LINUX_MIB_TIMEWAITRECYCLED);
362 	}
363 	spin_unlock(lock);
364 	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
365 
366 	if (twp) {
367 		*twp = tw;
368 	} else if (tw) {
369 		/* Silly. Should hash-dance instead... */
370 		inet_twsk_deschedule_put(tw);
371 	}
372 	return 0;
373 
374 not_unique:
375 	spin_unlock(lock);
376 	return -EADDRNOTAVAIL;
377 }
378 
inet6_sk_port_offset(const struct sock * sk)379 static u64 inet6_sk_port_offset(const struct sock *sk)
380 {
381 	const struct inet_sock *inet = inet_sk(sk);
382 
383 	return secure_ipv6_port_ephemeral(sk->sk_v6_rcv_saddr.s6_addr32,
384 					  sk->sk_v6_daddr.s6_addr32,
385 					  inet->inet_dport);
386 }
387 
inet6_hash_connect(struct inet_timewait_death_row * death_row,struct sock * sk)388 int inet6_hash_connect(struct inet_timewait_death_row *death_row,
389 		       struct sock *sk)
390 {
391 	const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr;
392 	const struct in6_addr *saddr = &sk->sk_v6_daddr;
393 	const struct inet_sock *inet = inet_sk(sk);
394 	const struct net *net = sock_net(sk);
395 	u64 port_offset = 0;
396 	u32 hash_port0;
397 
398 	if (!inet_sk(sk)->inet_num)
399 		port_offset = inet6_sk_port_offset(sk);
400 
401 	inet6_init_ehash_secret();
402 
403 	hash_port0 = inet6_ehashfn(net, daddr, 0, saddr, inet->inet_dport);
404 
405 	return __inet_hash_connect(death_row, sk, port_offset, hash_port0,
406 				   __inet6_check_established);
407 }
408 EXPORT_SYMBOL_GPL(inet6_hash_connect);
409