1 // SPDX-License-Identifier: GPL-2.0
2 /*
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4 */
5
6 #include "netlink.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "socket.h"
10 #include "queueing.h"
11 #include "messages.h"
12
13 #include <uapi/linux/wireguard.h>
14
15 #include <linux/if.h>
16 #include <net/genetlink.h>
17 #include <net/sock.h>
18 #include <crypto/utils.h>
19
20 static struct genl_family genl_family;
21
22 static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = {
23 [WGDEVICE_A_IFINDEX] = { .type = NLA_U32 },
24 [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 },
25 [WGDEVICE_A_PRIVATE_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
26 [WGDEVICE_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
27 [WGDEVICE_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGDEVICE_F_ALL),
28 [WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16 },
29 [WGDEVICE_A_FWMARK] = { .type = NLA_U32 },
30 [WGDEVICE_A_PEERS] = { .type = NLA_NESTED }
31 };
32
33 static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
34 [WGPEER_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
35 [WGPEER_A_PRESHARED_KEY] = NLA_POLICY_EXACT_LEN(NOISE_SYMMETRIC_KEY_LEN),
36 [WGPEER_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGPEER_F_ALL),
37 [WGPEER_A_ENDPOINT] = NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)),
38 [WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL] = { .type = NLA_U16 },
39 [WGPEER_A_LAST_HANDSHAKE_TIME] = NLA_POLICY_EXACT_LEN(sizeof(struct __kernel_timespec)),
40 [WGPEER_A_RX_BYTES] = { .type = NLA_U64 },
41 [WGPEER_A_TX_BYTES] = { .type = NLA_U64 },
42 [WGPEER_A_ALLOWEDIPS] = { .type = NLA_NESTED },
43 [WGPEER_A_PROTOCOL_VERSION] = { .type = NLA_U32 }
44 };
45
46 static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = {
47 [WGALLOWEDIP_A_FAMILY] = { .type = NLA_U16 },
48 [WGALLOWEDIP_A_IPADDR] = NLA_POLICY_MIN_LEN(sizeof(struct in_addr)),
49 [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 },
50 [WGALLOWEDIP_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGALLOWEDIP_F_ALL),
51 };
52
lookup_interface(struct nlattr ** attrs,struct sk_buff * skb)53 static struct wg_device *lookup_interface(struct nlattr **attrs,
54 struct sk_buff *skb)
55 {
56 struct net_device *dev = NULL;
57
58 if (!attrs[WGDEVICE_A_IFINDEX] == !attrs[WGDEVICE_A_IFNAME])
59 return ERR_PTR(-EBADR);
60 if (attrs[WGDEVICE_A_IFINDEX])
61 dev = dev_get_by_index(sock_net(skb->sk),
62 nla_get_u32(attrs[WGDEVICE_A_IFINDEX]));
63 else if (attrs[WGDEVICE_A_IFNAME])
64 dev = dev_get_by_name(sock_net(skb->sk),
65 nla_data(attrs[WGDEVICE_A_IFNAME]));
66 if (!dev)
67 return ERR_PTR(-ENODEV);
68 if (!dev->rtnl_link_ops || !dev->rtnl_link_ops->kind ||
69 strcmp(dev->rtnl_link_ops->kind, KBUILD_MODNAME)) {
70 dev_put(dev);
71 return ERR_PTR(-EOPNOTSUPP);
72 }
73 return netdev_priv(dev);
74 }
75
get_allowedips(struct sk_buff * skb,const u8 * ip,u8 cidr,int family)76 static int get_allowedips(struct sk_buff *skb, const u8 *ip, u8 cidr,
77 int family)
78 {
79 struct nlattr *allowedip_nest;
80
81 allowedip_nest = nla_nest_start(skb, 0);
82 if (!allowedip_nest)
83 return -EMSGSIZE;
84
85 if (nla_put_u8(skb, WGALLOWEDIP_A_CIDR_MASK, cidr) ||
86 nla_put_u16(skb, WGALLOWEDIP_A_FAMILY, family) ||
87 nla_put(skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ?
88 sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) {
89 nla_nest_cancel(skb, allowedip_nest);
90 return -EMSGSIZE;
91 }
92
93 nla_nest_end(skb, allowedip_nest);
94 return 0;
95 }
96
97 struct dump_ctx {
98 struct wg_device *wg;
99 struct wg_peer *next_peer;
100 u64 allowedips_seq;
101 struct allowedips_node *next_allowedip;
102 };
103
104 #define DUMP_CTX(cb) ((struct dump_ctx *)(cb)->args)
105
106 static int
get_peer(struct wg_peer * peer,struct sk_buff * skb,struct dump_ctx * ctx)107 get_peer(struct wg_peer *peer, struct sk_buff *skb, struct dump_ctx *ctx)
108 {
109
110 struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, 0);
111 struct allowedips_node *allowedips_node = ctx->next_allowedip;
112 bool fail;
113
114 if (!peer_nest)
115 return -EMSGSIZE;
116
117 down_read(&peer->handshake.lock);
118 fail = nla_put(skb, WGPEER_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN,
119 peer->handshake.remote_static);
120 up_read(&peer->handshake.lock);
121 if (fail)
122 goto err;
123
124 if (!allowedips_node) {
125 const struct __kernel_timespec last_handshake = {
126 .tv_sec = peer->walltime_last_handshake.tv_sec,
127 .tv_nsec = peer->walltime_last_handshake.tv_nsec
128 };
129
130 down_read(&peer->handshake.lock);
131 fail = nla_put(skb, WGPEER_A_PRESHARED_KEY,
132 NOISE_SYMMETRIC_KEY_LEN,
133 peer->handshake.preshared_key);
134 up_read(&peer->handshake.lock);
135 if (fail)
136 goto err;
137
138 if (nla_put(skb, WGPEER_A_LAST_HANDSHAKE_TIME,
139 sizeof(last_handshake), &last_handshake) ||
140 nla_put_u16(skb, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL,
141 peer->persistent_keepalive_interval) ||
142 nla_put_u64_64bit(skb, WGPEER_A_TX_BYTES, peer->tx_bytes,
143 WGPEER_A_UNSPEC) ||
144 nla_put_u64_64bit(skb, WGPEER_A_RX_BYTES, peer->rx_bytes,
145 WGPEER_A_UNSPEC) ||
146 nla_put_u32(skb, WGPEER_A_PROTOCOL_VERSION, 1))
147 goto err;
148
149 read_lock_bh(&peer->endpoint_lock);
150 if (peer->endpoint.addr.sa_family == AF_INET)
151 fail = nla_put(skb, WGPEER_A_ENDPOINT,
152 sizeof(peer->endpoint.addr4),
153 &peer->endpoint.addr4);
154 else if (peer->endpoint.addr.sa_family == AF_INET6)
155 fail = nla_put(skb, WGPEER_A_ENDPOINT,
156 sizeof(peer->endpoint.addr6),
157 &peer->endpoint.addr6);
158 read_unlock_bh(&peer->endpoint_lock);
159 if (fail)
160 goto err;
161 allowedips_node =
162 list_first_entry_or_null(&peer->allowedips_list,
163 struct allowedips_node, peer_list);
164 }
165 if (!allowedips_node)
166 goto no_allowedips;
167 if (!ctx->allowedips_seq)
168 ctx->allowedips_seq = ctx->wg->peer_allowedips.seq;
169 else if (ctx->allowedips_seq != ctx->wg->peer_allowedips.seq)
170 goto no_allowedips;
171
172 allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS);
173 if (!allowedips_nest)
174 goto err;
175
176 list_for_each_entry_from(allowedips_node, &peer->allowedips_list,
177 peer_list) {
178 u8 cidr, ip[16] __aligned(__alignof(u64));
179 int family;
180
181 family = wg_allowedips_read_node(allowedips_node, ip, &cidr);
182 if (get_allowedips(skb, ip, cidr, family)) {
183 nla_nest_end(skb, allowedips_nest);
184 nla_nest_end(skb, peer_nest);
185 ctx->next_allowedip = allowedips_node;
186 return -EMSGSIZE;
187 }
188 }
189 nla_nest_end(skb, allowedips_nest);
190 no_allowedips:
191 nla_nest_end(skb, peer_nest);
192 ctx->next_allowedip = NULL;
193 ctx->allowedips_seq = 0;
194 return 0;
195 err:
196 nla_nest_cancel(skb, peer_nest);
197 return -EMSGSIZE;
198 }
199
wg_get_device_start(struct netlink_callback * cb)200 static int wg_get_device_start(struct netlink_callback *cb)
201 {
202 struct wg_device *wg;
203
204 wg = lookup_interface(genl_info_dump(cb)->attrs, cb->skb);
205 if (IS_ERR(wg))
206 return PTR_ERR(wg);
207 DUMP_CTX(cb)->wg = wg;
208 return 0;
209 }
210
wg_get_device_dump(struct sk_buff * skb,struct netlink_callback * cb)211 static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
212 {
213 struct wg_peer *peer, *next_peer_cursor;
214 struct dump_ctx *ctx = DUMP_CTX(cb);
215 struct wg_device *wg = ctx->wg;
216 struct nlattr *peers_nest;
217 int ret = -EMSGSIZE;
218 bool done = true;
219 void *hdr;
220
221 rtnl_lock();
222 mutex_lock(&wg->device_update_lock);
223 cb->seq = wg->device_update_gen;
224 next_peer_cursor = ctx->next_peer;
225
226 hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
227 &genl_family, NLM_F_MULTI, WG_CMD_GET_DEVICE);
228 if (!hdr)
229 goto out;
230 genl_dump_check_consistent(cb, hdr);
231
232 if (!ctx->next_peer) {
233 if (nla_put_u16(skb, WGDEVICE_A_LISTEN_PORT,
234 wg->incoming_port) ||
235 nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) ||
236 nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) ||
237 nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name))
238 goto out;
239
240 down_read(&wg->static_identity.lock);
241 if (wg->static_identity.has_identity) {
242 if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY,
243 NOISE_PUBLIC_KEY_LEN,
244 wg->static_identity.static_private) ||
245 nla_put(skb, WGDEVICE_A_PUBLIC_KEY,
246 NOISE_PUBLIC_KEY_LEN,
247 wg->static_identity.static_public)) {
248 up_read(&wg->static_identity.lock);
249 goto out;
250 }
251 }
252 up_read(&wg->static_identity.lock);
253 }
254
255 peers_nest = nla_nest_start(skb, WGDEVICE_A_PEERS);
256 if (!peers_nest)
257 goto out;
258 ret = 0;
259 lockdep_assert_held(&wg->device_update_lock);
260 /* If the last cursor was removed in peer_remove or peer_remove_all, then
261 * we just treat this the same as there being no more peers left. The
262 * reason is that seq_nr should indicate to userspace that this isn't a
263 * coherent dump anyway, so they'll try again.
264 */
265 if (list_empty(&wg->peer_list) ||
266 (ctx->next_peer && ctx->next_peer->is_dead)) {
267 nla_nest_cancel(skb, peers_nest);
268 goto out;
269 }
270 peer = list_prepare_entry(ctx->next_peer, &wg->peer_list, peer_list);
271 list_for_each_entry_continue(peer, &wg->peer_list, peer_list) {
272 if (get_peer(peer, skb, ctx)) {
273 done = false;
274 break;
275 }
276 next_peer_cursor = peer;
277 }
278 nla_nest_end(skb, peers_nest);
279
280 out:
281 if (!ret && !done && next_peer_cursor)
282 wg_peer_get(next_peer_cursor);
283 wg_peer_put(ctx->next_peer);
284 mutex_unlock(&wg->device_update_lock);
285 rtnl_unlock();
286
287 if (ret) {
288 genlmsg_cancel(skb, hdr);
289 return ret;
290 }
291 genlmsg_end(skb, hdr);
292 if (done) {
293 ctx->next_peer = NULL;
294 return 0;
295 }
296 ctx->next_peer = next_peer_cursor;
297 return skb->len;
298
299 /* At this point, we can't really deal ourselves with safely zeroing out
300 * the private key material after usage. This will need an additional API
301 * in the kernel for marking skbs as zero_on_free.
302 */
303 }
304
wg_get_device_done(struct netlink_callback * cb)305 static int wg_get_device_done(struct netlink_callback *cb)
306 {
307 struct dump_ctx *ctx = DUMP_CTX(cb);
308
309 if (ctx->wg)
310 dev_put(ctx->wg->dev);
311 wg_peer_put(ctx->next_peer);
312 return 0;
313 }
314
set_port(struct wg_device * wg,u16 port)315 static int set_port(struct wg_device *wg, u16 port)
316 {
317 struct wg_peer *peer;
318
319 if (wg->incoming_port == port)
320 return 0;
321 list_for_each_entry(peer, &wg->peer_list, peer_list)
322 wg_socket_clear_peer_endpoint_src(peer);
323 if (!netif_running(wg->dev)) {
324 wg->incoming_port = port;
325 return 0;
326 }
327 return wg_socket_init(wg, port);
328 }
329
set_allowedip(struct wg_peer * peer,struct nlattr ** attrs)330 static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
331 {
332 int ret = -EINVAL;
333 u32 flags = 0;
334 u16 family;
335 u8 cidr;
336
337 if (!attrs[WGALLOWEDIP_A_FAMILY] || !attrs[WGALLOWEDIP_A_IPADDR] ||
338 !attrs[WGALLOWEDIP_A_CIDR_MASK])
339 return ret;
340 family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]);
341 cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]);
342 if (attrs[WGALLOWEDIP_A_FLAGS])
343 flags = nla_get_u32(attrs[WGALLOWEDIP_A_FLAGS]);
344
345 if (family == AF_INET && cidr <= 32 &&
346 nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) {
347 if (flags & WGALLOWEDIP_F_REMOVE_ME)
348 ret = wg_allowedips_remove_v4(&peer->device->peer_allowedips,
349 nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
350 peer, &peer->device->device_update_lock);
351 else
352 ret = wg_allowedips_insert_v4(&peer->device->peer_allowedips,
353 nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
354 peer, &peer->device->device_update_lock);
355 } else if (family == AF_INET6 && cidr <= 128 &&
356 nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) {
357 if (flags & WGALLOWEDIP_F_REMOVE_ME)
358 ret = wg_allowedips_remove_v6(&peer->device->peer_allowedips,
359 nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
360 peer, &peer->device->device_update_lock);
361 else
362 ret = wg_allowedips_insert_v6(&peer->device->peer_allowedips,
363 nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
364 peer, &peer->device->device_update_lock);
365 }
366
367 return ret;
368 }
369
set_peer(struct wg_device * wg,struct nlattr ** attrs)370 static int set_peer(struct wg_device *wg, struct nlattr **attrs)
371 {
372 u8 *public_key = NULL, *preshared_key = NULL;
373 struct wg_peer *peer = NULL;
374 u32 flags = 0;
375 int ret;
376
377 ret = -EINVAL;
378 if (attrs[WGPEER_A_PUBLIC_KEY] &&
379 nla_len(attrs[WGPEER_A_PUBLIC_KEY]) == NOISE_PUBLIC_KEY_LEN)
380 public_key = nla_data(attrs[WGPEER_A_PUBLIC_KEY]);
381 else
382 goto out;
383 if (attrs[WGPEER_A_PRESHARED_KEY] &&
384 nla_len(attrs[WGPEER_A_PRESHARED_KEY]) == NOISE_SYMMETRIC_KEY_LEN)
385 preshared_key = nla_data(attrs[WGPEER_A_PRESHARED_KEY]);
386
387 if (attrs[WGPEER_A_FLAGS])
388 flags = nla_get_u32(attrs[WGPEER_A_FLAGS]);
389
390 ret = -EPFNOSUPPORT;
391 if (attrs[WGPEER_A_PROTOCOL_VERSION]) {
392 if (nla_get_u32(attrs[WGPEER_A_PROTOCOL_VERSION]) != 1)
393 goto out;
394 }
395
396 peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable,
397 nla_data(attrs[WGPEER_A_PUBLIC_KEY]));
398 ret = 0;
399 if (!peer) { /* Peer doesn't exist yet. Add a new one. */
400 if (flags & (WGPEER_F_REMOVE_ME | WGPEER_F_UPDATE_ONLY))
401 goto out;
402
403 /* The peer is new, so there aren't allowed IPs to remove. */
404 flags &= ~WGPEER_F_REPLACE_ALLOWEDIPS;
405
406 down_read(&wg->static_identity.lock);
407 if (wg->static_identity.has_identity &&
408 !memcmp(nla_data(attrs[WGPEER_A_PUBLIC_KEY]),
409 wg->static_identity.static_public,
410 NOISE_PUBLIC_KEY_LEN)) {
411 /* We silently ignore peers that have the same public
412 * key as the device. The reason we do it silently is
413 * that we'd like for people to be able to reuse the
414 * same set of API calls across peers.
415 */
416 up_read(&wg->static_identity.lock);
417 ret = 0;
418 goto out;
419 }
420 up_read(&wg->static_identity.lock);
421
422 peer = wg_peer_create(wg, public_key, preshared_key);
423 if (IS_ERR(peer)) {
424 ret = PTR_ERR(peer);
425 peer = NULL;
426 goto out;
427 }
428 /* Take additional reference, as though we've just been
429 * looked up.
430 */
431 wg_peer_get(peer);
432 }
433
434 if (flags & WGPEER_F_REMOVE_ME) {
435 wg_peer_remove(peer);
436 goto out;
437 }
438
439 if (preshared_key) {
440 down_write(&peer->handshake.lock);
441 memcpy(&peer->handshake.preshared_key, preshared_key,
442 NOISE_SYMMETRIC_KEY_LEN);
443 up_write(&peer->handshake.lock);
444 }
445
446 if (attrs[WGPEER_A_ENDPOINT]) {
447 struct sockaddr *addr = nla_data(attrs[WGPEER_A_ENDPOINT]);
448 size_t len = nla_len(attrs[WGPEER_A_ENDPOINT]);
449 struct endpoint endpoint = { { { 0 } } };
450
451 if (len == sizeof(struct sockaddr_in) && addr->sa_family == AF_INET) {
452 endpoint.addr4 = *(struct sockaddr_in *)addr;
453 wg_socket_set_peer_endpoint(peer, &endpoint);
454 } else if (len == sizeof(struct sockaddr_in6) && addr->sa_family == AF_INET6) {
455 endpoint.addr6 = *(struct sockaddr_in6 *)addr;
456 wg_socket_set_peer_endpoint(peer, &endpoint);
457 }
458 }
459
460 if (flags & WGPEER_F_REPLACE_ALLOWEDIPS)
461 wg_allowedips_remove_by_peer(&wg->peer_allowedips, peer,
462 &wg->device_update_lock);
463
464 if (attrs[WGPEER_A_ALLOWEDIPS]) {
465 struct nlattr *attr, *allowedip[WGALLOWEDIP_A_MAX + 1];
466 int rem;
467
468 nla_for_each_nested(attr, attrs[WGPEER_A_ALLOWEDIPS], rem) {
469 ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX,
470 attr, allowedip_policy, NULL);
471 if (ret < 0)
472 goto out;
473 ret = set_allowedip(peer, allowedip);
474 if (ret < 0)
475 goto out;
476 }
477 }
478
479 if (attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]) {
480 const u16 persistent_keepalive_interval = nla_get_u16(
481 attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]);
482 const bool send_keepalive =
483 !peer->persistent_keepalive_interval &&
484 persistent_keepalive_interval &&
485 netif_running(wg->dev);
486
487 peer->persistent_keepalive_interval = persistent_keepalive_interval;
488 if (send_keepalive)
489 wg_packet_send_keepalive(peer);
490 }
491
492 if (netif_running(wg->dev))
493 wg_packet_send_staged_packets(peer);
494
495 out:
496 wg_peer_put(peer);
497 if (attrs[WGPEER_A_PRESHARED_KEY])
498 memzero_explicit(nla_data(attrs[WGPEER_A_PRESHARED_KEY]),
499 nla_len(attrs[WGPEER_A_PRESHARED_KEY]));
500 return ret;
501 }
502
wg_set_device(struct sk_buff * skb,struct genl_info * info)503 static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
504 {
505 struct wg_device *wg = lookup_interface(info->attrs, skb);
506 u32 flags = 0;
507 int ret;
508
509 if (IS_ERR(wg)) {
510 ret = PTR_ERR(wg);
511 goto out_nodev;
512 }
513
514 rtnl_lock();
515 mutex_lock(&wg->device_update_lock);
516
517 if (info->attrs[WGDEVICE_A_FLAGS])
518 flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]);
519
520 if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
521 struct net *net;
522 rcu_read_lock();
523 net = rcu_dereference(wg->creating_net);
524 ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
525 rcu_read_unlock();
526 if (ret)
527 goto out;
528 }
529
530 ++wg->device_update_gen;
531
532 if (info->attrs[WGDEVICE_A_FWMARK]) {
533 struct wg_peer *peer;
534
535 wg->fwmark = nla_get_u32(info->attrs[WGDEVICE_A_FWMARK]);
536 list_for_each_entry(peer, &wg->peer_list, peer_list)
537 wg_socket_clear_peer_endpoint_src(peer);
538 }
539
540 if (info->attrs[WGDEVICE_A_LISTEN_PORT]) {
541 ret = set_port(wg,
542 nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT]));
543 if (ret)
544 goto out;
545 }
546
547 if (flags & WGDEVICE_F_REPLACE_PEERS)
548 wg_peer_remove_all(wg);
549
550 if (info->attrs[WGDEVICE_A_PRIVATE_KEY] &&
551 nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]) ==
552 NOISE_PUBLIC_KEY_LEN) {
553 u8 *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]);
554 u8 public_key[NOISE_PUBLIC_KEY_LEN];
555 struct wg_peer *peer, *temp;
556 bool send_staged_packets;
557
558 if (!crypto_memneq(wg->static_identity.static_private,
559 private_key, NOISE_PUBLIC_KEY_LEN))
560 goto skip_set_private_key;
561
562 /* We remove before setting, to prevent race, which means doing
563 * two 25519-genpub ops.
564 */
565 if (curve25519_generate_public(public_key, private_key)) {
566 peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable,
567 public_key);
568 if (peer) {
569 wg_peer_put(peer);
570 wg_peer_remove(peer);
571 }
572 }
573
574 down_write(&wg->static_identity.lock);
575 send_staged_packets = !wg->static_identity.has_identity && netif_running(wg->dev);
576 wg_noise_set_static_identity_private_key(&wg->static_identity, private_key);
577 send_staged_packets = send_staged_packets && wg->static_identity.has_identity;
578
579 wg_cookie_checker_precompute_device_keys(&wg->cookie_checker);
580 list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) {
581 wg_noise_precompute_static_static(peer);
582 wg_noise_expire_current_peer_keypairs(peer);
583 if (send_staged_packets)
584 wg_packet_send_staged_packets(peer);
585 }
586 up_write(&wg->static_identity.lock);
587 }
588 skip_set_private_key:
589
590 if (info->attrs[WGDEVICE_A_PEERS]) {
591 struct nlattr *attr, *peer[WGPEER_A_MAX + 1];
592 int rem;
593
594 nla_for_each_nested(attr, info->attrs[WGDEVICE_A_PEERS], rem) {
595 ret = nla_parse_nested(peer, WGPEER_A_MAX, attr,
596 peer_policy, NULL);
597 if (ret < 0)
598 goto out;
599 ret = set_peer(wg, peer);
600 if (ret < 0)
601 goto out;
602 }
603 }
604 ret = 0;
605
606 out:
607 mutex_unlock(&wg->device_update_lock);
608 rtnl_unlock();
609 dev_put(wg->dev);
610 out_nodev:
611 if (info->attrs[WGDEVICE_A_PRIVATE_KEY])
612 memzero_explicit(nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]),
613 nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]));
614 return ret;
615 }
616
617 static const struct genl_ops genl_ops[] = {
618 {
619 .cmd = WG_CMD_GET_DEVICE,
620 .start = wg_get_device_start,
621 .dumpit = wg_get_device_dump,
622 .done = wg_get_device_done,
623 .flags = GENL_UNS_ADMIN_PERM
624 }, {
625 .cmd = WG_CMD_SET_DEVICE,
626 .doit = wg_set_device,
627 .flags = GENL_UNS_ADMIN_PERM
628 }
629 };
630
631 static struct genl_family genl_family __ro_after_init = {
632 .ops = genl_ops,
633 .n_ops = ARRAY_SIZE(genl_ops),
634 .resv_start_op = WG_CMD_SET_DEVICE + 1,
635 .name = WG_GENL_NAME,
636 .version = WG_GENL_VERSION,
637 .maxattr = WGDEVICE_A_MAX,
638 .module = THIS_MODULE,
639 .policy = device_policy,
640 .netnsok = true
641 };
642
wg_genetlink_init(void)643 int __init wg_genetlink_init(void)
644 {
645 return genl_register_family(&genl_family);
646 }
647
wg_genetlink_uninit(void)648 void __exit wg_genetlink_uninit(void)
649 {
650 genl_unregister_family(&genl_family);
651 }
652